# Copyright (c) 2019-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
# coding: utf-8
"""
PatchRunner integration hook for vtsdb Service

This module integrates vtsdb Service patching scripts with the
PatchRunner Framework.
Reference: https://wiki.eng.vmware.com/VSphere2016/vSphere2016Upgrade/Inplace/Patch_Extensibility

"""
__author__ = 'VMware, Inc.'
__copyright__ = 'Copyright 2019-2020 VMware, Inc. All rights reserved.'

import logging
import os
import sys
import json

for entry in os.environ['VMWARE_PYTHON_PATH'].split(":"):
    if entry not in sys.path:
        sys.path.append(entry)

from cis.defaults import get_cis_install_dir
from cis.tools import run_command
from extensions import extend, Hook
from fss_utils import getTargetFSS
from vcsa_utils import isDisruptiveUpgrade
from patch_errors import UserError
from reporting import getProgressReporter
import rpm_utils
from vcsa_utils import getComponentDiscoveryResult
from vtsdb import utils

sys.path.append(os.path.join(os.path.dirname(__file__), "patches"))

sys.path.insert(0, os.sep.join([os.environ['VMWARE_POSTGRES_BASE'], 'share', 'python-modules']))
from vpostgres_cis.firstboot import GetPgVersion

"""
Update from version 0(or None - means vtsdb doesnt exist) to 1
means fresh installation of vtsdb so trigger vtsdb firstboot.
For incremental patching i.e update from version 1 to version 2, execute patch_01.py
and from 2 to 3 execute patch_02.py and so on.
patches = [
    ("2","patch_01"),
    ("3","patch_02"),
]
"""
patches = []

COMPONENT_NAME = 'vtsdb'

# vtsdb is behind the FSS VSTATS
FSS_NAME = 'VSTATS'

logger = logging.getLogger(__name__)


@extend(Hook.Discovery)
def discover(ctx):
    """
    First function which decides if vtsdb component is part of patching workflow.
    :param ctx:
    :return: DiscoveryResult
    """
    logger.info("Executing discovery hook for vtsdb service")

    # Check if FSS 'VSTATS' is enabled, if not then no need to patch vstdb
    if not getTargetFSS(FSS_NAME):
        logger.info("Skip patching vtsdb as FSS is not enabled")
        return None

    # If FSS is enabled then check vstdb version. If current version is not found
    # and getTargetFSS returned true, this is an install of vtsdb service
    # so vtsdb firstboot should be triggered during the patch. If vstdb is already
    # installed, then find all applicable patches.
    utils.preserveVersionFile(ctx.stageDirectory)
    current_version = utils.getSourceVersion(ctx.stageDirectory)

    if not current_version:
        logger.info("vtsdb current version is not found. \
        this is a fresh install")
    else:
        if patches:
            logger.info("Applicable incremental patches for vtsdb: {}".format(patches))
    return getComponentDiscoveryResult(COMPONENT_NAME)


@extend(Hook.Expand)
def expand(ctx):
    '''void expand(PatchContext ctx) throw UserError'''
    if not isDisruptiveUpgrade(ctx):
        current_version = utils.getSourceVersion(ctx.stageDirectory)
        progressReporter = getProgressReporter()
        progressReporter.updateProgress(0, 'Start vtsdb expand')
        logger.info("Prepare for the incremental patching of vtsdb service")
        _doIncrementalExpand(ctx)
        progressReporter.updateProgress(100, 'expand vtsdb completed')

@extend(Hook.Contract)
def contract(ctx):
    '''void patch(PatchContext ctx) throw UserError'''
    if not isDisruptiveUpgrade(ctx):
        progressReporter = getProgressReporter()
        progressReporter.updateProgress(0, 'Start vtsdb contract')
        # verify that additional disks required for vtsdb is added to appliance
        logger.info("Check if additional disks are available for vtsdb and vtsdb_log")
        _doIncrementalContract(ctx)
        progressReporter.updateProgress(100, 'contract vtsdb completed')


@extend(Hook.Patch)
def patch(ctx):
    """void patch(PatchContext ctx) throw UserError
    Main patch logic is executed here
    """

    if isDisruptiveUpgrade(ctx):
        # Get vtsdb config data
        cfg = utils.get_config_data()
        # verify that additional disks required for vtsdb is added to appliance
        logger.info("Check if additional disks are available for vtsdb and vtsdb_log")
        _validateDiskAddition()
        # Check if it is install or upgrade of vtsdb service
        current_version = utils.getSourceVersion(ctx.stageDirectory)

        try:
            if not current_version:
                # Current version is none, so trigger the firstboot
                logger.info("No vtsdb service is found, trigger firstboot of vtsdb.")
                _runFirstboot(ctx)
            else:
                """
                vtsdb instance already exists, perform the following steps
                1. Check if pg_upgrade is required and upgrade accordingly
                2. Check if there are any incremental patches
                    to be applied and update accordingly
                """
                logger.info("Check if pg_upgrade is required")
                if utils.PGUpgradeInplaceRequired(cfg):
                    logger.info("In place PG Upgrade is required")
                    res = utils.PGUpgradeInstanceInplace(cfg)
                    if res != 0:
                        raise Exception('Inplace pg_upgrade failed')

                logger.info("Prepare for the incremental patching of vtsdb service")
                doIncrementalPatching(ctx)
        finally:
            # Once the upgrade is complete, remove upgrade related packages
            # Note: This RPM list has to be appended
            #       with the corresponding RPM name
            # whenever there is a pg major version upgrade
            if GetPgVersion(cfg['VTSDB_DATA']) == cfg['PG_VERSION']:
                rpm_list = ['VMware-Postgres-upgrade-11',
                            'VMware-Postgres-upgrade-10',
                            'VMware-Postgres-upgrade-96',
                            'VMware-Postgres-upgrade-95',
                            'VMware-Postgres-upgrade-94',
                            'VMware-Postgres-upgrade-93']
                if len(rpm_list) != 0:
                    rpm_utils.removeRpms(rpm_list)


def _validateDiskAddition():
    """
    Verify if the additional disks required by vtsdb are present,
    if no then patching should fail
    :return:
    """
    verify_disk_cmd = ['lsblk', '-o', 'name,size,type', '-n', '-l', '--json']
    vtsdb_disk_count = 0
    rc, stdout, stderr = run_command(verify_disk_cmd)
    logger.info("Command return code: {}".format(rc))
    if rc != 0:
        logger.error("Command execution error: {}".format(stderr))
        cause = "vtsdb.disk.validation.fail" % str(stderr)
        raise UserError(cause)
    json_data = json.loads(stdout)

    """
    Sample json output:
    {
    "blockdevices": [
      {"name": "sda", "size": "12G", "type": "disk"},
      ...
      {"name": "vtsdb_vg-vtsdb", "size": "25G", "type": "lvm"},
      {"name": "vtsdblog_vg-vtsdblog", "size": "15G", "type": "lvm"},
     ]
    }
    Traverse the json output, verify vtsdb_vg-vtsdb & vtsdblog_vg-vtsdblog
    disks are found. If no vtsdb patching will fail
    """
    for obj in json_data['blockdevices']:
        if obj['name'] == 'vtsdb_vg-vtsdb' or obj['name'] == 'vtsdblog_vg-vtsdblog':
            vtsdb_disk_count += 1

    logger.info("Disk count is: {}".format(vtsdb_disk_count))
    if (vtsdb_disk_count != 2):
        logger.error("Required disks are not available, vtsdb patching will fail")
        cause = "vtsdb.disk.unavaialble"
        raise UserError(cause)
    return True


def _runFirstboot(ctx):
    logger.info("Starting the vtsdb firstboot...")
    for python_path in os.environ['VMWARE_PYTHON_PATH'].split(":"):
        if python_path not in sys.path:
            sys.path.append(python_path)
    # command for running vtsdb firstboot
    COMPONENT_HOME_FOLDER = '%s/vmware-%s' % (get_cis_install_dir(), COMPONENT_NAME)
    COMPONENT_FIRSTBOOT_PATH = os.path.join(COMPONENT_HOME_FOLDER,
        'firstboot',
        'vtsdb-firstboot.py')
    cmd = [os.environ['VMWARE_PYTHON_BIN'],
           COMPONENT_FIRSTBOOT_PATH,
           '--action',
           'firstboot',
           '--compkey',
           COMPONENT_NAME
           ]
    rc, _, stderr = run_command(cmd)
    logger.info("Firstboot return code: {}".format(rc))
    if rc != 0:
        logger.error("Firstboot error: {}".format(stderr))
        cause = 'vtsdb.patch.firstboot.fail' % str(stderr)
        raise UserError(cause)


def doIncrementalPatching(ctx):
    """
    Incrementally apply all applicable patches
    :param ctx:
    :param current_version:
    :return:
    """
    user_error = None
    init_version = int(utils.getSourceVersion(ctx.stageDirectory))

    # store the current version
    current_version = init_version

    for patch_version, modulePath in patches:
        logger.info("Checking if need patch %s on version %s"
                    % (modulePath, patch_version))
        try:
            mod = __import__(modulePath)
        except Exception as e:
            err_msg = "Failed to import vtsdb patch module" \
                      "%s! Error: %s."
            logger.error(err_msg % (modulePath, str(e)))
            cause = 'vtsdb.patch.import.fail: {} for {}'.format(err_msg, modulePath)
            user_error = UserError(cause)
            break
        if mod.is_patch_needed(current_version):
            logger.info("Patch %s needed" % modulePath)
            try:
                mod.doPatching(ctx)
                # Update the current patch version after
                # successful patching
                current_version = patch_version
            except Exception as e:
                err_msg = "Failed to apply patch: {} due to: {}".format(modulePath, str(e))
                logger.error(err_msg)
                cause = 'vtsdb.patch.incrementalPatching.fail: {}, {}, {}'.format(err_msg, patch_version, str(e))
                user_error = UserError(cause)
                break
            logger.info("Patch %s applied" % modulePath)

    if not user_error:
        logger.info("All patches applied successfully")
        return
    else:
        logger.error("Failed to patch vtsdb service")
        progress_reporter = getProgressReporter()
        progress_reporter.updateProgress(0,'vtsdb service patching failed')
    raise user_error

def _doIncrementalContract(ctx):
    init_version = int(utils.getSourceVersion(ctx.stageDirectory))

    # Nothing to contract if it is a new service
    if init_version is None:
        return

    for ver, modulePath in reversed(getApplicablePatches(init_version)):
        logger.info("Applying vtsdb contract %s on version %s" % (ver, modulePath))
        mod = __import__(modulePath)
        mod.doContract(ctx)
        logger.info("Contract %s applied" % (modulePath))

def _doIncrementalExpand(ctx):
    """
    Incrementally apply all applicable expands
    :param ctx:
    :param current_version:
    :return:
    """
    user_error = None
    current_version = int(utils.getSourceVersion(ctx.stageDirectory))

    for ver, modulePath in getApplicablePatches(current_version):
        logger.info("Applying expand %s on version %s" % (ver, modulePath))
        mod = __import__(modulePath)
        mod.doExpand(ctx)
        logger.info("Expand %s applied" % (modulePath))

    logger.info("All expands applied successfully")

def getApplicablePatches(sourceVersion):
    if sourceVersion is None:
        return patches
    return [(ver, modulePath) for (ver, modulePath) in patches if int(ver) > int(sourceVersion)]