# Copyright 2015-2023 VMware, Inc.
# All rights reserved. -- VMware Confidential

import sys
import os
import json
import platform
import logging
import socket
import vcsa_utils
import uuid
import subprocess
import time
import os_utils

from . import utils
from extensions import extend, Hook
from patch_specs import DiscoveryResult, ValidationResult, Question, Mismatch, \
    Requirements, PatchInfo, RequirementsResult
from l10n import msgMetadata as _T, localizedString as _
from reporting import getProgressReporter
from fss_utils import getTargetFSS  # pylint: disable=E0401

log = logging.getLogger(__name__)
VMAFD_PAYLOAD_DIR = os.path.dirname(__file__)

# Import patches python directory where all individual python module reside
sys.path.append(os.path.join(VMAFD_PAYLOAD_DIR, 'patches'))
patches = [
]

NDU_LIMITED_DOWNTIME_FSS = "NDU_Limited_Downtime"

VMAFD_ENABLE_REPOINT_REG_VALUE = '1'
VMAFD_DISABLE_REPOINT_REG_VALUE = '0'

def invoke_cmd(cmd):

    p = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE)
    (output, error) = p.communicate()
    if error:
        print (error)
        return error
    else:
        return output

def getLDUInstallParam():
    cmd = b'install-parameter vmdir.ldu-guid'
    output = invoke_cmd(cmd)
    output = output.rstrip()
    return output

def setLDUInstallParam(guid):
    cmd = b'install-parameter vmdir.ldu-guid -s '
    cmd += guid
    output = invoke_cmd(cmd)
    log.info('set install parameter output: %s' % output)
    return output

def setLDURegistry(guid):
    vmdirKey = '[HKEY_THIS_MACHINE\\Services\\vmdir]'
    vmdirVal = 'LduGuid'
    vmafdKey = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters]'
    vmafdVal = 'LDU'

    utils = getUtil()

    ret = utils.setRegValue(vmdirKey, vmdirVal, guid)
    log.info('Set registry value in vmdir')
    if ret != 0:
        return ret

    ret = utils.setRegValue(vmafdKey, vmafdVal, guid)
    log.info('Set registry value in vmafd')

    return ret

def setLDU():
    guid = getLDUInstallParam()

    if (guid == None or not guid):
        log.info('There is no LDU-guid - making a new one')
        guid = str(uuid.uuid1())
        log.info('Setting LDU %s' % guid)
        setLDUInstallParam(guid)
    else:
        log.info('LDU-guid already exists: %s' % guid)
    log.info('Setting LDU to VMAFD')

    result = setLDURegistry(guid)
    if (result != 0):
        raise Exception('Setting the LDU Failed')
    return

def getUtil():
    return utils.PatchUtilsLin()

def getCurrentVersion():
    return getUtil().getSourceVersion()

def getLatestVersion():
    if len(patches) > 0:
        return patches[-1][0]
    return '6.5.0.0'

def getChangesSummary():
    latestPatchScript = patches[-1][1]
    mod = __import__(latestPatchScript)
    changesSummary = None
    if hasattr(mod, 'getChanges'):
        changesSummary = mod.getChanges()

    return changesSummary

'''
Wait for DomainID / DC value before patch finishes
These values are expected to be updated by CacheDCThread of VMAFD
'''
def wait_for_dc_reg_val_update():
    retryCount = 0
    domainGUID = ''
    affinitizedDCValue = ''
    maxRetry = 300       # 5 Minutes (1 seconds per try)

    utils = getUtil()
    vmafdParamsRegKey = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters]'
    vmafdDomainsRegKey = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters\\Domains]'

    while retryCount < maxRetry:
        retryCount = retryCount + 1
        try:
            domainGUID = utils.getRegValue(vmafdParamsRegKey, 'DomainGUID')
            if domainGUID is not None:
                vmafdDomainGUIDRegKey = vmafdDomainsRegKey.strip("]") + "\\" + domainGUID + "]"
                affinitizedDCValue = utils.getRegValue(vmafdDomainGUIDRegKey, 'AffinitizedDC')
                if affinitizedDCValue is not None:
                    log.info("Affinitized DC value set to %s" % affinitizedDCValue)
                    return
                else:
                    log.info("Affinitized DC value is not set yet. Retrying...")
            else:
                log.info("Domain ID value is not set yet. Retrying...")

        except Exception as e:
            log.info("Failed to read DomainID/Affinitized DC reg value. Retrying...")

        if retryCount >= maxRetry:
            log.exception("Reached maximum attempt to read Domain ID/Affinitized DC value from registry. Giving up...")
            raise Exception("Reached maximum attempt to wait for DC Name info")
        time.sleep(1)

def update_registry_data():
    utils = getUtil()
    utils.set_vmafd_lwreg("LogFile", "/var/log/vmware/vmafdd/vmafdd.log")
    utils.set_vmafd_lwreg("Krb5Conf", "/var/lib/vmware/vmafdd_config/krb5/krb5.lotus.conf")
    utils.set_vmafd_lwreg("KeytabPath", "/var/lib/vmware/vmafdd_config/krb5/krb5.keytab")
    utils.upgrade_vmafd_reg_tree()

@extend(Hook.Discovery)
def discover(ctx):
    '''DiscoveryResult discover(PatchContext ctx) throw UserError'''

    result = vcsa_utils.getComponentDiscoveryResult('vmafdd')

    etc_machine_ssl_crt = "/etc/vmware/vmware-vmafd/machine-ssl.crt"
    etc_machine_ssl_key = "/etc/vmware/vmware-vmafd/machine-ssl.key"

    src_machine_ssl_crt = etc_machine_ssl_crt
    src_machine_ssl_key = etc_machine_ssl_key

    var_machine_ssl_crt = "/var/lib/vmware/vmafdd_data/machine-ssl.crt"
    var_machine_ssl_key = "/var/lib/vmware/vmafdd_data/machine-ssl.key"

    if os.path.islink(etc_machine_ssl_crt):
        src_machine_ssl_crt = var_machine_ssl_crt

    if os.path.islink(etc_machine_ssl_key):
        src_machine_ssl_key = var_machine_ssl_key

    symlink_krb5_lotus = "/etc/krb5.lotus.conf"
    symlink_krb5_keytab = "/usr/lib/vmware-vmafd/share/config/krb5.keytab"

    src_krb5_lotus = symlink_krb5_lotus
    src_krb5_keytab = symlink_krb5_keytab

    var_krb5_lotus = "/var/lib/vmware/vmafdd_config/krb5/krb5.lotus.conf"
    var_krb5_keytab = "/var/lib/vmware/vmafdd_config/krb5/krb5.keytab"

    if os.path.islink(symlink_krb5_lotus):
        src_krb5_lotus = var_krb5_lotus

    if os.path.islink(symlink_krb5_keytab):
        src_krb5_keytab = var_krb5_keytab

    #backup certool.cfg
    symlink_certool = "/usr/lib/vmware-vmca/share/config/certool.cfg"
    src_certool = symlink_certool
    var_certool = "/var/lib/vmware/vmca_config/certool.cfg"
    if os.path.islink(symlink_certool):
        src_certool = var_certool

    # Add vmafd available in VCHA and missing from backup manifest
    result.replicationConfig = {
        src_machine_ssl_crt: var_machine_ssl_crt,
        src_machine_ssl_key: var_machine_ssl_key,
        src_krb5_lotus: var_krb5_lotus,
        src_krb5_keytab: var_krb5_keytab,
        src_certool: var_certool,
        "/storage/db/vmware-vmafd/vecs": "/storage/db/vmware-vmafd/vecs",
        "/storage/db/vmware-vmafd/vmevent": "/storage/db/vmware-vmafd/vmevent"
    }

    return result

@extend(Hook.Requirements)
def collectRequirements(ctx):
    '''RequirementsResult collectRequirements(PatchContext ctx)'''
    mismatches = []
    requirements = Requirements()
    patchInfo = PatchInfo()

    utils = getUtil()

    try:
        if utils.is_mgmt_node():
            passwordQuestion = Question(userDataId="vmdir.password",
                                         text=_(_T("vmdir.password.text",
                                                   "Single Sign-On administrator password")),
                                         description=_(_T("vmdir.password.desc",
                                                         "For the first instance of the identity domain, this is the password given to the Administrator account.  Otherwise, this is the password of the Administrator account of the replication partner.")),
                                         kind=Question.PASSWORD_KIND)

            requirements = Requirements(requiredDiskSpace={'/storage/core' :
                                                           getUtil().getRequiredDiskSpace()},
                                    questions=[passwordQuestion],
                                    rebootRequired=False)
        else:
            requirements = Requirements(requiredDiskSpace={'/storage/core' :
                                                           getUtil().getRequiredDiskSpace()},
                                    questions=[],
                                    rebootRequired=False)
    except Exception: #pylint: disable=W0703
        log.exception('Cannot connect to VMware Authentication Framework Daemon')
        mismatches.append(Mismatch(text=_(_T('vmafd.error.text',
                                             'Unable to connect to VMware Authentication Framework Daemon')),
                                   description=_(_T('vmafd.error.description',
                                                    'VMware Afd Service is down')),
                                   resolution=_(_T('vmafd.error.resolution',
                                                   'Please check that VMware Afd Service is running')),
                                   severity=Mismatch.ERROR))

    return RequirementsResult(requirements, patchInfo, mismatches)

@extend(Hook.Validation)
def validate(ctx):
    '''ValidationResult validate(PatchContext ctx)'''
    mismatches = []
    utils = getUtil()

    if utils.is_mgmt_node() and not utils.validatePassword(ctx.userData['vmdir.password']):
        mismatches.append(
                Mismatch(text=_(_T("vmdir.password.error.text",
                                   "Single Sign-On administrator password is missing or incorrect")),
                         description=_(_T("vmdir.password.error.description",
                                          "Single Sign-On administrator login is required for VMware Directory Service patching.")),
                         resolution=_(_T("vmdir.password.error.resolution",
                                         "Please provide valid Single Sign-On password.")),
                         severity=Mismatch.ERROR,
                         relatedUserDataId='vmdir.password'))

    return ValidationResult(mismatches)

def makeLdif(domainName, hostName):
    f = open('/tmp/tmp.ldif', 'w')
    dn = getDN('dn: cn=Administrators,cn=Builtin', domainName)
    dn = dn + '\n'
    f.write(dn)
    f.write('changetype: modify\n')
    f.write('add: member\n')
    attr = 'member: cn=' + hostName
    attr = attr + ',ou=Computers'
    attr = getDN(attr, domainName)
    attr = attr + '\n'
    f.write(attr)
    f.close()

def getDN(partialDN, domainName):
    dn = partialDN
    for d in domainName:
        dn = dn + ",dc="
        dn = dn + d
    return dn

def _addThisMachineToLDAP(dcName, password):
    hostName = ''
    utils = getUtil()
    domainDN = utils.getDomainName().split(".")

    pnid_cmd = ['/usr/lib/vmware-vmafd/bin/vmafd-cli', 'get-pnid', '--server-name', 'localhost']
    (stdout, stderr, rc) = os_utils.executeCommand(pnid_cmd)
    if rc == 0:
        hostName = stdout.rstrip()
        log.info("Hostname is %s" % hostName)
    else:
        log.error(stderr)
        raise Exception("Failed to add machine to LDAP: "
                        "Couldn't run command get-pnid")
    makeLdif(domainDN, hostName)

    user = getDN('cn=Administrator,cn=Users', domainDN)
    adminUPN = utils.getAdminUPN()

    modifyCmd = [
            'ldapmodify',
            '-h',
            dcName,
            '-p',
            '389',
            '-U',
            adminUPN,
            '-Y',
            'SRP',
            '-y',
            '/dev/fd/0',
            '-f',
            '/tmp/tmp.ldif']

    (rc, stdout, stderr) = utils.run_command(modifyCmd, stdin=password)
    if rc == 20: # Entry already exists
        log.info('This node already exists in Administrator group at DC %s!', dcName)
    elif rc != 0:
        log.error(stderr)
        raise Exception("Failed to add machine to LDAP: "
                        "Couldn't run command ldapmodify")
    else:
        log.info("Successfully added machine to Administrator group in DC %s!", dcName)

    try:
        os.remove('/tmp/tmp.ldif')
    except OSError as e:
        log.error("Failed to delete temp ldif file due to %s" % e)

@extend(Hook.Prepatch)
def prePatch(ctx): #pylint: disable=W0613
    '''void prePatch(PatchContext sharedCtx) throw UserUpgradeError'''

def _doIncrementalPatching(ctx):
    utils = getUtil()
    currVersion = getCurrentVersion()
    log.info('Vmafd current version: %s', currVersion)
    patchCurrentVersion = currVersion is None

    for ver, modulePath in patches:
        compResult = utils.versionCompare(currVersion, ver)
        patchCurrentVersion |= compResult < 0

        if patchCurrentVersion:
            log.info('Applying patch %s on version %s' % (ver, modulePath))
            mod = __import__(modulePath)
            mod.doPatching(ctx)
            log.info('Patch %s applied' % (modulePath))

    # Add this management node to Administrator group
    # This doesn't add Gateway node to Administrator group
    if utils.is_mgmt_node():
        password = ctx.userData['vmdir.password']
        dcName = utils.getDCName()
        dcNameEx = utils.getDCNameEx()
        _addThisMachineToLDAP(dcName, password)
        # Need to add machine to both DC if they are pointing to different machines
        if dcName != dcNameEx:
            _addThisMachineToLDAP(dcNameEx, password)
    else:
        log.info("No need to add machine to Admin group on non-management node!")

    utils.setSourceVersion(getLatestVersion())

@extend(Hook.SwitchOver)
def switchover(ctx):
    '''void switchover(Switchover ctx) throw UserError'''

    '''This function only invokes the patch logic when it is NDU.'''

    if not getTargetFSS(NDU_LIMITED_DOWNTIME_FSS):
        return

    utils = getUtil()
    is_lw_started = utils.start_lwsmd()

    update_registry_data()

    if is_lw_started:
        utils.stop_lwsmd()

    vmdnsRootKey = '[HKEY_THIS_MACHINE\\Services\\vmdns]'
    utils = getUtil()

    progressReporter = getProgressReporter()

    progressReporter.updateProgress(0, _(_T('vmafd.patch.begin',
                                            'Start VMware Authentication Framework Daemon patching')))

    # Start lwsmd and lwreg services via service control
    utils.start_lwsmd_by_service_control()

    utils.update_path_ownership()

    if utils.is_infra_or_embedded_node():
        utils.deleteRegTree(vmdnsRootKey)

    setLDU()

    # Avoid repointing during patch process
    # It may be already repointed during stopping all service after pre-patch
    # This will just ensure that who ever calls get-dc-name-ex doesn't get different
    # response during patch process
    utils.update_repoint_config(VMAFD_DISABLE_REPOINT_REG_VALUE)

    # start vmafd service
    utils.start_service()

    if utils.is_mgmt_node() or utils.is_gateway():
        wait_for_dc_reg_val_update()

    _doIncrementalPatching(ctx)

    utils.restart_svc('lwsmd')

    progressReporter.updateProgress(100, _(_T('vmafd.patch.complete',
                                              'Patching VMware Authentication Framework Daemon completed')))

@extend(Hook.Patch)
def patch(ctx):
    '''void patch(PatchContext sharedCtx) throw UserUpgradeError'''

    if not vcsa_utils.isDisruptiveUpgrade(ctx) and getTargetFSS(NDU_LIMITED_DOWNTIME_FSS):
        return

    vmdnsRootKey = '[HKEY_THIS_MACHINE\\Services\\vmdns]'
    utils = getUtil()

    progressReporter = getProgressReporter()

    progressReporter.updateProgress(0, _(_T('vmafd.patch.begin',
                                            'Start VMware Authentication Framework Daemon patching')))

    LwIsStarted = utils.start_lwsmd()

    update_registry_data()

    if utils.is_infra_or_embedded_node():
        utils.deleteRegTree(vmdnsRootKey)

    setLDU()

    # Avoid repointing during patch process
    # It may be already repointed during stopping all service after pre-patch
    # This will just ensure that who ever calls get-dc-name-ex doesn't get different
    # response during patch process
    utils.update_repoint_config(VMAFD_DISABLE_REPOINT_REG_VALUE)

    utils.update_path_ownership()

    # start vmafd service
    utils.start_service()

    if utils.is_mgmt_node() or utils.is_gateway():
        wait_for_dc_reg_val_update()

    _doIncrementalPatching(ctx)

    if LwIsStarted:
        utils.stop_lwsmd()
    elif utils.is_service_running('lwsmd'):
        utils.restart_svc('lwsmd')

    # cleanup vmdns RPMs
    utils.cleanup_vmdns_rpms()

    progressReporter.updateProgress(100, _(_T('vmafd.patch.complete',
                                              'Patching VMware Authentication Framework Daemon completed')))

@extend(Hook.OnSuccess)
def OnSuccess(ctx):
    utils = getUtil()

    #wait till vmafdd is in running state
    utils.wait_for_vmafdd_running_state()

    # Enable AutoRepoint only on Gateway Node
    if utils.is_gateway():
        utils.update_repoint_config(VMAFD_ENABLE_REPOINT_REG_VALUE)
