#!/usr/bin/env python
#
# Copyright 2017-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
#
"""
This module is the entry point for the Patch Runner.
"""

import logging
import os
import pathlib
import shutil
import sys
import traceback

from extensions import Hook, extend
from l10n import localizedString, msgMetadata as _T
from patch_errors import UserError
from patch_specs import PatchContext, DiscoveryResult
from reporting import getProgressReporter
from vcsa_utils import getComponentDiscoveryResult, getComponentDefintion, \
    isDisruptiveUpgrade
from fss_utils import getTargetFSS
from .vsphere_ui_memory_increase import try_increasing_max_heap
from .patch_util import get_applicable_patches, get_executed_patches, \
    get_applicable_patches_for_expand, get_expanded_patches
from .patches.file_util import add_or_update_property_in_file
from .patches.path_constants import VSPHERE_UI_ROOT_DIR, \
    VSPHERE_UI_UPGRADE_MARKER_FILE, VSPHERE_UI_UPGRADE_MARKER_FILE_TEMPLATE
from .version_util import mark_patch_as_executed, isVMCGateway, \
    VSPHERE_UI_SOURCE_VERSION_PROP, get_client_version_from_rpm, \
    VSPHERE_UI_PRE_UPGRADE_CONFIG_FILE, \
    VSPHERE_UI_SOURCE_CLN_PROP, mark_patch_as_expanded, \
    get_service_version_from_cisreg_file, \
    VSPHERE_UI_CISREG_VERSION_PROP

logger = logging.getLogger(__name__)

DEPENDENT_SERVICES = []
DEPENDENT_COMPONENTS = []
NDU_LIMITED_DOWNTIME_FSS = 'NDU_Limited_Downtime'


@extend(Hook.Discovery)
def discover(patch_context):
    """
    Determine whether the vSphere Client should be patched.
    The function always return valid {DiscoveryResult}. This means that the other
    patch hooks will always be invoked.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext

    :return: Always return {DiscoveryResult}.
    :rtype: DiscoveryResult
    """

    # Get the vSphere Client service name from services.json using
    # serviceId = "vsphere-ui"
    service_name = _get_service_name()

    logger.info('Retrieving the discovery result from services.json file')
    result = getComponentDiscoveryResult('vsphere-ui', patchServices=[service_name])

    if not isVMCGateway():
        DEPENDENT_SERVICES.append('vmware-vpxd')
        DEPENDENT_COMPONENTS.append('vpxd')

    # Append the required services to the result`s dependent services list
    for service in DEPENDENT_SERVICES:
        if service not in result.dependentServices:
            logger.info('Adding service %s to vsphere-ui component dependent services list' % service)
            result.dependentServices.append(service)

    # Append the required components to the result`s dependent components list
    for component in DEPENDENT_COMPONENTS:
        if component not in result.dependentComponents:
            logger.info('Adding component %s to vsphere-ui component dependent components list' % component)
            result.dependentComponents.append(component)

    pre_upgrade_cfg_file = os.path.join(str(patch_context.stageDirectory),
                                        VSPHERE_UI_PRE_UPGRADE_CONFIG_FILE)
    # create/cleanup the vsphere-ui pre-upgrade config file
    with open(pre_upgrade_cfg_file, 'w'):
        pass
    # save the source version of the vSphere UI
    source_client_version, source_client_cln = get_client_version_from_rpm()
    logger.info(
        'Saving source H5 client version %s and cln %s, to %s ' %
        (source_client_version, source_client_cln, pre_upgrade_cfg_file))
    add_or_update_property_in_file(pre_upgrade_cfg_file, VSPHERE_UI_SOURCE_VERSION_PROP,
                                   source_client_version)
    add_or_update_property_in_file(pre_upgrade_cfg_file, VSPHERE_UI_SOURCE_CLN_PROP,
                                   source_client_cln)

    # Create a marker file to detect the first post upgrade service start-up
    upgrade_marker_file = _create_upgrade_marker_file(patch_context)

    # Save the source cisreg service version of vsphere-ui
    source_cisreg_version = get_service_version_from_cisreg_file()
    if source_cisreg_version is not None:
        logger.info('Saving vsphere-ui source cisreg version %s to %s ' %
                    (source_cisreg_version, upgrade_marker_file))
        add_or_update_property_in_file(upgrade_marker_file,
                                       VSPHERE_UI_CISREG_VERSION_PROP,
                                       source_cisreg_version)

    # Replication configuration that is supposed to override the default one.
    # Used for RDU (NDU) upgrades.
    if not isDisruptiveUpgrade(patch_context):
        marker_file_template = os.path.join(str(patch_context.stageDirectory),
                                            VSPHERE_UI_UPGRADE_MARKER_FILE_TEMPLATE)
        logger.info(
            'Adding vsphere-ui upgrade marker template: %s to the replication config.' % marker_file_template)
        result.replicationConfig = {
            marker_file_template: VSPHERE_UI_UPGRADE_MARKER_FILE
        }

    return result


@extend(Hook.Expand)
def expand(patch_context):
    """ Expand the current source machine

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    if _is_rdu_upgrade(patch_context):
        _do_incremental_expand(patch_context)


@extend(Hook.Revert)
def revert(patch_context):
    """ Revert the expand phase

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    if _is_rdu_upgrade(patch_context):
        _do_incremental_revert(patch_context)


@extend(Hook.Contract)
def contract(patch_context):
    """
    Contract the target machine.
    Used for removing leftover files.
    All services are up and running and the downtime is over.
    No disruptive changes are allowed.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    if os.path.exists(VSPHERE_UI_UPGRADE_MARKER_FILE):
        logger.info(
            'Removing the upgrade marker file at %s ' % VSPHERE_UI_UPGRADE_MARKER_FILE)
        os.remove(VSPHERE_UI_UPGRADE_MARKER_FILE)
    if _is_rdu_upgrade(patch_context):
        _do_incremental_contract(patch_context)


@extend(Hook.Patch)
def do_patching(patch_context):
    """
    This code should be executed only for in-place (disruptive) upgrades
    or for RDU based, where NDU_LIMITED_DOWNTIME_FSS is OFF.

    1. Delete the virgo work folder.
    2. Increase vsphere-ui max heap, if needed.
    3. Perform incremental patching, if there are applicable patches.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """

    if (isDisruptiveUpgrade(patch_context)
            or (not getTargetFSS(NDU_LIMITED_DOWNTIME_FSS))):
        _delete_virgo_work_folder()

        '''
        When an B2B upgrade is made, that doesn't feature an update of the vsphere-ui rpm AND the size of the VC is L or XL,
        then this script is the only vsphere-ui code that is being executed during the upgrade. The visl-integration bundle
        will no doubtfully get updated, so any custom ram sizes will be reverted to the default. Since vsphere-ui
        relies on a custom ram size, we must ensure that for L or XL environments are properly supplied with memory.
        '''
        try:
            try_increasing_max_heap()
        except Exception as e:
            logger.exception('Error checking/increasing heap size', e)

        progress_reporter = getProgressReporter()
        user_error = _do_incremental_patching(progress_reporter, patch_context)

        if user_error:
            raise user_error


@extend(Hook.OnSuccess)
def on_success(patch_context):
    """
    At OnSuccess phase the patch phase has finished successfully and the
    appliance is running on the new version. This phase is intended to be used
    to do any clean up or post patch configuration changes, once all components
    are up-to-date.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """

    executed_patches = get_executed_patches(patch_context)

    for patch in executed_patches:
        try:
            patch.on_success(patch_context)
        except:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            errtb = traceback.format_exception(exc_type, exc_value,
                                               exc_traceback)
            logger.error(
                'OnSuccess hook for patch with version %s failed with error:\n %s' % (
                    patch.get_version(), ''.join(errtb)))


def _do_incremental_expand(patch_context):
    """
    Incrementally expands all applicable expandable patches.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    logger.info('Getting applicable patches for expand.')
    expandable_patches = get_applicable_patches_for_expand(patch_context)
    if len(expandable_patches) == 0:
        logger.info('No applicable patches for expand.')
        return

    last_patch = expandable_patches[-1]
    logger.info(
        'The last patch that will be applied for expand is %s with version %s' % (
        last_patch.get_name(), last_patch.get_version()))

    for patch in expandable_patches:
        logger.info('Executing expand for patch %s with version %s' % (patch.get_name(), patch.get_version()))
        try:
            patch.do_expand(patch_context)
        except Exception as ex:
            logger.exception(
                'Failed to expand patch %s with version %s. Exception: %s' % (
                    patch.get_name(), patch.get_version(), ex))
            raise ex
        mark_patch_as_expanded(patch.get_name(), patch.get_version(), patch_context)
        logger.info('Successfully expanded patch.')

    logger.info('All applicable patches successfully expanded!')


def _do_incremental_revert(patch_context):
    """
    Incrementally reverts all expanded patches.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    expanded_patches = get_expanded_patches(patch_context)
    if len(expanded_patches) == 0:
        return
    logger.info('Starting the revert of expanded patches.')
    for patch in reversed(expanded_patches):
        logger.info('Executing revert for patch %s with version %s' % (patch.get_name(), patch.get_version()))
        try:
            patch.do_revert(patch_context)
        except Exception as ex:
            logger.exception(
                'Failed to revert patch %s with version %s. Exception: %s' % (
                    patch.get_name(), patch.get_version(), ex))
            raise ex
    logger.info('All expanded patches successfully reverted!')


def _do_incremental_contract(patch_context):
    """
    Incrementally contracts all expanded patches.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    expanded_patches = get_expanded_patches(patch_context)
    if len(expanded_patches) == 0:
        return
    logger.info('Starting the contract of expanded patches.')
    for patch in reversed(expanded_patches):
        logger.info('Executing contract for patch %s with version %s' % (patch.get_name(), patch.get_version()))
        try:
            patch.do_contract(patch_context)
        except Exception as ex:
            logger.exception(
                'Failed to contract patch %s with version %s. Exception: %s' % (
                    patch.get_name(), patch.get_version(), ex))
            raise ex
    logger.info('All expanded patches successfully contracted!')


def _do_incremental_patching(progress_reporter, patch_context):
    """
    Incrementally applies all applicable patches.

    :param progress_reporter: Progress reporter that is updated on every successfully
    applied patch
    :type progress_reporter: _ProgressReporter

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext

    :return: user_error: error occurred during patching
    :rtype: UserError
    """

    user_error = None

    logger.info('Get applicable patches')
    applicable_patches = get_applicable_patches(patch_context)
    if len(applicable_patches) == 0:
        logger.info('No applicable patches found.')
        return user_error

    last_patch = applicable_patches[-1]
    logger.info('The last patch that will be applied is %s with version %s' % (last_patch.get_name(), last_patch.get_version()))

    cur_progress = 0
    progress_step = int(100 / len(applicable_patches))

    for patch in applicable_patches:
        # Apply the patch
        try:
            logger.info('Executing patch %s with version %s' % (patch.get_name(), patch.get_version()))
            patch.do_patching(patch_context)
            cur_version = patch.get_version()
            mark_patch_as_executed(patch.get_name(), cur_version, patch_context)
            logger.info('Successfully applied patch.')

        except UserError as e:
            logger.exception('Failed to apply patch with version %s' %
                             patch.get_version())
            user_error = e
            break

        except Exception:
            logger.exception('Failed to apply patch with version %s' %
                             patch.get_version())
            cause = localizedString(_T(
                'vsphere.ui.patch.fail.generic',
                'Error when applying vSphere Client patch version $s'),
                patch.get_version())
            user_error = UserError(cause)
            break

        # Update the progress
        cur_progress += progress_step
        progress_reporter.updateProgress(cur_progress, patch.get_summary())

    if not user_error:
        logger.info('All patches applied successfully')
        progress_reporter.updateProgress(100, localizedString(_T(
            'vsphere.ui.patch.success',
            'vSphere Client was patched successfully.')))

    elif cur_progress != 0:
        logger.error('Not all patches were applied. Latest applied patch is %s' %
                     cur_version)
        progress_reporter.updateProgress(100, localizedString(_T(
            'vsphere.ui.patch.fail.partial',
            'vSphere Client was patched partially to version %s'), cur_version))

    else:
        logger.error('Failed to patch the vSphere Client')
        progress_reporter.updateProgress(100, localizedString(_T(
            'vsphere.ui.patch.fail.all',
            'vSphere Client patching failed')))

    return user_error


def _get_service_name():
    """
    Read the vSphere Client service name from services.json

    :return: vSphere Client service name:
        - vsphere-ui - on Windows
        - vsphere-ui - on Linux
    :rtype: str
    """

    logger.info('Retrieving service name...')
    comp_def = getComponentDefintion('vsphere-ui')
    service_name = comp_def.serviceName
    logger.info('serviceName = %s' % service_name)
    return service_name


def _delete_virgo_work_folder():
    """
    Delete the virgo work folder, if exists.
    """

    work_dir = os.path.join(VSPHERE_UI_ROOT_DIR, 'server', 'work')
    if os.path.exists(work_dir):
        logger.info('Deleting vSphere Client work folder: %s' % work_dir)
        shutil.rmtree(work_dir)
        logger.info('Done deleting vSphere Client work folder')
    else:
        logger.info('Work folder does not exist: %s' % work_dir)


def _is_rdu_upgrade(patch_context):
    """
    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext
    """
    return (getTargetFSS(NDU_LIMITED_DOWNTIME_FSS)
            and (not isDisruptiveUpgrade(patch_context)))


def _create_upgrade_marker_file(patch_context):
    """
    For in-place (disruptive) based upgrades, creates upgrade.marker file.
    For RDU based upgrades, creates upgrade.marker.template file.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext

    :returns: The full path of the created file
    """
    if isDisruptiveUpgrade(patch_context):
        logger.info(
            'Creating a marker file for vsphere-ui upgrade at %s ' % VSPHERE_UI_UPGRADE_MARKER_FILE)
        pathlib.Path(VSPHERE_UI_UPGRADE_MARKER_FILE).touch()
        return VSPHERE_UI_UPGRADE_MARKER_FILE
    else:
        marker_file_template = os.path.join(str(patch_context.stageDirectory),
                                            VSPHERE_UI_UPGRADE_MARKER_FILE_TEMPLATE)
        logger.info(
            'Creating a marker file template for vsphere-ui upgrade at %s ' % marker_file_template)
        pathlib.Path(marker_file_template).touch()
        return marker_file_template
