# Copyright 2018-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
# coding: utf-8
"""
This module is the entry point for the Patch Runner.
Refer to https://wiki.eng.vmware.com/VSphere2016/vSphere2016Upgrade/Inplace/Patch_Extensibility
"""

import logging
import os
import sys
from . import utils
from distutils.version import LooseVersion

from extensions import Hook, extend
from patch_specs import PatchContext, DiscoveryResult, RequirementsResult, \
   PatchInfo, ValidationResult, Requirements
from patch_errors import UserError
from reporting import getProgressReporter
from vcsa_utils import getComponentDiscoveryResult, getComponentDefintion, \
   isDisruptiveUpgrade
from fss_utils import getTargetFSS
from os_utils import getCommandExitCode

SERVICE_NAME = 'vlcm'

logger = logging.getLogger(__name__)
MY_PAYLOAD_DIR = os.path.dirname(__file__)

# Import patches python directory where all individual python module resides
PATCHES_DIR = os.path.join(MY_PAYLOAD_DIR, "patches")
sys.path.append(PATCHES_DIR)

################################################################################
# Added a new entry here if you added a new patch.
# Note that the list must be ascending.
#
# PATCHES contains a list of tuples where the first element indicates the source
# version that the <second element>.py under patches dir is handling.
# For example, if an entry read like (0.0.1, "patch1"), then there must be
# a patch1.py inside patches dir which takes care of patching vLCM from
# version 0.0.1 to the next version
################################################################################
PATCHES = [
    ("0.0.3", "patch2"),
    ("0.0.4", "patch3"),
    ("0.0.5", "patch4"),
]


def getCurrentVersion(ctx):
   """Return the version from the staged version file inside ctx.stageDirectory,
   if any. Return stripped string or None if not found.

   :rtype: str
   """
   ret = utils.getSourceVersion(ctx.stageDirectory)
   if ret:
      ret = ret.strip()
   return ret


@extend(Hook.Discovery)
def discover(ctx):  # pylint: disable=W0613
   """
   The function always return valid {DiscoveryResult}. This means that the other
   patch hooks will always be invoked.

   :param ctx: Context given by the patch framework
   :type ctx: PatchContext
   :return: Always return {DiscoveryResult}.
   :rtype: DiscoveryResult
   """
   logger.info('Retrieving the discovery result from services.json file')

   # Preserve the original version.txt to stage dir
   logger.info("Preserving version file")
   utils.preserveVersionFile(ctx.stageDirectory)

   # report how many applicable patches before really going into patching
   current_version = getCurrentVersion(ctx)

   if not current_version:
       # Because of a bug in our rpm, there is a chance the version.txt is not present on the machine.
       # Since we don't know if which patches were applied, we'll assume none of them have.
       # Patch1, Patch2 and Patch3 are idempotent and can be executed multiple times.
       # This issue won't be faced from 0.0.4 onwards.
       # More info on the bug: https://confluence.eng.vmware.com/display/~belinovd/vLCM+RPM+bug
       logger.info("vLCM current version not found, will assume 0.0.1 during patch.")
   else:
      # here we have a current_version, so we check for applicable patches
      sorted_patches = _get_applicable_sorted_patches(current_version)
      logger.info(
         "Applicable patches for vLCMService: {}".format(sorted_patches))
   # Database is change without see it here so we always participate from now on
   result = getComponentDiscoveryResult(SERVICE_NAME)
   return result


@extend(Hook.Requirements)
def collect_requirements(ctx):  # pylint: disable=W0613
   """
   RequirementsResult collectRequirements(PatchContext ctx)

   :param ctx: Context given by the patch framework
   :type ctx: PatchContext
   :return: The summary of the latest patch that should be installed
   :rtype: RequirementsResult
   """
   logger.info("No special requirements")
   mismatches = []
   requirements = Requirements()
   patchInfo = PatchInfo()
   return RequirementsResult(requirements, patchInfo, mismatches)


@extend(Hook.Validation)
def validate(ctx):
   """ValidationResult validate(PatchContext sharedCtx)"""
   mismatches = []
   return ValidationResult(mismatches)


def _get_applicable_sorted_patches(current_version):
   """Filter PATCHES by using current_version, we simply sort all PATHCES based
   on its first element, which is the source version of that patch, then we
   filter this list so that only those patches with newer than or equal to
   current version stay. This list of patches are those that should be applied.

   :param current_version: current version to filter patches
   :type current_version: str
   :return: filtered PATCHES
   :rtype: list([str, str])
   """
   # find applicable patches
   # first we sort PATCHES by version
   sorted_patches = sorted(PATCHES, key=lambda x: LooseVersion(x[0]))
   if not current_version:
      logger.info("No current version provided, this means vLCM was not "
                  "installed, returning all patches. ")
      return sorted_patches
   # filter sorted patches to only contain patches that have version larger than
   # current version
   sorted_patches = [x for x in sorted_patches if LooseVersion(x[0]) >= LooseVersion(current_version)]
   return sorted_patches


def _do_incremental_patching(ctx):
   """
   Incrementally applies all applicable patches.
   """
   current_version = getCurrentVersion(ctx)

   if not current_version:
       # Because of a bug in our rpm, there is a chance the version.txt is not present on the machine.
       # Since we don't know if which patches were applied, we'll assume none of them have.
       # Patch1, Patch2 and Patch3 are idempotent and can be executed multiple times.
       # This issue won't be faced from 0.0.4 onwards.
       # More info on the bug: https://confluence.eng.vmware.com/display/~belinovd/vLCM+RPM+bug
       logger.info("No current version found, assuming 0.0.1")
       current_version = "0.0.1"

   logger.info("Start incremental patching on source version {}. ".format(
      current_version))

   sorted_patches = _get_applicable_sorted_patches(current_version)
   length = len(sorted_patches)

   progress_reporter = getProgressReporter()
   progress_reporter.updateProgress(0, "Start vLCM service patching, {} "
                                       "patches pending".format(length))

   logger.info("Applicable patches: {}".format(sorted_patches))
   for index in range(length):
      ver, module_path = sorted_patches[index]

      logger.info("Applying patch for source version {}".format(ver))
      mod = __import__(module_path)
      if hasattr(mod, "doPatching"):
         try:
            mod.doPatching(ctx)
         except Exception as e:
            logger.error("Failed to apply patch for source version {0}! "
                         "Error: {1}".format(ver, str(e)))
            raise UserError(str(e))
         progress_reporter.updateProgress(
            int((index+1)/len(sorted_patches)),
            "Applied patch for source version {0}, changes are {1}".format(
               ver, mod.getChanges() if hasattr(mod, "getChanges") else "undefined"
            )
         )
         logger.info("Applied patch for source version {}".format(ver))
      else:
         msg = "Patch {} does not define doPatching function".format(module_path)
         logger.error(msg)
         raise UserError(msg)

   progress_reporter.updateProgress(100, "Completed vLCM service patching")


@extend(Hook.Expand)
def expand(ctx):
   """ Expand the current source machine
   """
   if not isDisruptiveUpgrade(ctx):
      sys.path.append(os.getenv("VMWARE_PYTHON_PATH"))
      from .db_config import upgrade
      logger.info("Preparing to run DB expand. ")
      fssValues = {

      }
      upgrade.execute_expand(fssValues)

@extend(Hook.Revert)
def revert(ctx):
   """ Reverting the expand phase
   """
   if not isDisruptiveUpgrade(ctx):
      sys.path.append(os.getenv("VMWARE_PYTHON_PATH"))
      from .db_config import upgrade
      logger.info("Reverting DB expand. ")
      upgrade.execute_revert()

@extend(Hook.Contract)
def contract(ctx):
   """ Contract the target machine
   """
   if not isDisruptiveUpgrade(ctx):
      sys.path.append(os.getenv("VMWARE_PYTHON_PATH"))
      from .db_config import upgrade
      logger.info("Preparing to run DB contract. ")
      fssValues = {

      }
      upgrade.execute_contract(fssValues)

@extend(Hook.Patch)
def do_patching(ctx):
   """void patch(PatchContext sharedCtx) throw UserError

   :param ctx: Context given by the patch framework
   :type ctx: PatchContext
   """

   _fix_FSS()

   if isDisruptiveUpgrade(ctx):
      logger.info("Preparing to run DB upgrade. ")
      fssValues = {
      }
      sys.path.append(os.getenv("VMWARE_PYTHON_PATH"))
      from .db_config import upgrade
      upgrade.execute_expand(fssValues)
      upgrade.execute_contract(fssValues)

   logger.info("Prepare to do incremental patching. ")
   _do_incremental_patching(ctx)

def _fix_FSS():
   ''' During leaf service upgrade. If an FSS is True on the target but
   it is missing on the source the service won't see it. This appends it
   to the Py FSS file to fix that for anyone that uses it.
   '''
   logger.info('Fixing FSS')
   sys.path.append(os.environ['VMWARE_PYTHON_PATH'])
   import featureState
   fssFile = '/usr/lib/vmware/site-packages/featureState.py'

