# Copyright 2020-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
#
'''
Holder of functions which are being used across some the phases
'''

import logging
import traceback
from contextlib import contextmanager

from patch_errors import UserError, ComponentError

from vmware_b2b.patching.executor.execution_facade import executeComponentHook
from vmware_b2b.patching.utils.reporting_utils import setupReporting, configureReportingFactory, \
    FRAMEWORK_COMPONENT, reportProgressError, configureLocalization
from vmware_b2b.patching.utils.idempotent_utils import updateAndSavePhaseContext, \
    calculateCompletedComponentsForRepoting
from os_utils import executeCommand

logger = logging.getLogger(__name__)


@contextmanager
def setupStatusAggregator(ctx, outputFile, hooks):
    '''Sets up the statusAggregator according to the context
    and the given hooks

    @param ctx: Global CPO patch phase context.
    @type ctx: vmware.patching.data.model.PatchPhaseContext

    @param outputFile: File where the output of the command needs to be written
      to
    @type outputFile: str

    @param hooks: The hooks that will be executed
    @type hook: [str]
    '''
    # setup context
    configureLocalization(ctx.locale)

    # setup statusAggregator
    completedReporters = calculateCompletedComponentsForRepoting(ctx, *hooks)
    statusAggregator = setupReporting(outputFile,
                                      len(len(hooks) * ctx.getPatchableComponents()),
                                      completedReporters=completedReporters)
    configureReportingFactory(FRAMEWORK_COMPONENT, statusAggregator.reportingQueue)
    try:
        yield statusAggregator
    finally:
        statusAggregator.stop()


def executeComponentsHook(hook, ctx, userData, reportingQueue):
    '''Executes the given hook for each discovered component.

    @param hook: The hook that is being executed
    @type hook: str

    @param ctx: Global CPO patch phase context.
    @type ctx: vmware.patching.data.model.PatchPhaseContext

    @param userData: Customer input, as result of component questions raised
      in @Discovery patching hook.
    @type userData: dict

    @param reportingQueue: A queue to be used for reporting purposes by the patching
        hook
    @type reportingQueue: multiprocessor.Queue
    '''
    executionContext = ctx.getPhaseExecution(hook)
    for c in ctx.getPatchableComponents(sort=True):
        if executionContext.isComponentExecuted(c.name):
            continue
        logger.info("Running %s hook for component %s", hook, c.name)
        executeComponentHook(hook, ctx, c, userData, reportingQueue)
        updateAndSavePhaseContext(ctx, executionContext, component=c.name)


def executeHook(hook,
                ctx,
                userData,
                statusAggregator):
    '''
    This is a generic entry point for upgrade phases which
    are executed after the discovery phase.
    It executes the given hook for each discovered component and
    does proper exception handling.

    @param hook: The hook that is being executed
    @type hook: str

    @param ctx: Global CPO patch phase context.
    @type ctx: vmware.patching.data.model.PatchPhaseContext

    @param userData: Customer input, as result of component questions raised
      in @Discovery patching hook.
    @type userData: dict

    @param statusAggregator: StatusAggregator which controls the reporting
    @type statusAggregator: status_reporting_sdk.statusAggregator.`

    @return: True if the phase succeed and False otherwise
    @rtype: bool
    '''
    try:
        # start execution of the hook
        executeComponentsHook(hook,
                              ctx,
                              userData,
                              statusAggregator.reportingQueue)

        return True
    except ComponentError as wrapperError:
        logger.exception("Patch hook %s got ComponentError.", hook)
        if isinstance(wrapperError.baseError, UserError):
            reportProgressError(wrapperError.baseError, wrapperError.componentKey)
        else:
            reportProgressError(identifier=wrapperError.componentKey)
        return False
    except UserError as userError:
        logger.exception("Patch hook %s got UserError.", hook)
        reportProgressError(userError)
        return False
    except Exception:  # pylint: disable=W0703
        logger.exception("Patch hook %s got unhandled exception.", hook)
        reportProgressError(errorText=traceback.format_exc())
        return False

def ensureVMdirParametersAreNotPresent():
    '''
    Removes the vmdir install parameters
    '''
    for key in ['vmdir.password']:
        logger.info("Ensuring %s is not present", key)
        cmd = ['/bin/install-parameter', '-s', key]

        # the stdout is not used as not to leak the credentials if the command fails
        _, stderr, exitCode = executeCommand(cmd)
        if exitCode != 0:
            logger.debug("Failed to remove %s parameter %s", key, stderr)
            raise Exception("Failed to remove %s parameter", key)
