#!/usr/bin/env python
#
# Copyright 2016-2018, 2021 VMware, Inc.  All rights reserved. -- VMware Confidential
#
"""
This module is the entry point for the Patch Runner.
Copied from NGC with slight modifications:
//depot/vsphere-client-modules/
   vmkernel-main/assembler/patch/vsphere_client/__init__.py
"""

import logging

from extensions import Hook, extend
from fss_utils import getTargetFSS
from l10n import localizedString
from l10n import msgMetadata as _T
from os_utils import isWindows
from patch_errors import UserError
from patch_specs import (DiscoveryResult, PatchContext, PatchInfo,
                         RequirementsResult, ValidationResult)
from reporting import getProgressReporter
from vcsa_utils import getComponentDiscoveryResult, isDisruptiveUpgrade

from .patch_util import get_applicable_patches
from .properties_util import delete_backup_files, revert_backup_files
from .version_util import (CONFIG_DIR, get_version, initialize,
                           remove_temp_version)

logger = logging.getLogger(__name__)

@extend(Hook.Discovery)
def discover(patch_context):
   """
   Determine whether the component should be patched.

   :param patch_context: Context given by the patch framework
   :type patch_context: PatchContext

   :return: 'None' if no patching should be done.
   :rtype: DiscoveryResult
   """

   initialize()
   if not _should_patch():
      logger.info('No need to patch vAPI Endpoint.')
      remove_temp_version()
      return None

   logger.info('Deleting old backup files')
   delete_backup_files(CONFIG_DIR)

   logger.info('Retrieving the discovery result from services.json file')
   result = getComponentDiscoveryResult('vapi-endpoint')
   return result


@extend(Hook.Patch)
def do_patching(patch_context):
   """
   Perform incremental patching

   :param patch_context: Context given by the patch framework
   :type patch_context: PatchContext
   """

   if isDisruptiveUpgrade(patch_context) or not getTargetFSS("NDU_Limited_Downtime"):
      logger.info('Executing Patch phase')

      progress_reporter = getProgressReporter()
      _do_incremental_patching(progress_reporter, patch_context)
      remove_temp_version()

@extend(Hook.Expand)
def expand(patch_context):
   """
   Perform incremental patching

   :param patch_context: Context given by the patch framework
   :type patch_context: PatchContext
   """

   if getTargetFSS("NDU_Limited_Downtime") and not isDisruptiveUpgrade(patch_context):
      logger.info('Executing Expand phase')

      progress_reporter = getProgressReporter()
      _do_incremental_patching(progress_reporter, patch_context)

@extend(Hook.Revert)
def revert(patch_context):
   """
   Revert changes made in expand hook

   :param patch_context: Context given by the patch framework
   :type patch_context: PatchContext
   """

   # There is a revert only for Non-disruptive upgrade(NDU) in case there is a
   # temp version file. In case this files doesn't exist means that there was no
   # expand so nothing to revert.
   if (getTargetFSS("NDU_Limited_Downtime") and not isDisruptiveUpgrade(patch_context) and get_version() != '1.0.0.0'):
      logger.info('Executing Revert phase')
      revert_backup_files(CONFIG_DIR)
      remove_temp_version()

def _should_patch():
   """
   :return: Return True if there is at least one patch that should be applied.
   :rtype: bool
   """

   cur_version = get_version()
   applicable_patches = get_applicable_patches(cur_version)
   return len(applicable_patches) != 0


def _do_incremental_patching(progress_reporter, patch_context):
   """
   Incrementally applies all applicable patches.

   :param progress_reporter: Progress reporter that is updated on every successfully applied patch
   :type progress_reporter: _ProgressReporter

   :param patch_context: Context given by the patch framework
   :type patch_context: PatchContext

   :return: version of the latest successfully applied patch
   :rtype: LooseVersion

   :raise UserError: Error that will be reported to the user.
   """

   logger.info('Get applicable patches')
   cur_version = get_version()
   applicable_patches = get_applicable_patches(cur_version)
   if len(applicable_patches) == 0:
      # This can happen only if this phase is invoked directly.
      logger.info('No applicable patches found.')
      return cur_version

   cur_progress = 0
   progress_step = int(100 / len(applicable_patches))
   user_error = None

   for patch in applicable_patches:
      # Apply the patch
      try:
         logger.info('Executing patch version %s', patch.get_version())
         patch.do_patching(patch_context)
         cur_version = patch.get_version()
         logger.info('Successfully applied patch.')

      except UserError as e:
         logger.exception('Failed to apply patch with version %s', patch.get_version())
         user_error = e
         break

      except Exception:
         logger.exception('Failed to apply patch with version %s', patch.get_version())
         cause = localizedString(_T(
               'vapi.endpoint.patch.fail.generic',
               'Error while applying vAPI Endpoint patch version %s'), patch.get_version())
         user_error = UserError(cause)
         break

      # Update the progress
      cur_progress += progress_step
      progress_reporter.updateProgress(cur_progress, patch.get_summary())

   if user_error:
      if cur_progress != 0:
         logger.error('Not all patches were applied. Latest applied patch is %s', cur_version)
         progress_reporter.updateProgress(100, localizedString(_T(
               'vapi.endpoint.patch.fail.partial',
               'vAPI Endpoint was patched partially to version %s'), cur_version))

      else:
         logger.error('Failed to patch vAPI Endpoint')
         progress_reporter.updateProgress(100, localizedString(_T(
            'vapi.endpoint.patch.fail.all',
            'vAPI Endpoint patching failed')))

      raise user_error

   else:
      logger.info('All patches applied successfully')
      progress_reporter.updateProgress(100, localizedString(_T(
            'vapi.endpoint.patch.success',
            'vAPI Endpoint was patched successfully.')))

   return cur_version
