# Copyright 2015 VMware, Inc.  All rights reserved. -- VMware Confidential
#
'''
Script responsible to setup environment and execute the patching hook of vCSA
components. The script calls only those components which have been discovered
in prior discovery phase. The patch phase respect the component dependencies
defined in @Discovery hook, and execute the component patch hook in the right
order as pay attention to start the dependent services, before the hook is
executed.
'''
import logging
import traceback

from extensions import Hook
from patch_specs import PatchContext
from service_manager import getServiceManager, ServiceStarttype
from rpm_utils import removeRpms
from l10n import msgMetadata as _T, localizedString as _
from reporting import getMessageReporter, getProgressReporter
from patch_errors import UserError, ComponentError
from vmware_b2b.patching.utils.phase_utils import executeComponentsHook

from vmware_b2b.patching.data.model import PatchPhaseContext
from vmware_b2b.patching.executor.execution_facade import executeComponentHook
from vmware_b2b.patching.utils.reporting_utils import setupReporting, configureReportingFactory, \
    FRAMEWORK_COMPONENT, reportProgressError, configureLocalization
from vmware_b2b.patching.utils.idempotent_utils import updateAndSavePhaseContext, \
    calculateCompletedComponentsForRepoting


logger = logging.getLogger(__name__)
CANNOT_REMOVE_RPM_ERROR_TEXT = _T('ur.rpm.removal.error.text',
                           'Cannot remove automatically deprecated rpms %s from the system.')
CANNOT_REMOVE_RPM_ERROR_RESOLUTION = _(_T('ur.rpm.removal.error.resolution',
                                 'There are deprecated rpms that cannot be removed automatically.'
                                 ' To ensure that there will be no problems with subsequence patches'
                                 ' remove the rpms manually.'))

def _startDependentServices(c):
    '''Start a component dependent services. Assume those services are either
    new or already patched.

    @param c: A component, which dependent services has to be started
    @type c: vmware.patching.data.model.Component
    '''
    depServices = c.discoveryResult.dependentServices
    logger.info("Start component %s dependent services - %s", c.name,
                ", ".join(depServices))

    serviceManager = getServiceManager()
    for depService in depServices:
        if serviceManager.getStarttype(depService) not in [ServiceStarttype.MANUAL, ServiceStarttype.DISABLED]:
            serviceManager.start(depService)
        else:
            logger.info('Skip starting service %s as its starttype is manual/disabled.', depService)


def _patchComponents(ctx, userData, reportingQueue):
    '''Execute the patching workflow. The patching workflow first start all
    component dependent services and then call the component Patch hook.

    @param ctx: Global CPO patch phase context.
    @type ctx: vmware.patching.data.model.PatchPhaseContext

    @param userData: Customer input, as result of component questions raised
      in @Discovery patching hook.
    @type userData: dict

    @param reportQueue: A queue to be used for reporting purposes by the patching
        hook
    @type reportQueue: multiprocessor.Queue

    @return: True if patch process complete properly, and False otherwise
    @rtype bool

    @raise Exception: if patch hook execution fails.
    '''
    executionContext = ctx.getPhaseExecution(Hook.Patch)

    for c in ctx.getPatchableComponents(sort=True):
        if executionContext.isComponentExecuted(c.name):
            continue
        logger.info("Running Patch for component %s", c.name)
        _startDependentServices(c)
        executeComponentHook(Hook.Patch, ctx, c, userData, reportingQueue)
        updateAndSavePhaseContext(ctx, executionContext, component=c.name)
    return True

def _runOnSuccess(ctx, userData, reportingQueue):
    '''Execute the patching workflow. The patching workflow first start all
    component dependent services and then call the component Patch hook.

    @param ctx: Global CPO patch phase context.
    @type ctx: vmware.patching.data.model.PatchPhaseContext

    @param userData: Customer input, as result of component questions raised
      in @Discovery patching hook.
    @type userData: dict

    @param reportQueue: A queue to be used for reporting purposes by the patching
        hook
    @type reportQueue: multiprocessor.Queue

    @return: True if patch process complete properly, and False otherwise
    @rtype bool

    @raise Exception: if patch hook execution fails.
    '''
    executionContext = ctx.getPhaseExecution(Hook.OnSuccess)
    for c in ctx.getPatchableComponents(sort=True):
        if executionContext.isComponentExecuted(c.name):
            continue
        try:
            # Start all dependent services before the patch hook is being called out
            executeComponentHook(Hook.OnSuccess, ctx, c, userData, reportingQueue)
        except ComponentError:
            logger.exception("Patch hook %s got ComponentWrapperError.", Hook.OnSuccess)
        except UserError:
            logger.exception("Patch hook %s got UserError.", Hook.OnSuccess)
        except Exception:  # pylint: disable=W0703
            logger.exception("PostPatch hook %s got unhandled exception.", Hook.OnSuccess)

        updateAndSavePhaseContext(ctx, executionContext, component=c.name)

    return True


def _removeDeprecatedRpms(ctx):
    ''' Remove deprecated rpms from the system

    @param ctx: Global CPO patch phase context.
    @type ctx: vmware.patching.data.model.PatchPhaseContext

    @param outputFile: File where the output of the command needs to be written
      to
    @type outputFile: str
    '''
    executionContext = ctx.getPhaseExecution("RPM_REMOVAL")
    if executionContext.isComponentExecuted("done"):
        return
    rpmList = set()
    for component in ctx.getAllComponents():
        rpmList = rpmList.union(component.discoveryResult.deprecatedProducts)
    rpmList = list(rpmList)  # list is needed for reporting
    failingRpms = removeRpms(rpmList)
    if failingRpms:
        getMessageReporter().postWarning(_(CANNOT_REMOVE_RPM_ERROR_TEXT, ' '.join(failingRpms)), CANNOT_REMOVE_RPM_ERROR_RESOLUTION)
    updateAndSavePhaseContext(ctx, executionContext, component="done")

def patch(stageDir, userData, outputFile):
    '''The entry point of patch CPO phase.

    The patch phase is responsible to setup environment and execute the patching
    hook of vCSA components. The patch phase calls only those components which
    have been discovered in prior discovery phase. The patch phase respect the
    component dependencies defined in @Discovery hook, and execute the component
    patch hook in the right order as pay attention to start the dependent
    services, before the hook is executed.

    @param stageDir: global CPO stage directory. All components stage
      directories will be created as sub-directories of global stage directory.
    @type stageDir: str

    @param userData: Customer input, as result of component questions raised
      in @Discovery patching hook.
    @type userData: dict

    @param outputFile: File where the output of the command needs to be written
      to
    @type outputFile: str

    @return: True if the patch phase succeed and False otherwise
    @rtype: bool
    '''
    statusAggregator = None
    try:
        ctx = PatchPhaseContext.load(stageDir)
        configureLocalization(ctx.locale)

        isInplace = ctx.upgradeType == PatchContext.DISRUPTIVE_UPGRADE
        if isInplace and ctx.isPhaseExecutionCompleted(Hook.OnSuccess):
            return True  # It is fine to finish here if it is done
        elif not isInplace and ctx.isPhaseExecutionCompleted(Hook.Patch):
            #TODO Remove once not running patch during RDU
            return True  # It is fine to finish here if it is done

        hooksToExecute = [Hook.Patch]
        if isInplace:
            hooksToExecute.append(Hook.Contract)
            hooksToExecute.append(Hook.OnSuccess)

        # System one should always finish
        completedReporters = calculateCompletedComponentsForRepoting(ctx, *hooksToExecute)

        # add extra component for the framework as we might throw exception from
        # framework bit and this will mess the reporting
        statusAggregator = setupReporting(outputFile,
                                          len(len(hooksToExecute) * ctx.getPatchableComponents()) + 1,
                                          completedReporters=completedReporters)
        configureReportingFactory(FRAMEWORK_COMPONENT, statusAggregator.reportingQueue)

        getProgressReporter().updateProgress(0)
        _patchComponents(ctx, userData, statusAggregator.reportingQueue)

        if isInplace:
            _removeDeprecatedRpms(ctx)
            executeComponentsHook(Hook.Contract, ctx, userData, statusAggregator.reportingQueue)
            _runOnSuccess(ctx, userData, statusAggregator.reportingQueue)

        getProgressReporter().success()
        return True
    except ComponentError as wrapperError:
        logger.exception("Patch hook %s got ComponentWrapperError.", Hook.Patch)
        if isinstance(wrapperError.baseError, UserError):
            reportProgressError(wrapperError.baseError, wrapperError.componentKey)
        else:
            reportProgressError(identifier=wrapperError.componentKey)
        return False
    except UserError as userError:
        logger.exception("Patch hook %s got UserError.", Hook.Patch)
        reportProgressError(userError)
        return False
    except Exception:  # pylint: disable=W0703
        logger.exception("Patch hook %s got unhandled exception.", Hook.Patch)
        reportProgressError(errorText=traceback.format_exc())
        return False
    finally:
        if statusAggregator:
            statusAggregator.stop()
