#!/usr/bin/env python
# Copyright 2016-2020 VMware, Inc.
# All rights reserved. -- VMware Confidential

"""
    This script is a utility to be used by both day0 and b2b scripts
"""
import sys
import os
import logging
import subprocess
import json
import tempfile
import time
import re
from time import sleep
from datetime import datetime
from os.path import join, exists, abspath, dirname, isdir

sys.path.append(join(
    dirname(__file__),
    "patches",
    "libs",
    "feature-state"))

from fss_utils import getTargetFSS
import product_utils

logger = logging.getLogger(__name__)

# Used to figure out initial deployment options
BUILD_INFO_FILE = '/etc/vmware/.buildInfo'

# Used to find what is the deployment type
DEPLOYMENT_TYPE_FILE = "/etc/vmware/deployment.node.type"

# Used to figure out initial deployment options
BUILD_INFO_FILE = '/etc/vmware/.buildInfo'

# Name of the ignore list file, need to provide the deployment type to it
IGNORELISTPATTERN = "%s_ignore.lst"

DEPLOYMENT_ROOT = abspath(join(dirname(__file__), os.pardir))
RPM_MANIFEST_FILE = join(DEPLOYMENT_ROOT, 'rpm-manifest.json')
ALLOWLIST_FILE = join(DEPLOYMENT_ROOT,'scripts', 'allowed_components.json')
MANIFEST_FILE = join(DEPLOYMENT_ROOT, 'manifest-latest.xml')
SERVICES_JSON = join(DEPLOYMENT_ROOT,
                     "scripts/patches/libs/sdk/config/services.json")

_rpmManifestJson = None

def getSourceFSS(name):
    '''
    getSourceFSS is to check source vCenter Server FSS is enabled or not
    :param name: name
    :return: True or False
    '''
    import featureState
    featureState.init(enableLogging=False)
    return hasattr(featureState, name) and getattr(featureState, name)

def isB2BUpgrade():
    ''' Indicate if this is upgrade or patching
    '''

    if isGateway():
        logger.info("On gateway it is aways upgrade.")
        return True

    isUpgrade = os.environ.get('VMWARE_B2B') == '1' \
                 or getTargetFSS('VMWARE_B2B') != False
    logger.debug("Bundle will execute upgrade: %s", isUpgrade)
    return isUpgrade

def isGateway():
    ''' Indicate if the source is gateway or not
    '''
    if os.path.isfile(BUILD_INFO_FILE):
        with open(BUILD_INFO_FILE, 'r') as fp:
            if 'CLOUDVM_NAME:VMware-vCenter-Cloud-Gateway' in fp.read():
                logger.info('Running on a VMC Gateway appliance.')
                return True
        logger.debug('Not running on a VMC Gateway appliance.')
        return False
    else:
        logger.warn('File %s does not exist' % BUILD_INFO_FILE)
    return False

def get_rpm_manifest():
    global _rpmManifestJson
    if _rpmManifestJson is None:
        with open(RPM_MANIFEST_FILE) as f:
            _rpmManifestJson = json.load(f)
    return _rpmManifestJson

def _searchRpm(rpmManifestJson, pattern):
    result = []
    for entry in list(rpmManifestJson.values()):
        if re.match(pattern, entry["relativepath"]):
            result.append(entry["relativepath"])
    return result

def isB2BComponentAllowed(componentName, allowlist, ignorelist):
    ''' Indicate if the component is allowed based on the filtering criterials
    applied on the bundle.

    @param componentName: The exact component name as seen in the payload dir
    @type componentName: str
    @param allowlist: List of allowlisted components for this execution
    @type allowlist: list(str)
    @type ignorelist: List of ignorelist components for this execution
    @type ignorelist: list(str)
    '''

    ignorelisted = ignorelist is not None and componentName in ignorelist
    allowlisted = allowlist is None or componentName in allowlist

    return not ignorelisted and allowlisted

def _backspaceStdout(l):
    """ Go back l characters in stdout and replace them with spaces
    """

    sys.stdout.write("%s%s%s" % (("\b" * l), (" " * l), ("\b" * l)))
    sys.stdout.flush()


def _getFormattedTime():
    """ Output the current time in a format consistent with
        the logging.
        example: 2014-07-24 08:42:37,084
    """

    d = datetime.now()
    return ("%04d-%02d-%02d %02d:%02d:%02d,%03d" %
            (d.year, d.month, d.day, d.hour, d.minute, d.second, d.microsecond / 1000))

def _getRpmsFromFilter(filterFunc):
    ''' Returns list of rpms that match the filter provided. The rpms are calculated
    based on what is in services.json file and the patch repository.

    @param filterFunc: Function that is used to decide if the rpms of given service
    in services.json should be added or not.
    @type filterFunc: lambda dict: bool

    @return: List of names of the rpms that match the filter. The name is full name
    including version and .rpm.
    '''

    result = []
    with open(SERVICES_JSON) as f:
        servicesJson = json.load(f)

    rpmManifestJson = get_rpm_manifest()

    for name, service in list(servicesJson.items()):
        if filterFunc(service):
            files = service["appliance"]["installFiles"]
            logger.debug("Rpm matching filter - service %s, rpms %r", name, files)
            for fil in files:
                fil = fil.replace('.', '\.')
                pattern = fil.replace('@version@', '.*')
                logger.debug("File pattern %s" % pattern)
                result.extend(_searchRpm(rpmManifestJson["files"], pattern))
    return result


def getRpmIgnoreList(componentAllowlist, componentIgnorelist):
    """Retrieves the rpm ignorelist based on the current deployment type and other
    filtering mechanizms.

    @param componentAllowlist: List of allowlisted components for this execution
    @type componentAllowlist: list(str)
    @param componentIgnorelist: List of ignorelist components for this execution
    @type componentIgnorelist: list(str)

    @return: Rpms which should not be installed during the update workflow
    @rtype: list
    """
    deployType = ""
    if exists(DEPLOYMENT_TYPE_FILE):
        with open(DEPLOYMENT_TYPE_FILE) as f:
            deployType = f.read().strip()

    logger.debug("vCSA deployment Type: %s", deployType)
    ignoreFile = join(DEPLOYMENT_ROOT, "ignorelists",
                      IGNORELISTPATTERN % (deployType))

    ignoreRpms = []
    if exists(ignoreFile):
        with open(ignoreFile) as f:
            ignoreRpms = [line.strip() for line in f if line.strip()]
    else:
        logger.debug("File %s does not exist", ignoreFile)

    # process feature switches in target services.json
    fssFilter = lambda service: "FSSname" in service and getTargetFSS(service["FSSname"]) is False
    ignoreRpms.extend(_getRpmsFromFilter(fssFilter))

    if not isB2BComponentAllowed("vpostgres", componentAllowlist, componentIgnorelist):
       # When the vPostgres component is executed it takes care of those rpms as
       # they are used for it and at the end are removed. If the script is not
       # executed they should be put on the ignorelist as otherwise they got
       # installed and left on the machine. They are not part of the services.json
       # intentionally as they should only be present during B2B and not on fresh
       # install.
       rpmManifestJson = get_rpm_manifest()
       pattern = "VMware-Postgres-upgrade-.*\.rpm"
       logger.debug("vPostgres pattern used %s" % pattern)
       ignoreRpms.extend(_searchRpm(rpmManifestJson["files"], pattern))

    # if product not installed/enabled then ignore product rpms
    # (currently only HLM). This is only required  when update is from
    # combined repo, if update is from dedicated platform/HLM repo then
    # rpms will be specific to the update.
    if isGateway():
        products = product_utils.getAllProducts()
        for product in products:
            if not product_utils.isProductEnabled(product):
                logger.info("Not %s update filtering its rpms", product)
                ignoreRpms.extend(getRpmsForProduct(product))

    return ignoreRpms


def getServiceIgnoreList():
    """Retrieves the listof services not enabled due to FSS
    @return: services which should not be installed during the update workflow
    @rtype: list(string)
    """
    # process feature switches in target services.json
    result = []
    with open(SERVICES_JSON) as f:
        servicesJson = json.load(f)

    rpmManifestJson = get_rpm_manifest()

    for name, service in list(servicesJson.items()):
        if "FSSname" in service and getTargetFSS(service["FSSname"]) is False:
            result.append(name)
    return result


def getComponentsForProduct(product):
    ''' Returns B2B components for given product.

    @param product: A product that want to be check. For example HLM
    @type product: str

    @return: B2B components that are for this product.
    @rtype: list
    '''
    with open(SERVICES_JSON) as f:
        servicesJson = json.load(f)

    result = []
    for service in servicesJson.values():
        if product in service.get("products", []):
            result.extend(service.get("b2bScripts", []))
    logger.debug("B2B components scripts %s are for product %s",
                             result, product)

    return result


def getAllProductComps():
    '''
    Return all the b2b components based on services.json
    :return: list of b2b components
    '''
    with open(SERVICES_JSON) as f:
        servicesJson = json.load(f)

    allProductComps = []
    for service in servicesJson.values():
        if service.get("products", None):
            allProductComps.extend(service.get("b2bScripts", []))

    logger.debug('All product components %s', str(allProductComps))
    return allProductComps


def getGwPlatformComponents(patchRunnerComponentsDir):
    '''
    Returns list gateway platform b2b components that don't belong
    to any product
    :param patchRunnerComponentsDir:
    :return: list of gateway platform b2b components
    '''
    allComps = []
    for dr in os.listdir(patchRunnerComponentsDir):
        fullPath = join(patchRunnerComponentsDir, dr)
        if isdir(fullPath) and 'first_component' not in dr \
                and 'last_component' not in dr:
            allComps.append(dr)

    logger.debug('All components %s', str(allComps))

    allProductComps = getAllProductComps()

    gwPlatformComps = [ele for ele in allComps if ele not in allProductComps]

    logger.debug('All gateway platform components %s', str(gwPlatformComps))
    return gwPlatformComps


def getRpmsForProduct(product):
    ''' Returns rpms marked with given tag.

    @param product: A product that want to be check. For example hlm
    @type product: str

    @return: Rpm marked with the given product. This is based on services.json thus
       it will never return Photon rpms.
    @rtype: list
    '''
    result = _getRpmsFromFilter(lambda service: "products" in service and product in service['products'])
    return result

def runCommand(command, progress=False, message="", onSubcmdStart=None):
    """Run a command and waits until its completion.

    @param command: The command to run
    @type command: list

    @param progress: Flag indicating whether there should be a progress bar
    @type progress: boolean

    @param message: The message to be displayed if the ahead of the progress bar
    @type message: str

    @param onSubcmdStart: Executed when a subcommand is started
    @type onSubcmdStart: Callback
      Example: def onSubcmdStart():

    @return: Tuple of (stdout, stderr, processExitCode)
    @rtype: tuple
    """

    logger.debug("Running command: %s", command)

    with tempfile.NamedTemporaryFile(prefix='update-script-out', dir='/var/log/vmware/applmgmt') as out:
        with tempfile.NamedTemporaryFile(prefix='update-script-err', dir='/var/log/vmware/applmgmt') as err:
            logger.debug('You can find the output of the command in temp files '
                         'out %s, err %s', out.name, err.name)
            proc = subprocess.Popen(command, stdout=out,
                                    stderr=err, stdin=None)

            if onSubcmdStart:
                onSubcmdStart()

            if progress and not os.environ.get('B2B_no_stdout'):
                # This code creates a progress bar and polls to see if the
                # process has completed, once the progress bar has finished.
                # If the process has completed the return with the error code,
                # otherwise repeat the progress bar.
                progressbar_width = 4
                progressbar_time = 1.0
                # setup progress bar
                sys.stdout.write(" [%s] : %s%s" % (_getFormattedTime(), message,
                                                   (" " * progressbar_width)))
                sys.stdout.flush()
                while (proc.poll() is None):
                    _backspaceStdout(progressbar_width)
                    for _i in range(progressbar_width):
                        sleep(progressbar_time / progressbar_width)
                        sys.stdout.write(".")
                        sys.stdout.flush()
                sys.stdout.write("\n")
                sys.stdout.flush()

            rc = proc.wait()
            out.seek(os.SEEK_SET)
            err.seek(os.SEEK_SET)
            stdout = out.read()
            stderr = err.read()
            logger.debug("STDOUT: %s" % stdout)
            logger.debug("STDERR: %s" % stderr)
            return (stdout, stderr, rc)

def isRootPasswdExpired():
    '''
    Check if VCSA root password is expired
    before patching start.
    '''
    import spwd
    pwdCfg = spwd.getspnam('root')
    maxdays = pwdCfg[4]
    if maxdays != -1:
        if time.time() > (maxdays + pwdCfg[2]) * 86400:
            return True
    return False
