# Copyright (c) 2019-2022 VMware, Inc.  All rights reserved. -- VMware Confidential
# coding: utf-8
"""
PatchRunner integration hook for vstats Service

This module integrates vstats service patching scripts with the
PatchRunner Framework.
Reference: https://wiki.eng.vmware.com/VSphere2016/vSphere2016Upgrade/Inplace/Patch_Extensibility

"""
__author__ = 'VMware, Inc.'
__copyright__ = 'Copyright 2019-2020 VMware, Inc. All rights reserved.'

import logging
import os
import sys

for entry in os.environ['VMWARE_PYTHON_PATH'].split(":"):
    if entry not in sys.path:
        sys.path.append(entry)

from cis.defaults import get_cis_install_dir
from cis.tools import run_command
from extensions import extend, Hook
from fss_utils import getTargetFSS
from vcsa_utils import isDisruptiveUpgrade
from patch_errors import UserError
from vcsa_utils import getComponentDiscoveryResult
from reporting import getProgressReporter
from . import utils

sys.path.append(os.path.join(os.path.dirname(__file__), "patches"))

"""
Update from version 0(or None - means vstats doesnt exist) to 1 is
fresh installation of vstats so trigger the firstboot scripts.
For incremental patching i.e update from version 1 to version 2, execute patch_01.py
and from 2 to 3 execute patch_02.py and so on.
patches = [
    ("2","patch_01"),
    ("3","patch_02"),
]
"""
patches = [
    (2,"patch_01"),
    (3,"patch_02"),
    (4,"patch_03"),
    (5,"patch_04")
]

COMPONENT_NAME = 'vstats'

# vstats is behind the FSS VSTATS
FSS_NAME = 'VSTATS'

logger = logging.getLogger(__name__)

@extend(Hook.Discovery)
def discover(ctx):
    """
    First function which decides if vstats component is part of patching workflow.
    :param ctx:
    :return: DiscoveryResult
    """
    logger.info("Executing discovery hook for vstats service")

    # Check if FSS 'VSTATS' is enabled, if not then no need to patch vstats
    if not getTargetFSS(FSS_NAME):
        logger.info("Skip patching vstats as FSS is not enabled")
        return None

    # If FSS is enabled, check vstats version. If current version is not found
    # and getTargetFSS returned true, this is an install of vstats service
    # so firstboot should be triggered during the patch. If vstats is already
    # installed, then find all applicable patches. If no applicable patches
    # found then no need to patch
    utils.preserveVersionFile(ctx.stageDirectory)

    current_version = utils.getSourceVersion(ctx.stageDirectory)
    if not current_version:
        logger.info("vstats current version is not found. \
        this is a fresh install")
    else:
        if not patches:
            logger.info("No applicable patches for vstats: patching is not required")
            return None
        logger.info("Applicable patches for vstats: {}".format(patches))
    return getComponentDiscoveryResult(COMPONENT_NAME)

@extend(Hook.Expand)
def expand(ctx):
    '''void expand(PatchContext ctx) throw UserError'''
    if not isDisruptiveUpgrade(ctx):
        current_version = utils.getSourceVersion(ctx.stageDirectory)
        progressReporter = getProgressReporter()
        progressReporter.updateProgress(0, 'Start vstats expand')
        logger.info("Prepare for the incremental patching of vstats service")
        _doIncrementalExpand(ctx)
        progressReporter.updateProgress(100, 'expand vstats completed')

@extend(Hook.Contract)
def contract(ctx):
    '''void patch(PatchContext ctx) throw UserError'''
    if not isDisruptiveUpgrade(ctx):
        progressReporter = getProgressReporter()
        progressReporter.updateProgress(0, 'Start vstats contract')
        _doIncrementalContract(ctx)
        progressReporter.updateProgress(100, 'contract vstats completed')


@extend(Hook.Patch)
def patch(ctx):
    """void patch(PatchContext ctx) throw UserError
    Main patch logic is executed here
    """
    if isDisruptiveUpgrade(ctx):
        # Check if it is install or update of vstats service
        current_version = utils.getSourceVersion(ctx.stageDirectory)
        if not current_version:
            # Current version is none, so trigger the firstboot
            logger.info("No vstats service found, trigger firstboot of vstats.")
            _runFirstboot(ctx)
        else:
            logger.info("Prepare for the incremental patching of vstats service")
            doIncrementalPatching(ctx)

def _runFirstboot(ctx):
   logger.info("Starting the vstats firstboot...")
   for python_path in os.environ['VMWARE_PYTHON_PATH'].split(":"):
       if python_path not in sys.path:
           sys.path.append(python_path)
       # command for running vtsdb firstboot
       COMPONENT_HOME_FOLDER = '%s/vmware-%s' % (get_cis_install_dir(), COMPONENT_NAME)
       COMPONENT_FIRSTBOOT_PATH = os.path.join(COMPONENT_HOME_FOLDER,
                                               'firstboot',
                                               'vstats-firstboot.py')
       cmd = [os.environ['VMWARE_PYTHON_BIN'],
              COMPONENT_FIRSTBOOT_PATH,
              '--action',
              'firstboot',
              '--compkey',
              COMPONENT_NAME
              ]
   rc, _, stderr = run_command(cmd)
   logger.info("Firstboot return code: {}".format(rc))
   if rc != 0:
       logger.error("Firstboot error: {}".format(stderr))
       cause = 'vstats.patch.firstboot.fail: %s' % [str(stderr)]
       raise UserError(cause)

def doIncrementalPatching(ctx):
    """
    Incrementally apply all applicable patches
    :param ctx:
    :param current_version:
    :return:
    """
    user_error = None
    init_version = int(utils.getSourceVersion(ctx.stageDirectory))
    # store the current version
    current_version = init_version

    for patch_version, modulePath in patches:
        logger.info("Checking if need patch %s on version %s"
                    % (modulePath, patch_version))
        try:
            mod = __import__(modulePath)
        except Exception as e:
            err_msg = "Failed to import vstats patch module %s! Error: %s."
            logger.error(err_msg % (modulePath, str(e)))
            cause = 'vstats.patch.import.fail: {} for {}'.format(err_msg, modulePath)
            user_error = UserError(cause)
            break
        if mod.is_patch_needed(current_version):
            logger.info("Patch %s needed" % modulePath)
            try:
                mod.doPatching(ctx)
                # update the current patch version after successful patching
                current_version = patch_version
            except Exception as e:
                err_msg = "Failed to apply patch: {} due to: {}".format(modulePath, str(e))
                logger.error(err_msg)
                cause = 'vstats.patch.incrementalPatching.fail: {}, {}, {}'.format(err_msg, patch_version, str(e))
                user_error = UserError(cause)
                break
            logger.info("Patch %s applied" % modulePath)

    if not user_error:
        logger.info("All patches applied successfully")
        return
    else:
        logger.error("Failed to patch vstats service")
        progress_reporter = getProgressReporter()
        progress_reporter.updateProgress(0, 'vstats service patching failed')
    raise user_error

def _doIncrementalExpand(ctx):
    """
    Incrementally apply all applicable expands
    :param ctx:
    :param current_version:
    :return:
    """
    user_error = None
    current_version = int(utils.getSourceVersion(ctx.stageDirectory))

    for ver, modulePath in getApplicablePatches(current_version):
        logger.info("Applying vstats expand %s on version %s" % (ver, modulePath))
        mod = __import__(modulePath)
        mod.doExpand(ctx)
        logger.info("Expand %s applied" % (modulePath))

    logger.info("All expands applied successfully")

def _doIncrementalContract(ctx):
    init_version = int(utils.getSourceVersion(ctx.stageDirectory))

    # Nothing to contract if it is a new service
    if init_version is None:
        return

    for ver, modulePath in reversed(getApplicablePatches(init_version)):
        logger.info("Applying contract %s on version %s" % (ver, modulePath))
        mod = __import__(modulePath)
        mod.doContract(ctx)
        logger.info("Contract %s applied" % (modulePath))

    logger.info("All contracts applied successfully")

def getApplicablePatches(sourceVersion):
    if sourceVersion is None:
        return patches
    return [(ver, modulePath) for (ver, modulePath) in patches if int(ver) > int(sourceVersion)]
