# Copyright (c) 2017-2022 VMware, Inc.  All rights reserved.
# -- VMware Confidential

"""
vpxd-svcs component module for B2B patching.
"""
import os
import sys
import shutil
import logging
import importlib
import vcsa_utils
from .constants import (VERSION_FILE_NAME, VERSION, CP_PROP_KEYS,
                        VPXD_SVCS_CONFIG_DIR, VPXD_SVCS_CONFIG_DIR_OLD,
                        DATASERVICES_PROP, DATASERVICES_PROP_OLD,
                        CPSPEC_PROP, CPSPEC_PROP_OLD, CPSPEC_PROP_FILE_NAME)
from extensions import extend, Hook
from l10n import msgMetadata as _T, localizedString as _
from patch_specs import Requirements, PatchInfo, RequirementsResult
from reporting import getProgressReporter
import fss_utils

VPXD_SVCS_PATCHES_DIR = os.path.dirname(__file__)
# Import patches directory where all patches reside.
sys.path.append(os.path.join(VPXD_SVCS_PATCHES_DIR, "patches"))

cisutils = None
logger = logging.getLogger(__name__)

# Earlier versions of vpxd-svcs do not have the vpxdsvcs.version property in
# dataservice.proprties file. In such a case the installed_version method will
# return 0.

# Patch version 7 should not be used.
# It was introduced for AppDefense shim work which now has been pulled back.
patches = [
    (1, "patch_01"),
    (2, "patch_02"),
    (4, "patch_04"),
    (5, "patch_05"),
    (6, "patch_06"),
    (8, "patch_08"),
    (11, "patch_11"),
    (12, "patch_12")
]


def installed_version(ctx):
    """
    Returns the vpxd-svcs version currently installed. We return 0 if there is
    no version field in the dataservice.properties file.
    """
    ds_prop_file = os.path.join(ctx.stageDirectory, VERSION_FILE_NAME)
    version = cisutils.readprop(ds_prop_file, VERSION)
    cp_prop = cisutils.readprop(ds_prop_file, CP_PROP_KEYS[0])
    if not version:
        if not cp_prop:
            return 0
        else:
            # Certain builds have Compute Policy related properties in the
            # dataservice.properties file but miss the vpxd-svcs.version
            # property.
            return 1
    return int(version)


def getChangesSummary():
    curr_patch_script = patches[-1][1]
    mod = __import__(curr_patch_script)
    return mod.getChanges()


@extend(Hook.Discovery)
def discover(ctx):
    """
    DiscoveryResult discover(PatchContext sharedCtx) throw
    UserUpgradeError.
    """
    # Not applicable if no older versions are present
    if not os.path.exists(VPXD_SVCS_CONFIG_DIR):
       if not os.path.exists(VPXD_SVCS_CONFIG_DIR_OLD):
          logger.info("vpxd-svcs config not found. Patch not applicable.")
          return None

    if fss_utils.getTargetFSS("NDU_Limited_Downtime") \
        and not vcsa_utils.isDisruptiveUpgrade(ctx):
        # Info backup & restore to NOT copy version specific files [PR 2741375]
        replicationConfig = {
            VPXD_SVCS_CONFIG_DIR: None
        }
        return vcsa_utils.getComponentDiscoveryResult(
            "vpxd-svcs",
            displayName=_(_T("vpxd-svcs.displayName",
                            "VMware vCenter-Services")),
            replicationConfig=replicationConfig)

    return vcsa_utils.getComponentDiscoveryResult(
        "vpxd-svcs",
        displayName=_(_T("vpxd-svcs.displayName", "VMware vCenter-Services")))


@extend(Hook.Requirements)
def collectRequirements(ctx):
    """RequirementsResult collectRequirements(PatchContext sharedCtx)"""
    requirements = Requirements()
    mismatches = []
    patchInfo = PatchInfo(summary=getChangesSummary())
    return RequirementsResult(requirements, patchInfo, mismatches)


@extend(Hook.Prepatch)
def prePatch(ctx):
    """void prePatch(PatchContext sharedCtx) throw UserError"""
    dataservicesProp = DATASERVICES_PROP
    if not os.path.exists(dataservicesProp):
        dataservicesProp = DATASERVICES_PROP_OLD
    shutil.copy(dataservicesProp, ctx.stageDirectory)

    sys.path.append(os.environ['VMWARE_PYTHON_PATH'])
    cisutils = importlib.import_module('cis.utils')
    if cisutils.readprop(dataservicesProp, VERSION):
        # If ComputePolicy is not enabled on source make sure to run patch-02
        # and patch-03 by resetting the vpxd-svcs.version property to 1.
        import featureState
        featureState.init()
        if (hasattr(featureState, 'ComputePolicy') and
            not featureState.getComputePolicy()):
            logger.info("Resetting vpxd-svcs version")
            cisutils.replace_properties_in_file(
                os.path.join(ctx.stageDirectory, VERSION_FILE_NAME),
                {VERSION: patches[0][0]})

    if os.path.exists(CPSPEC_PROP):
        shutil.copy(CPSPEC_PROP, ctx.stageDirectory)
    elif os.path.exists(CPSPEC_PROP_OLD):
        shutil.copy(CPSPEC_PROP_OLD, ctx.stageDirectory)


def _doIncrementalPatching(ctx):
    installed = installed_version(ctx)

    patch_installed = False
    for ver, module_path in patches:
        if ver <= installed:
            continue
        patch_module = __import__(module_path)
        logger.info("Applying Compute Policy patch version {}".format(ver))
        patch_module.doPatching(ctx)
        patch_installed = True
        logger.info("Patch %s applied" % (module_path))

    if not patch_installed:
        logger.info("Upgrade between same versions of vpxd-svcs.")


@extend(Hook.Patch)
def patch(ctx):
    """void patch(PatchContext ctx) throw UserError"""
    global cisutils
    if fss_utils.getTargetFSS("NDU_Limited_Downtime") \
        and not vcsa_utils.isDisruptiveUpgrade(ctx):
        logger.info("This is a case of NDU upgrade."
                    "Upgrade will take place before service startup.")
        return None

    # We import cis.utils in the Patch hook to ensure it is
    # patched to work with python 3.
    sys.path.append(os.environ['VMWARE_PYTHON_PATH'])
    cisutils = importlib.import_module('cis.utils')

    if not os.path.exists(DATASERVICES_PROP):
        if os.path.exists(os.path.join(ctx.stageDir, VERSION_FILE_NAME)):
            shutil.copy(os.path.join(ctx.stageDir, VERSION_FILE_NAME),
                        DATASERVICES_PROP)

    if not os.path.exists(CPSPEC_PROP):
        if os.path.exists(os.path.join(ctx.stageDir, CPSPEC_PROP_FILE_NAME)):
            shutil.copy(os.path.join(ctx.stageDir, CPSPEC_PROP_FILE_NAME),
                        CPSPEC_PROP)

    progressReporter = getProgressReporter()
    progressReporter.updateProgress(0, _(_T("vpxd-svcs.patch.begin",
                                            "Start vpxd-svcs patching")))
    _doIncrementalPatching(ctx)
    progressReporter.updateProgress(100, _(_T("vpxd-svcs.patch.complete",
                                              "patching vpxd-svcs completed")))
