# Copyright 2019-2021 VMware, Inc.
# All rights reserved. -- VMware Confidential
"""
VMDIR patching module following the B2B patching principles
"""
import sys
import os
import platform
import logging
import socket
import vcsa_utils
import time
import ldap3 as ldap

from . import utils
from . import ldap_wrapper
from patch_errors import InternalError
from patch_specs import DiscoveryResult, ValidationResult, Question, Mismatch, \
    Requirements, PatchInfo, RequirementsResult
from extensions import extend, Hook
from l10n import msgMetadata as _T, localizedString as _
from reporting import getProgressReporter
from typing import List
from vcsa_utils import isDisruptiveUpgrade
from fss_utils import getTargetFSS

log = logging.getLogger(__name__)
VMDIR_PAYLOAD_DIR = os.path.dirname(__file__)

# Import patches python directory where all individual python module reside
sys.path.append(os.path.join(VMDIR_PAYLOAD_DIR, "patches"))
patches = [
]

LEGACY_SCHEMA_DN = 'cn=aggregate,cn=schemacontext'
NDU_LIMITED_DOWNTIME = "NDU_Limited_Downtime"
SCHEMA_DELETE_TXT = "schemaDel.txt"
LDAP_NO_SUCH_OBJECT = 32

def getUtil():
    if os.name == "nt":
        return utils.PatchUtilsWin()
    else:
        return utils.PatchUtilsLin()

def getCurrentVersion():
    return getUtil().getSourceVersion()

def getLatestVersion():
    if len(patches) > 0:
        return patches[-1][0]
    return "6.5.0.0"

def getChangesSummary():
    latestPatchScript = patches[-1][1]
    mod = __import__(latestPatchScript)
    changesSummary = None
    if hasattr(mod, 'getChanges'):
        changesSummary = mod.getChanges()

    return changesSummary

@extend(Hook.Discovery)
def discover(ctx):
    '''DiscoveryResult discover(PatchContext sharedCtx) throw UserUpgradeError'''

    result = None
    try:
        result = vcsa_utils.getComponentDiscoveryResult("vmdird")
        if getTargetFSS(NDU_LIMITED_DOWNTIME) and (not isDisruptiveUpgrade(ctx)):
            result.replicationConfig = {"/storage/db/vmware-vmdir/data.mdb": "/storage/db/vmware-vmdir/data.mdb"}
    except ValueError:
        log.exception("Exception ValueError while doing discovery for vmdird. Will assume vmdird is not present.")
    return result


def check_psc_version(password: str) -> List[str]:
    host = 'localhost'
    obj = ldap_wrapper.LdapConnection(host, password, use_machine_account=True)
    conn = obj.open_ldap_connection(host)
    obj.ldap_search(conn, obj.domain_controllers_dn, "(objectClass=computer)", ldap.SUBTREE,
                    ["cn", "vmwPlatformServicesControllerVersion"])
    old_dc_list = []
    nodes = conn.response
    for node in nodes:
        if obj.get_attribute(node, "vmwPlatformServicesControllerVersion").startswith(("6.5")):
            old_dc = obj.get_attribute(node, "cn")
            log.info("Found DC %s with vsphere version less than 6.7", old_dc)
            old_dc_list.append(old_dc)
    conn.unbind()
    return old_dc_list

def is_legacy_schema_exists(password):
    obj = ldap_wrapper.LdapConnection('localhost', password, use_machine_account=True)
    conn = obj.open_ldap_connection('localhost')
    if obj.ldap_search(conn, "cn=aggregate,cn=schemacontext", "(objectClass=*)", ldap.SUBTREE, ldap.ALL_ATTRIBUTES):
        log.info("Legacy schema exists")
        return True
    log.info("Legacy schema does not exist")
    return False

@extend(Hook.Validation)
def validate(ctx): #pylint: disable=W0613
    '''ValidationResult validate(PatchContext sharedCtx)'''
    mismatches = []
    log.info('Vmdir validation Hook Called ')

    util = getUtil()
    dcAccount, passwordval = util.getMachineCredentials()
    log.info('Checking for VMDir legacy schema and mixed mode vCenter')
    old_dc_list = check_psc_version(passwordval)

    if getTargetFSS(NDU_LIMITED_DOWNTIME) and (not isDisruptiveUpgrade(ctx)):
    # Determine if legacy schema deletion is required for NDU/RDU and store the information in a file to use later.
        try:
            if is_legacy_schema_exists(passwordval):
                filePath = os.path.join(ctx.stageDirectory, SCHEMA_DELETE_TXT)
                if old_dc_list:
                    log.info("Not deleting legacy schema as 65 DC's {} are present".format(' ,'.join(old_dc_list)))
                    util.setSchemaDeleteVal(filePath, False)
                else:
                    util.setSchemaDeleteVal(filePath, True)
        except Exception:
            log.warning("Failed to check for schema deletion.")

    if old_dc_list and get_domain_functional_level(dcAccount, passwordval) == '4' and is_legacy_schema_exists(passwordval):
        old_dcs = ', '.join(old_dc_list)
        mismatches.append(
            Mismatch(text=_(_T("vmdir.mixedmode.error.text",
                               "VMDir Domain Functional Level 4 detected on vcenter server {}. This can only happen if the administrator"
                               " has manually changed the replication related scheme in vmdir. The manual change of VMDir schema"
                               " configuration is not supported and can lead to vCenter Server becoming inaccessible.".format(dcAccount))),
                     description=_(_T("vmdir.mixedmode.error.description",
                                      "Updating a vCenter server in Enhanced Linked mode with vCenter server"
                                      " of different version is not supported.")),
                     resolution=_(_T("vmdir.mixedmode.error.resolution",
                                     "Upgrade vCenter server(s): {} to the 6.7 version or higher. This is required due to VMDir"
                                     " schema mismatch between 6.5 and higher versions. If the problem persists please contact"
                                     " VMware support.".format(old_dcs))),
                     severity=Mismatch.ERROR,
                     relatedUserDataId='vmdir.mixedmode'))

    return ValidationResult(mismatches)

@extend(Hook.Prepatch)
def prePatch(ctx): #pylint: disable=W0613
    '''void prePatch(PatchContext sharedCtx) throw UserUpgradeError'''

def _postInstallScripts(ctx): #pylint: disable=W0613
    update_schema()
    update_vmdir()

def update_vmdir():
    # Windows not supported
    if platform.system().lower() == "windows":
        return

    util = getUtil()

    # Start service before update
    util.start_service()

    afdParamKeyPath = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters]'
    dirKeyPath = '[HKEY_THIS_MACHINE\\Services\\vmdir]'

    domainName = util.getRegValue(afdParamKeyPath, 'DomainName')
    dcAccountDN = util.getRegValue(dirKeyPath, 'dcAccountDN')
    samAccount = socket.getfqdn().rstrip('\n')

    dcAccount, password = util.getMachineCredentials()
    dcAccount_upn = '%s@%s' % (dcAccount, domainName)
    path = get_vdc_upgrade_path()
    args = [path, '-H', 'localhost', '-D', dcAccount_upn, '-d', dcAccountDN, '-s', samAccount]

    log.info("Performing vmdir update")

    (rc, stdout, stderr) = util.run_command_with_display(args, args, stdin=password)
    if rc != 0:
        log.error('Failed to run command : %s', args)
        log.error(stderr)
        log.error(stdout)
        raise InternalError("Vmware Directory Service patch failed" )
    log.info(stdout)

def get_vdc_upgrade_path():
    # Windows not supported
    if platform.system().lower() == "windows":
        return

    path = os.path.join(getUtil().getInstallDir(), 'bin', 'vdcupgrade')
    return path

def get_vmdird_path():
    # Windows not supported
    if platform.system().lower() == "windows":
        return

    path = os.path.join(getUtil().getInstallDir(), 'sbin', 'vmdird')
    return path

def update_schema():
    log.info("Performing schema upgrade")

    utils = getUtil()

    path = get_vmdird_path()

    datafile = os.path.join(utils.getConfigDir(), 'vmdirschema.ldif')

    args = [path, '-u', '-c', '-s', '-f', datafile]
    argsdisplay = args

    (rc, stdout, stderr) = getUtil().run_command_with_display(args, argsdisplay)
    if rc != 0:
        log.error('Failed to run command : %s', argsdisplay)
        log.error(stderr)
        log.error(stdout)
        raise InternalError("Vmware Directory Service schema update failed" )

    log.info(stdout)
    log.info('Schema update success')

def _doIncrementalPatching(ctx):
    utils = getUtil()
    currVersion = getCurrentVersion()
    log.info('Vmdir current version: %s', currVersion)
    patchCurrentVersion = currVersion is None

    for ver, modulePath in patches:
        compResult = utils.versionCompare(currVersion, ver)
        patchCurrentVersion |= compResult < 0

        if patchCurrentVersion:
            log.info('Applying patch %s on version %s' % (ver, modulePath))
            mod = __import__(modulePath)
            mod.doPatching(ctx)
            log.info('Patch %s applied' % (modulePath))

    utils.setSourceVersion(getLatestVersion())

def set_maintainance_mode():
    log.info('Set maintenance mode')

    vmdirKey='[HKEY_THIS_MACHINE\\Services\\vmdir]'
    valName='Arguments'

    utils = getUtil()

    if utils.is_service_running(utils.getVmdirComponentName()):
        log.error('service vmdir is running while patching')
        raise InternalError("VMware Directory Service patch failed: Service is running while patching")

    # Get current arguments as a list
    currentArgs = utils.getRegValue(vmdirKey, valName)
    argList = currentArgs.split()

    # add standalone args
    args = ['-m', 'standalone']

    for arg in args:
        argList.append(arg)

    # write back args
    newArgs = ' '.join(argList)

    utils.setRegValue(vmdirKey, valName, newArgs)

    log.info('Maintenance mode success')

'''
Wait for DomainID / DC value before patch finishes
These values are expected to be updated by CacheDCThread of VMAFD
VMAFD needs VMDIR running to retrieve these on Infra/Embedded node
This wait is needed as VMDIR can't be started before patching
because of schema change from 6.5)
'''
def wait_for_dc_reg_val_update():
    retryCount = 0
    domainGUID = ''
    affinitizedDCValue = ''
    maxRetry = 60       # 5 Minutes (5 seconds per try)

    utils = getUtil()
    vmafdParamsRegKey = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters]'
    vmafdDomainsRegKey = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters\\Domains]'

    while retryCount < maxRetry:
        retryCount = retryCount + 1
        try:
            domainGUID = utils.getRegValue(vmafdParamsRegKey, 'DomainGUID')
            if domainGUID is not None:
                vmafdDomainGUIDRegKey = vmafdDomainsRegKey.strip("]") + "\\" + domainGUID + "]"
                affinitizedDCValue = utils.getRegValue(vmafdDomainGUIDRegKey, 'AffinitizedDC')
                if affinitizedDCValue is not None:
                    log.info("Affinitized DC value set to %s" % affinitizedDCValue)
                    return
                else:
                    log.info("Affinitized DC value is not set yet. Retrying...")
            else:
                log.info("Domain ID value is not set yet. Retrying...")

        except Exception as e:
            log.info("Failed to read DomainID/Affinitized DC reg value. Retrying...")

        if retryCount >= maxRetry:
            log.exception("Reached maximum attempt to read Domain ID/Affinitized DC value from registry. Giving up...")
            raise Exception("Reached maximum attempt to wait for DC Name info")
        time.sleep(5)


def get_domain_functional_level(account: str, password: str) -> str:
    util = getUtil()
    args = ['/usr/lib/vmware-vmafd/bin/dir-cli', 'domain-functional-level', 'get', '--login', account]
    (rc, stdout, stderr) = util.run_command_with_display(args, args, stdin=password)
    if rc != 0:
        log.error("Error fetching Domain Functional level")
        log.error(stderr)
        log.error(stdout)
        raise InternalError("Vmware Directory Service failed to get domain functional level")
    dfl = (stdout.split('\n')[1]).split(' ')[3]
    log.info("Domain Functional Level is %s" % dfl)
    return dfl


def ldap_delete_tombstone(dc: str, tomstone_dn: str, admin_dn: str, password: str) -> None:
    util = getUtil()
    args = ['/opt/likewise/bin/ldapdelete', '-h', dc, '-x', '-D', admin_dn, '-w', password, tomstone_dn]
    args_display = ['/opt/likewise/bin/ldapdelete', '-h', dc, '-x', '-D', admin_dn, '-w', '****', tomstone_dn]
    (rc, stdout, stderr) = util.run_command_with_display(args, args_display)
    if rc != 0 and rc != LDAP_NO_SUCH_OBJECT:
        log.error("Failed to delete the legacy schema tombstone entry on DC %s", dc)
        log.error(stderr)
        log.error(stdout)
        raise InternalError("Vmware Directory Service failed to delete tombstone entry")


def delete_legacy_schema() -> None:
    """
    Deletes legacy schema entry and it's tombstone entry on the DC
    Before deleting the legacy schema make sure we don't have any 6.5 DC
    """
    util = getUtil()
    _, password = util.getMachineCredentials()
    old_dc_list = check_psc_version(password)
    if old_dc_list:
        log.info("Not deleting legacy schema as 65 DC's {} are present".format(' ,'.join(old_dc_list)))
        return

    host = 'localhost'
    attr_object_guid = 'objectGUID'
    obj = ldap_wrapper.LdapConnection(host, password, use_machine_account=True)
    conn = obj.open_ldap_connection(host)
    # delete legacy schema on localhost
    if obj.ldap_search(conn, LEGACY_SCHEMA_DN, "(objectClass=*)", ldap.SUBTREE, attr_object_guid):
        guid = obj.get_attribute(conn.response[0], attr_object_guid)
        tombstone_dn = "cn=aggregate#{}:{},cn=Deleted Objects,{}".format(attr_object_guid, guid, obj.ldap_domain_dn)
        log.info("Deleting the legacy schema from {}".format(host))
        if obj.ldap_delete(conn, LEGACY_SCHEMA_DN):
            ldap_delete_tombstone(host, tombstone_dn, obj.ldap_machine_account, password)
        else:
            conn.unbind()
            raise InternalError("Vmware Directory Service failed to delete legacy schema")
    else:
        log.info("No legacy schema stored for the domain")

    conn.unbind()


@extend(Hook.SwitchOver)
def switchover(ctx):
    '''void switchover(Switchover ctx) throw UserError'''

    if getTargetFSS(NDU_LIMITED_DOWNTIME):
        DoUpgrade(ctx)
    else:
        log.info("FSS is off, all work is done in Patch hook")


@extend(Hook.Patch)
def patch(ctx):
    '''void patch(PatchContext sharedCtx) throw UserUpgradeError'''

    if not getTargetFSS(NDU_LIMITED_DOWNTIME) or isDisruptiveUpgrade(ctx):
        DoUpgrade(ctx)
    else:
        log.info("All work is done in SwitchOver Hook")


def DoUpgrade(ctx):
    utils = getUtil()

    progressReporter = getProgressReporter()

    progressReporter.updateProgress(0, _(_T("vmdir.patch.begin",
                                            'Start VMware Directory Service patching')))
    # Start the lwsmd service not running already
    utils.start_lwsmd_service()

    # set vmdir in maintainance mode before patching
    set_maintainance_mode()

    # to fix B2B failure when certain path of upgrade
    # is followed as explained in PR 2619919
    if getTargetFSS(NDU_LIMITED_DOWNTIME) and (not isDisruptiveUpgrade(ctx)):
        # Determine if legacy schema deletion is required for NDU/RDU
        filePath = os.path.join(ctx.stageDirectory, SCHEMA_DELETE_TXT)
        if utils.getSchemaDeleteVal(filePath):
            utils.start_service()
            log.info("Deleting the legacy schema for RDU Upgrade.")
            delete_legacy_schema()
            utils.stop_service()
    else:
        utils.start_service()
        delete_legacy_schema()
        utils.stop_service()

    _doIncrementalPatching(ctx)
    _postInstallScripts(ctx)

    wait_for_dc_reg_val_update()

    progressReporter.updateProgress(100, _(_T("vmdir.patch.complete",
                                              'Patching VMware Directory Service completed')))

@extend(Hook.OnSuccess)
def OnSuccess(ctx):
    '''Call back for post upgrade cleanup'''
    try:
        vmdirScriptPath = '/usr/lib/vmware-vmdir/scripts/'
        sys.path.append(vmdirScriptPath)
        import vmdirResetMode

        log.info('Performing vmdir commit operation')
        vmdirResetMode.resetMode()
        log.info('vmdir commit complete')
    except:
        log.exception('vmdir commit failed')
        log.error(vmdirResetMode.LIN_FAIL_MESAGE)
