"""
Copyright 2016-2022 VMware, Inc.  All rights reserved. -- VMware Confidential

This file provides the hooks which will be invoked during patching process.
"""
import datetime
import logging
import os
import os_utils
import platform
import sys
import shutil
import subprocess

sys.path.append('/usr/lib/vmware-updatemgr/python/hcl')
sys.path.append('/usr/lib/vmware/site-packages')

from extensions import extend, Hook
from patch_specs import RequirementsResult, Mismatch
from fss_utils import getTargetFSS  # pylint: disable=E0401
from l10n import msgMetadata as _T, localizedString as _
from reporting import getProgressReporter
from hardware_discovery.services.utils.authentication import CertificateStore
from hardware_discovery.services.vc_service import getVCServiceFromCertificate
from pyVmomi import vim
from time import strftime

import vcsa_utils
from fss_utils import getTargetFSS

from . import utils

# Constants to apply.
OPTION_SDDC_DEPLOYED_TYPE = 'config.SDDC.Deployed.Type'
OPTION_SKIP_VLCM_PRECHECK = 'config.SDDC.VCUpgradeVLCMPrecheck.Skip'
BLOCKED_70U3_BUILDS = [18644231, 18825058, 18792741, 18831705, 18934053, 18935636, 18801418, 18733601, 19037457, 19063458]
VERSION_70U2 = "7.0.2"
MIN_70U2_BUILD_BLOCKED = "18426014" # inclusive

# BLOCKED_70U2_BUILDS is kept here for reference only.
# We use MIN_70U2_BUILD_BLOCKED to block ESXi 7.0 U2 hosts.
BLOCKED_70U2_BUILDS = [18426014, 18538813, 18836573, 18426997, 18999557, 19111464]

PSQL = '/opt/vmware/vpostgres/current/bin/psql'

VERSION_TXT = "version.txt"

VCINTEGRITY_NDU_MARKER_FILE = "ndu.cfg"
VCINTEGRITY_NDU_MARKER_FILE_TARGET_PATH = \
                    os.path.join("/usr/lib/vmware-updatemgr/config/",
                    VCINTEGRITY_NDU_MARKER_FILE)

# Map of replication override for NDU based upgrade.
NDU_REPLICATION_CONFIG = {
      "/usr/lib/vmware-updatemgr/config/updatemgr-config.attrs" : None,
      "/usr/lib/vmware-updatemgr/config/updatemgr-config.props" : None,
      "/usr/lib/vmware-updatemgr/bin/extension.xml" : None,
      "/usr/lib/vmware-updatemgr/bin/clients.xml" : None,
      "/usr/lib/vmware-updatemgr/bin/hwcompat.xml" : None,
      "/usr/lib/vmware-updatemgr/bin/vciInstallUtils_config.xml" : None,
      "/usr/lib/vmware-updatemgr/config/disabled_checks.json" : None,
      "/usr/lib/vmware-updatemgr/config/vvs-config.json" : None,
      "/etc/vmware-rhttpproxy/endpoints.conf.d/updatemgr-proxy.conf" : None,
      "/usr/lib/vmware-updatemgr/bin/jetty-vum.xml" : None,
      "/usr/lib/vmware-updatemgr/bin/jetty-vum-ssl.xml" :
      "/usr/lib/vmware-updatemgr/bin/jetty-vum-ssl.xml.ndusave",
      "/usr/lib/vmware-updatemgr/bin/vci-integrity.xml" :
      "/usr/lib/vmware-updatemgr/bin/vci-integrity.xml.ndusave",
      "/storage/updatemgr/patch-store" : "/storage/updatemgr/patch-store",
    }

logger = logging.getLogger(__name__)
MY_PAYLOAD_DIR = os.path.dirname(__file__)

TMP_DIR = "/tmp"
EXT_PY = ".py"
HOST_CHECK_FILE_BASE = "dual_driver_check"
HOST_CHECK_FILE = HOST_CHECK_FILE_BASE + EXT_PY
HOST_CHECK_SRC_PATH =  os.path.join(MY_PAYLOAD_DIR, HOST_CHECK_FILE)
HOST_CHECK_DST_PATH = os.path.join(TMP_DIR, HOST_CHECK_FILE)

EXT_TXT = ".txt"
EXT_LOG = ".log"
HOST_CHECK_LOG_DIR = "/var/log/vmware/applmgmt/"
FAULT_HOSTS_FILE_BASE = "dual_driver_check_faulty_hosts"
HOSTS_FOR_CHECK_FILE_BASE = os.path.join(TMP_DIR, "vlcm_hosts_for_check")

@extend(Hook.Discovery)
def discover(ctx):
    """
    Determine whether the component should be patched.

    :param ctx: Context given by the patch framework

    :return: 'None' if no patching should be done, otherwise component discovery
             result for VUM
    """
    discoveryResult = None
    if platform.system().lower() == "windows":
        return discoveryResult

    curVersion = utils.GetConfigVersion(utils.CONFIG_VERSION_FILE)
    if vcsa_utils.isDisruptiveUpgrade(ctx) or not \
        getTargetFSS("NDU_Limited_Downtime"):
        # need to save current version.txt file because it is
        # already replaced at Patch stage
        filePath = os.path.join(ctx.stageDirectory, VERSION_TXT)
        utils.SetConfigVersion(filePath, curVersion)
        discoveryResult = vcsa_utils.getComponentDiscoveryResult("updatemgr")

    if getTargetFSS("NDU_Limited_Downtime") and not \
        vcsa_utils.isDisruptiveUpgrade(ctx):
        discoveryResult = doNduDiscovery(ctx, curVersion)

    return discoveryResult


# Modify the patching hook to run only in case of non-disruptive upgrade.
@extend(Hook.Patch)
def patch(ctx):
    """
    Perform incremental patching for B2B upgrade framework.
    """
    progressReporter = getProgressReporter()
    progressReporter.updateProgress(0, _(_T("updatemgr.patch.begin",
                                            "Start VMware Update Manager"
                                            " patching")))
    if vcsa_utils.isDisruptiveUpgrade(ctx) or not \
        getTargetFSS("NDU_Limited_Downtime"):
            before = datetime.datetime.now()
            runB2BPatching(ctx, progressReporter)
            delta = datetime.datetime.now() - before
            logger.info("Time for B2B patch: %s" % (delta))
            logger.info("VMware Update Manager is successfully patched")

    progressReporter.updateProgress(100, _(_T("updatemgr.patch.complete",
                                              "VMware Update Manager patching"
                                              " completed")))

def _getVcServiceAndCertStore():
    vcService = None
    certStore = None
    try:
        certStore = CertificateStore()
        vcService = getVCServiceFromCertificate(certStore)
    except Exception as e:
        vcService = None
        logger.error("Unable to get VC service or cert store. Error: %s",
                     str(e))
    return vcService, certStore


def _getConfig(configKey):
    configValue = None
    if vcService:
        try:
            configValue = vcService.si.content.setting.QueryOptions(configKey)[0].value
        except vim.fault.InvalidName:
            logger.info("vCenter option: %s does NOT exist.", configKey)
        except Exception as e:
            logger.error("Unable to retrieve vCenter option: %s. Error: %s",
                         configKey, str(e))
        if configValue:
            logger.info("vCenter option: %s, value: %s", configKey, configValue)
    return configValue


def _setConfig(configKey, configValue):
    if vcService:
        option = [vim.Option.OptionValue(key=configKey, value=configValue)]
        try:
            vcService.si.content.setting.UpdateOptions(changedValue=option)
            logger.info("Set vCenter option: %s to %s", configKey, configValue)
        except Exception as e:
            logger.error("Unable to set vCenter option: %s to %s. Error: %s",
                         configKey, configValue, str(e))


def _strToBool(s):
    return s in {'true', 'True'}

def _isBoolean(s):
    return s in {'true', 'True', 'false', 'False'}

def _generatePostgresMistmatch(mismatches, err):
    logger.error("Failed to query hosts for precheck using psql. Error: %s",
                 err)
    mismatches.append(Mismatch(
        text=_(_T("com.vmware.vcIntegrity.db.error.text",
                  "Failed to query hosts in inventory")),
        description=_(_T("com.vmware.vcIntegrity.db.error.description",
                         "The underlying error is %s"), err),
        resolution=_(_T("com.vmware.vcIntegrity.db.error.resolution",
                        "Check the vPostgres logs for more information")),
        severity=Mismatch.ERROR))

def _executePsqlQuery(query):
    cmd = [PSQL, "-d", "VCDB", "-U", "postgres", "-q", "-t", "-c", query]
    logger.info("Executing command: %s", ' '.join(cmd))
    return os_utils.executeCommand(cmd)

def _checkAndReport70u3Hosts(mismatches):
   '''Return the names of 7.0 U3 or U3a hosts found'''
   u3_host_names = []
   query = ("SELECT vpx_entity.name FROM vpx_entity" +
            " JOIN vpx_host ON vpx_entity.id=vpx_host.id" +
            " WHERE vpx_host.product_build IN (" +
            ", ".join(["'{}'".format(e) for e in BLOCKED_70U3_BUILDS]) + ")")
   out, err, rc = _executePsqlQuery(query)
   if rc != 0:
      _generatePostgresMistmatch(mismatches, err)
      return u3_host_names

   res = out.strip()
   if res:
      u3_host_names = res.split()
      count = len(u3_host_names)
      logger.error("Found %d ESXi 7.0 U3 or U3a hosts in the vCenter "
                   "inventory. Hosts: %s", count, ", ".join(u3_host_names))
   return u3_host_names

def _checkAndReport70u2Hosts(mismatches):
   '''Return the names of 7.0 U2c or U2d hosts inside vLCM clusters found'''
   vlcm_u2_host_names = []

   # Find all the 7u2c and u2d hosts
   query = ("SELECT ID FROM vpx_host WHERE product_build >= '" +
            MIN_70U2_BUILD_BLOCKED + "' AND product_version = '" +
            VERSION_70U2 + "'")

   out, err, rc = _executePsqlQuery(query)
   if rc != 0:
      _generatePostgresMistmatch(mismatches, err)
      return vlcm_u2_host_names

   res = out.strip()
   if not res:
      # No u2 hosts found
      return vlcm_u2_host_names

   u2_hosts = res.split()
   count = len(u2_hosts)
   logger.info("Found %d ESXi 7.0 U2c or U2d hosts. Host Ids: %s", count,
               ", ".join(u2_hosts))

   # Find clusters from the hosts found
   query = ("SELECT vpx_compute_resource.id FROM vpx_compute_resource" +
            " JOIN vpx_entity ON vpx_compute_resource.id = vpx_entity.parent_id" +
            " WHERE vpx_entity.id IN" +
            " (" + ", ".join(["'{}'".format(e) for e in u2_hosts]) + ")" +
            " AND vpx_compute_resource.lifecycle_managed=1")

   out, err, rc = _executePsqlQuery(query)
   if rc != 0:
      _generatePostgresMistmatch(mismatches, err)
      return vlcm_u2_host_names

   res = out.strip()
   if not res:
      # hosts not part of any vLCM enabled clusters
      return vlcm_u2_host_names

   u2_clusters = set(res.split())
   count = len(u2_clusters)
   logger.info("Found %d vLCM clusters containing ESXi 7.0 U2c or U2d "
               "hosts. Cluster Ids: %s", count, ", ".join(u2_clusters))

   # Found 70u2 hosts inside vLCM cluster, Get host names.
   query = ("SELECT vpx_entity.name FROM vpx_compute_resource" +
            " JOIN vpx_entity ON vpx_compute_resource.id = vpx_entity.parent_id" +
            " WHERE vpx_entity.id IN" +
            " (" + ", ".join(["'{}'".format(e) for e in u2_hosts]) + ")" +
            " AND vpx_compute_resource.lifecycle_managed=1")
   out, err, rc = _executePsqlQuery(query)
   if rc != 0:
      _generatePostgresMistmatch(mismatches, err)
      return vlcm_u2_host_names

   res = out.strip()
   vlcm_u2_host_names = res.split()
   count = len(vlcm_u2_host_names)
   logger.error("Found %d ESXi 7.0 U2c or U2d hosts inside vLCM clusters. "
                "Hosts: %s", count, ", ".join(vlcm_u2_host_names))
   return vlcm_u2_host_names

def _startHostCheck(mismatches, timestamp, host_names):
    '''Execute dual driver check on hosts in the background'''

    # Write host names into a file and then pass it to dual_driver_check.py
    host_file = HOSTS_FOR_CHECK_FILE_BASE + '_' + timestamp + EXT_TXT

    try:
        with open(host_file, "w") as fp:
            # Append "\n" to every host name when writing it into the file
            fp.writelines(["{}\n".format(h) for h in host_names])
        logger.info("Saved host names into %s", host_file)

        # Copy dual_driver_check.py to the destination folder
        shutil.copy(HOST_CHECK_SRC_PATH, HOST_CHECK_DST_PATH)
        logger.info("Copied %s to %s", HOST_CHECK_SRC_PATH, HOST_CHECK_DST_PATH)
    except Exception as e:
        logger.error("Failed to access some file. Error: %s", str(e))
        mismatches.append(Mismatch(
            text=_(_T("com.vmware.vcIntegrity.file.error.text",
                      "Failed to access some file")),
            description=_(
                _T("com.vmware.vcIntegrity.file.error.description",
                   "The underlying error: %s"), str(e)),
            resolution=_(
                _T("com.vmware.vcIntegrity.file.error.resolution",
                   "Resolve the underlying error: %s and retry"), str(e)),
            severity=Mismatch.ERROR))
        return

    # Prepare a command to execute dual_driver_check.py
    cmd = ["/usr/bin/python", HOST_CHECK_DST_PATH,
           "--inB2B",
           "-t", timestamp,
           "-f", host_file]

    # Run dual_driver_check.py from the destination folder in the background
    try:
        proc = subprocess.Popen(cmd)
        logger.info("Executing background command: %s. pid: %s",
                    ' '.join(cmd), proc.pid)
    except Exception as e:
        logger.error("Failed to run %s. Error: %s", HOST_CHECK_DST_PATH, str(e))
        mismatches.append(Mismatch(
            text=_(_T("com.vmware.vcIntegrity.b2b.validation.error.text",
                      "Failed to start fine-grained host check")),
            description=_(
                _T("com.vmware.vcIntegrity.b2b.validation.error.description",
                   "The underlying error: %s"), str(e)),
            resolution=_(
                _T("com.vmware.vcIntegrity.b2b.validation.error.resolution",
                   "Resolve the underlying error: %s and retry"), str(e)),
            severity=Mismatch.ERROR))


# Register for the NDU specific hooks:
@extend(Hook.Requirements)
def collectRequirements(ctx):
    '''RequirementsResult collectRequirements(PatchContext ctx)'''
    mismatches = []

    # VMWARE_B2B flag being true means B2B is running in upgrade mode.
    # Since we know VMC runs in upgrade mode, while on-prem runs in
    # patching mode, we can check VMC environment this way.
    if getTargetFSS('VMWARE_B2B'):
        logger.info("VMware Update Manager running in VMC environment. "
                    "Skip host validation.")
        return RequirementsResult(mismatches=mismatches)

    global vcService
    vcService, certStore = _getVcServiceAndCertStore()

    # check if VCF or VCF-VxRail environment is being used
    deploymentType = _getConfig(OPTION_SDDC_DEPLOYED_TYPE)
    if deploymentType is not None:
        if deploymentType.startswith('VCF'):
                logger.info("VMware Update Manager running in VCF or "
                            "VCF-VxRail environment. "
                            "Skip host validation.")
                if certStore:
                    certStore.cleanup()
                return RequirementsResult(mismatches=mismatches)

    # user can skip precheck by setting config.SDDC.VCUpgradeVLCMPrecheck.Skip
    # to True. default is unset, so precheck is enabled by default unless VC is running
    # in VMC or VCF or VCF-VxRail environment as checked above
    isPrecheckSkipped = _getConfig(OPTION_SKIP_VLCM_PRECHECK)
    if isPrecheckSkipped is not None:
        if not _isBoolean(isPrecheckSkipped):
            logger.error("%s must be any one of [True, true, False, false]",
                         OPTION_SKIP_VLCM_PRECHECK)
        else:
            if _strToBool(isPrecheckSkipped):
                logger.info("%s is set True. Skip host validation.",
                            OPTION_SKIP_VLCM_PRECHECK)
                if certStore:
                    certStore.cleanup()
                return RequirementsResult(mismatches=mismatches)

    if certStore:
        certStore.cleanup()
    logger.info("VMware Update Manager started host validation.")

    # Check for u3 hosts
    u3_host_names = _checkAndReport70u3Hosts(mismatches)
    # Check for u2 hosts
    u2_host_names = _checkAndReport70u2Hosts(mismatches)

    count = len(u3_host_names) + len(u2_host_names)
    if count > 0:
        timestamp = strftime("%Y-%m-%d_%H-%M-%S")
        logger.info("Generated the timestamp for host check: %s", timestamp)
        fault_file_path = HOST_CHECK_LOG_DIR + FAULT_HOSTS_FILE_BASE \
                          + '_' + timestamp + EXT_TXT
        log_file_path = HOST_CHECK_LOG_DIR + HOST_CHECK_FILE_BASE \
                        + '_' + timestamp + EXT_LOG

        # _T() accepts at most one format parameter, but we have three, so
        # we have to format the text message for Mismatch by ourselves here.
        errText = ("{0} host(s) were found in the vCenter inventory, that are "
                   "potentially problematic for a vCenter upgrade. Initiating "
                   "a detailed check of the hosts to scan for driver conflicts."
                   " If problematic hosts are flagged by this detailed check, "
                   "they will be listed in {1} file. See {2} for the full scan "
                   "results. ".format(count, fault_file_path, log_file_path))

        mismatches.append(Mismatch(
            text=_(_T("com.vmware.vcIntegrity.b2b.validation.error.text", "%s"),
                      errText),
            description=_(
                _T("com.vmware.vcIntegrity.b2b.validation.error.description",
                   "vCenter upgrade is not allowed if there are any ESXi 7.0 "
                   "U2c or U2d host(s) inside Image managed cluster(s) or any "
                   "ESXi 7.0 U3 or U3a host(s) in the inventory.")),
            resolution=_(
                _T("com.vmware.vcIntegrity.b2b.validation.error.resolution",
                   "If problematic host(s) are flagged in the scan, these "
                   "hosts must be upgraded to ESXi 7.0 U3c or higher version. "
                   "These hosts can be upgraded either with a baseline created "
                   "from an ISO; or using an image based upgrade, if they are "
                   "in Image managed cluster. This host upgrade needs to be "
                   "completed before proceeding with the upgrade of vCenter "
                   "Server. Do not use Rollup based patch baselines. If no "
                   "problematic hosts were detected in the scan, restart "
                   "vCenter Server upgrade and the next pre-check will proceed "
                   "normally. NOTE: Once the scan is clean, it is recommended "
                   "to proceed to upgrade immediately. Prior to upgrade, do "
                   "NOT introduce additional hosts with potential driver "
                   "conflicts. Refer to KB 86447 for details.")),
            severity=Mismatch.ERROR))
        _startHostCheck(mismatches, timestamp, u3_host_names + u2_host_names)

    return RequirementsResult(mismatches=mismatches)

@extend(Hook.Validation)
def validate(ctx):
    '''ValidationResult validate(PatchContext ctx)'''
    pass

@extend(Hook.Expand)
def expand(ctx):
    '''void prepare(PatchContext ctx) throw UserError'''
    logger.info("VMware Update Manager expand done.")

@extend(Hook.Contract)
def contract(ctx):
    '''void patch(PatchContext ctx) throw UserError'''
    logger.info("VMware Update Manager contract hook done.")

@extend(Hook.Revert)
def revert(ctx):
    '''void onSuccess(PatchContext ctx) throw UserError'''
    logger.info("Start VMware Update Manager revert hook.")

# Non disruptive upgrade discovery
def doNduDiscovery(ctx, curVersion):
    '''
    Generate the replication configuration needed for NDU upgrade.
    '''
    # Create a marker file in the B2B stage directory to replcaite
    # to target VC indicating upgrade during service startup.
    markerFile = os.path.join(ctx.stageDirectory,
                              VCINTEGRITY_NDU_MARKER_FILE)

    with open(markerFile, 'w+') as fp:
        import configparser
        config = configparser.ConfigParser()
        config['NDU'] = dict()
        config['NDU']['version'] = str(curVersion)
        config.write(fp)

    # Add NDU upgrade marker file for replication.
    NDU_REPLICATION_CONFIG[markerFile] = \
                        VCINTEGRITY_NDU_MARKER_FILE_TARGET_PATH

    # Add any patch specific replication configuration needed for
    # service upgrade at startup on the target VC.
    _getNduReplicationConfig(curVersion)

    # The replication configuration
    discoveryResult = vcsa_utils.getComponentDiscoveryResult( \
                           "updatemgr", \
                           replicationConfig=NDU_REPLICATION_CONFIG)
    return discoveryResult

def _getNduReplicationConfig(curVersion):
    """
    Update the replication configuration for patches > curVersion.
    This will update the global replication map for discovery.
    """
    applicablePatches=list()
    utils.getApplicablePatches(curVersion, applicablePatches)

    # Import patches python directory where all individual python module reside
    sys.path.append(os.path.join(MY_PAYLOAD_DIR, "patches"))

    if len(applicablePatches) > 0 :
        for ver, modulePath in sorted(applicablePatches,
                                      key=lambda x: utils.Version(x[0])):
            mod = __import__(modulePath)
            if hasattr(mod, 'getReplicationConfig'):
                logger.info("Applying replication config  %s for version %s"
                            % (modulePath, ver))
                NDU_REPLICATION_CONFIG.update(
                            mod.getReplicationConfig(ctx))

# B2B in-place upgrade patching steps.
def runB2BPatching(ctx, progressReporter):
    _doIncrementalPatching(ctx, progressReporter)
    utils._updateIntegrityConfig()
    utils._upgradeDatabaseSchema()
    utils._upgradeSecurityConfig()
    utils._updateLSRegistrations()
    utils._fixupMetadataZips()

def _doIncrementalPatching(ctx, progressReporter):
    '''
    Runs the individual B2B patches.
    '''
    logger.info("Running Incremental Patching ...")
    applicablePatches = list()
    # Import patches python directory where all individual python module reside
    sys.path.append(os.path.join(MY_PAYLOAD_DIR, "patches"))

    # Read current version from saved version.txt
    filePath = os.path.join(ctx.stageDirectory, VERSION_TXT)
    curVersion = utils.GetConfigVersion(filePath)
    logger.info("Current version = %s", str(curVersion))

    # Get a list of applicable patches.
    utils.getApplicablePatches(curVersion, applicablePatches)

    if len(applicablePatches) == 0:
        return

    logger.info("Applying Patching ...")
    appliedPatchCount = 0
    totalPatch = len(applicablePatches)
    updatedVer = None
    for ver, modulePath in sorted(applicablePatches,
                                  key=lambda x: utils.Version(x[0])):
        logger.info("Applying patch %s for version %s" % (modulePath, ver))
        mod = __import__(modulePath)
        mod.doPatching(ctx)
        logger.info("Patch %s applied" % (modulePath))
        appliedPatchCount += 1
        updatedVer = ver
        progressReporter.updateProgress((appliedPatchCount*90)/totalPatch)

@extend(Hook.OnSuccess)
def onSuccess(ctx):
    """
    At OnSuccess phase the patch phase has finished successfully and the
    appliance is running on the new version. This phase is intended to be used
    to do any clean up or post patch configuration changes once all components
    are up-to-date.
    """
    logger.info("Starting Post-Upgrade Processing ...")

    global vcService
    vcService, certStore = _getVcServiceAndCertStore()

    # After a successful VC upgrade, always reset OPTION_SKIP_VLCM_PRECHECK
    _setConfig(OPTION_SKIP_VLCM_PRECHECK, str(False))

    if certStore:
        certStore.cleanup()