# Copyright (c) 2020-2024 VMware, Inc.  All rights reserved.
# All rights reserved. -- VMware Confidential
"""
The module contains common patching functionality for VMware vCenter services.
The component should be executed as last component in patching workflow.
"""
import os
import logging
import json

from patch_specs import (
    DiscoveryResult, ValidationResult, Requirements, PatchInfo,
    RequirementsResult)
from extensions import extend, Hook
from l10n import msgMetadata as _T, localizedString as _
from os_utils import executeCommand, getCommandExitCode
from reporting import getMessageReporter
from vmware_b2b.patching.config import settings
from patch_errors import UserError
from vcsa_utils import isDisruptiveUpgrade

logger = logging.getLogger(__name__)
MY_PAYLOAD_DIR = os.path.dirname(__file__)

SENDMAIL_CONF_FILE="/usr/bin/sendmail_config.py"
RETRY_FAILED_TO_START_SERVICES_TEXT = _(_T('ur.services.retry.start.text',
                                 'Failed to start all services, will retry operation.'))

FAILED_TO_START_SERVICES_TEXT = _(_T('ur.services.start.text',
                                 'Failed to start all services after successful patching.'))

BUILD_INFO_FILE = '/etc/vmware/.buildInfo'

@extend(Hook.Discovery)
def discover(ctx):  # pylint: disable=W0613
    '''DiscoveryResult discover(PatchContext sharedCtx) throw UserUpgradeError
    '''
    # There is special handling for last-patching-component component,
    # guaranteed
    # That all other components will be executed before this one
    return DiscoveryResult(displayName=_(_T("common.comp.displayName",
                                            "VMware vCenter Server Patcher")),
                           componentId="last-patching-component")


@extend(Hook.Requirements)
def collectRequirements(ctx):  # pylint: disable=W0613
    '''RequirementsResult collectRequirements(PatchContext sharedCtx)'''
    mismatches = []
    requirements = Requirements()
    patchInfo = PatchInfo()

    return RequirementsResult(requirements, patchInfo, mismatches)


@extend(Hook.Validation)
def validate(ctx):  # pylint: disable=W0613
    '''ValidationResult validate(PatchContext sharedCtx)'''
    mismatches = []

    return ValidationResult(mismatches)


def _executePreInstallScript(scriptLocation, args):
    '''Executes party pre-isntall script. The script contains update logic
    necessary to be executed just before the rpms are installed. It could also
    contain shared patching logic, e.g. update the OS patcher rpm, up front the
    update of all rpms.

    It also stops the affected services

    @param scriptLocation: Location to the script which needs to be executed
    @type scriptLocation: str

    @param args: Arguments which have to be passed to the script
    @type args: list
    '''
    logger.info("Running pre-install script %s", scriptLocation)
    stdout, stderr, exitCode = executeCommand(['/bin/python', scriptLocation] + args)
    logger.info("pre-install script completed. Stdout=%s, Stderr=%s,"
                "exit-code=%s", stdout, stderr, exitCode)
    if exitCode:
        raise Exception("Unable to execute script %s" % scriptLocation)

def _getRootStagedDirectory(ctx):
    '''Gets location of the root staged directory. Each of its
    sub-directories are a component specific ones.

    @param ctx: Patch context shared among all patching hooks
    @type ctx: PatchContext

    @return: The full path location on the disk
    @rtype: str
    '''
    # Stage directory points to $STAGE_DIR/patch_runner/<compName> defined in
    # update_script.py
    stageDir = ctx.stageDirectory
    result = os.path.abspath(os.path.join(stageDir, os.pardir))
    return result

def _getUpgradeDownloadDir(ctx):
    ''' Get upgrade download directory location
    '''
    stageDir = ctx.stageDirectory
    return os.path.abspath(os.path.join(stageDir, os.pardir, os.pardir))

def _getStagedScript(ctx, scriptName):
    '''Gets location of the script downloaded to staged directory.

    @param ctx: Patch context shared among all patching hooks
    @type ctx: PatchContext

    @param scriptName: Name of the script
    @type scriptName: str

    @return: The full path location on the disk
    @rtype: str
    '''
    # Stage directory points to $STAGE_DIR/patch_runner/<compName> defined in
    # update_script.py
    result = os.path.abspath(os.path.join(_getUpgradeDownloadDir(ctx),
                                          "scripts", scriptName))
    return result


@extend(Hook.Prepatch)
def prePatch(ctx):  # pylint: disable=W0613
    '''void prePatch(PatchContext sharedCtx) throw UserUpgradeError'''

    if isDisruptiveUpgrade(ctx):
        preInstallScript = _getStagedScript(ctx, "pre-install.py")
        rootStageDir = _getRootStagedDirectory(ctx)
        _executePreInstallScript(preInstallScript, [rootStageDir])
    else:
        logger.info("Pre-patch hook doesn't have nondisruptve logic.")


def _updateSendmailConfigs(scriptLocation):
    '''Updates Sendmail configuration. The script update Sendmail.cf by retrieving
    sendmail server details from VCDB.

    @param scriptLocation: Location to the script which needs to be executed
    @type scriptLocation: str
    '''
    logger.info("Running sendmail script %s", scriptLocation)

    if not os.path.exists('/etc/mail/auth/auth-info'):
        os.mkdir('/etc/mail/auth', 0o700)
        os.mknod('/etc/mail/auth/auth-info')

    logger.info("Updating sendmail.cf file by retrieving mail-server details from VCDB.")
    stdout, stderr, exitCode = executeCommand(['/bin/python', scriptLocation])

    if exitCode:
        logger.error("Failed to update sendmail.cf. Manually configure sendmail."
                    "Stdout=%s, Stderr=%s, exit-code=%s", stdout, stderr, exitCode)
    else:
        logger.info("Successfully updated sendmail config file on target. Stdout=%s,Stderr=%s,"
                "exit-code=%s", stdout, stderr, exitCode)


def _executePostPatchScript(scriptLocation, args):
    '''Executes post-patch script. The script contains update logic necessary to
    be executed just after all components patching hook complete successfully.

    @param scriptLocation: Location to the script which needs to be executed
    @type scriptLocation: str

    @param args: Arguments which have to be passed to the script
    @type args: list
    '''
    logger.info("Running post-patch script %s", scriptLocation)
    stdout, stderr, exitCode = executeCommand([scriptLocation] + args)
    logger.info("post-patch script completed. Stdout=%s,Stderr=%s,"
                "exit-code=%s", stdout, stderr, exitCode)
    if exitCode:
        raise Exception("Unable to execute script %s" % scriptLocation)

def _isLeafServiceUpgrade(ctx):
    '''
    '''
    rpmManifest = os.path.join(_getUpgradeDownloadDir(ctx), "rpm-manifest.json")
    if not os.path.exists(rpmManifest):
        return False
    with open(rpmManifest) as fp:
        manifest = json.load(fp)
        header = manifest.get("header", {})
        leafService = "true" == header.get("leaf_service", "false")
        return leafService
    return False

def _getPostPatchScript(ctx):
    ''' Finds the correct patch script for the current execution. This
    Can be the post-patch.sh script or the leaf-post-patch.sh
    '''
    postPatchScript = _getStagedScript(ctx, "post-patch.sh")
    leafPatchScript = _getStagedScript(ctx, "leaf-post-patch.sh")

    rpmManifest = os.path.join(_getUpgradeDownloadDir(ctx), "rpm-manifest.json")

    if _isLeafServiceUpgrade(ctx):
        return leafPatchScript
    return postPatchScript

def _startAllVMwareServices():
    '''Starts all enabled VMware services.

    @return: True if the operation succeed and False otherwise
    @rtype: bool
    '''
    command = ['/bin/service-control',
                   '--start',
                   '--all']

    logDir = settings.loggingData.directory
    outFileName = os.path.join(logDir, 'startAllServices.out.log')
    errFileName = os.path.join(logDir, 'startAllServices.err.log')

    logger.info("Starting all VMware services... The immediate command stdout "
                "is redirected to file %s and stderr is redirected to file %s",
                outFileName, errFileName)
    exitCode = getCommandExitCode(command,
                                  localStdoutFile=outFileName,
                                  localStderrFile=errFileName)
    with open(outFileName) as fp:
        out = fp.read()

    with open(errFileName) as fp:
        err = fp.read()

    logger.info("Start All VMware services: command=%s, exit-code=%s, "
                "stdout=%s, stderr=%s", command, exitCode, out, err)

    if exitCode != 0:
        raise UserError(FAILED_TO_START_SERVICES_TEXT)

def _perfromStartAllVmwareServices():
    ''' Performs same function as startAllVmwareServices but retries if we fail
    the first time
    '''
    try:
        _startAllVMwareServices()
    except UserError:
        # In some situation the start all fails because some services takes
        # long time to start, next start all starts them usually
        getMessageReporter().postWarning(RETRY_FAILED_TO_START_SERVICES_TEXT)
        _startAllVMwareServices()

def isGateway():
    ''' Indicate if the source is gateway or not
    '''
    if os.path.isfile(BUILD_INFO_FILE):
        with open(BUILD_INFO_FILE, 'r') as fp:
            if 'CLOUDVM_NAME:VMware-vCenter-Cloud-Gateway' in fp.read():
                logger.info('Running on a VMC Gateway appliance.')
                return True
            logger.info('Not running on a VMC Gateway appliance.')
            return False
    else:
        logger.warning('File %s does not exist', BUILD_INFO_FILE)
        return False

@extend(Hook.Patch)
def patch(ctx):  # pylint: disable=W0613
    '''void patch(PatchContext sharedCtx) throw UserUpgradeError'''
    if isDisruptiveUpgrade(ctx):
        postPatchScript = _getPostPatchScript(ctx)
        rootStageDir = _getRootStagedDirectory(ctx)
        _executePostPatchScript(postPatchScript, [rootStageDir])
        _perfromStartAllVmwareServices()
    else:
        # Note here start all is handled from outside
        logger.info("Patch hook doesn't have nondisruptve logic.")


@extend(Hook.Contract)
def contract(ctx):
    '''void contract(Contract ctx) throw UserError'''
    logger.info("Running last component contract hook.")
    if not isDisruptiveUpgrade(ctx):
        _updateSendmailConfigs(SENDMAIL_CONF_FILE)
        logger.info("Successfully finished last component contract hook.")

@extend(Hook.OnSuccess)
def onsuccess(ctx):
    '''void onsuccess(ctx)
    '''
    logger.info("Running last component onsuccess hook.")
