# Copyright 2015 - 2023  VMware, Inc.  All rights reserved.
# VMware Confidential
#

'''
Structures which models the patching data.
'''
import os
import logging
import json

from extensions import Hook
from patch_specs import PatchContext
from status_reporting_sdk.json_utils import JsonFileSerializer, Serializable

from vmware_b2b.patching.data.components_graph import GraphBuilder, sortTopologically

logger = logging.getLogger(__name__)


class Component(Serializable):
    '''Component model specification. The structure keeps reference to a component
    which want to take part of patching process.
    '''
    def __init__(self, name, patchScript, stageDir, discoveryResult,
                 requirementsResult):
        '''Creates new component.

        @param name: Component internal name. Component name will be used as
          identificator for distinguishing between components. The name is formed
          from the directory where its patching logic reside. For example:
          /foo/bar/patching/sso -> sso
        @type name: str

        @param patchScript: Holds information about where the patching component
          logic reside. Every component patch hook should be defined at
          <patchScriptPath>/__init__.py
        @type patchScript: str

        @param stageDir: Component specific stage directory. This is a unique
          per component directory, where a component can share data between
          the patching hooks.
        @type stageDir: str

        @param discoveryResult: Information about the component found after
          calling its discovery hook.
        @type discoveryResult: patch_specs.DiscoveryResult

        @param requirementsResult: Information about the component requirements.
        @type requirementsResult: patch_specs.RequirementsResult
        '''
        self.name = name
        self.patchScript = patchScript
        self.stageDir = stageDir
        self.discoveryResult = discoveryResult
        self.requirementsResult = requirementsResult

class Execution(Serializable):
    '''
    Will add the execution status of each component
    '''

    def __init__(self, completedComponents=None):
        '''
        Initialises the components argument to list
        '''
        self.completedComponents = completedComponents or []

    def executed(self, component):
        '''Marks the component as executed
        @param component: Name of the component which has been completed
        @type component: string
        '''
        self.completedComponents.append(component)


    def isComponentExecuted(self, component):
        '''checks the given component whether it is executed or not
        @param component: Name of the component which has been completed or not
        @type component: String

        @return: True if component is completed else false
        @rtype: bool
        '''
        return component in self.completedComponents

    def isStarted(self):
        ''' Returns if an execution is started
        @return: True if at least one component is completed else false
        @rtype: bool
        '''
        return len(self.completedComponents) > 0

class PatchPhaseContext(Serializable):
    '''Patch context shared between CPO executions. This is singleton object,
    same used through a single patching process. The context is created when
    CPO --discovery is executed, persisted on the system and the reloaded on
    every subsequent phase execution.
    '''
    # Location where patch phase context will be persisted/loaded to.
    CONTEXT_FILE = "patch_phase_context.json"

    def __init__(self,
                 stageDir,
                 components,
                 locale,
                 phaseExecutions=None,
                 upgradeType=PatchContext.DISRUPTIVE_UPGRADE):

        '''Creates PatchPhaseContext

        @param stageDir: global CPO stage directory. All components stage
          directories will be created as sub-directories of global stage directory.
        @type stageDir: str

        @param components: Components found during discvovery Hook execution
          which want to take part in the patching process.
        @type components: list of Component

        @param locale: Current locale. All user-facing messages should be
          translated to that locale.
        @type locale: str

        @param phasesExecutions: Dictionary containing information related to
        various phases running.
        @type phasesExecution: dict
        @param upgradeType: The type of the upgrade being performed. The allowed values
          are either PatchContext.DISRUPTIVE_UPGRADE or PatchContext.NONDISRUPTIVE_UPGRADE.
        '''
        self.stageDir = stageDir
        self.components = components
        self.locale = locale
        self.phaseExecutions = phaseExecutions or {}
        self.upgradeType = upgradeType

    def getPhaseExecution(self, phase):
        ''' Returns the State of the Phase
        @param phase:name of the phase
        @type phase: str

        @return: Returns the state of the phase
        @rtype: Execution or None if the state is not started
        '''
        if phase not in self.phaseExecutions:
            self.phaseExecutions[phase] = Execution()

        return self.phaseExecutions[phase]

    def isPhaseExecutionCompleted(self, phase):
        ''' Checks if given phase has being finished. It is tested against only
        the patchable components

        @param phase: Phase to check against. If it does not exists will create
        entry for it and return False
        @type phase: str

        @return if completed or nor
        @rtype bool
        '''
        execution = self.getPhaseExecution(phase)

        if not execution.isStarted():
            return False
        for c in self.getPatchableComponents(sort=False):
            if not execution.isComponentExecuted(c.name):
                return False
        return True

    def isPhaseExecutionStarted(self, phase):
        ''' Checks if given phase has being started.

        @param phase: Phase to check against. If it does not exists will create
        entry for it and return False
        @type phase: str

        @return if started or nor
        @rtype bool
        '''
        execution = self.getPhaseExecution(phase)
        return execution.isStarted()

    def isUpgradeStarted(self):
        ''' Returns if the upgrade is started. This check if we have execution
        of expand, systemprepare or prepatch hooks.
        @rtype: bool
        '''
        return self.isPhaseExecutionStarted(Hook.Expand) \
               or self.isPhaseExecutionStarted(Hook.SystemPrepare) \
               or self.isPhaseExecutionStarted(Hook.Prepatch)

    def _getComponent(self, componentId):
        '''Get a Component by its name.

        @param componentId: a component unique id
        @type componentId: str

        @return: A Component structure holding information about that component
        @rtype: Component

        @raise StopIteration: If a component cannot be found
        '''
        return next(c for c in self.components if c.discoveryResult.componentId == componentId)

    def getAllComponents(self, sort=False):
        '''Get all patchable and none-patchable components. The result could be
        sorted by component dependencies.

        @param sort: Specifies if the result has to be sorted.
        @type sort: bool

        @return: All components which take part in patch process
        @rtype: list of Component

        @raise ValueError: If a cyclic dependency is found
        '''
        if not sort:
            return self.components

        # Return components sorted by their dependency
        builder = GraphBuilder(self.components)

        componentsGraph = builder.buildDependenciesGraph()
        # Topological sort will raise exception if there is a cyclic dependency
        # in the dependency graph
        sortedNames = sortTopologically(componentsGraph)
        sortedComponents = [self._getComponent(name)  for name in sortedNames]
        return sortedComponents

    def getPatchableComponents(self, sort=False):
        '''Get all patchable components, and filter those which does not want to
        be patched. The result could be sorted by component dependencies.

        @param sort: Specifies if the result has to be sorted.
        @type sort: bool

        @return: All components which take part in patch process
        @rtype: list of Component

        @raise ValueError: If a cyclic dependency is found
        '''
        patchableComponents = []
        for c in self.getAllComponents(sort=sort):
            if not c.discoveryResult.patchable:
                logger.info("Skip patching component %s", c.name)
                continue
            patchableComponents.append(c)

        return patchableComponents

    def _getAnswerFileFormat(self):
        result = ""
        for c in self.components:
            for q in c.requirementsResult.requirements.questions:
                if result:
                    result += ",\n"
                result += '\t"%s" : "<ANSWER>"' % q.userDataId
        return "{\n%s\n}" % result

    def buildComponentExecutionContext(self, c, userData=None):
        '''Builds a component dedicated execution context amplified by the
        user answers in which the component is interested in.

        @param c: A component
        @type c: Component

        @param userData: All customer input, as result of component questions
        @type userData: dict

        @return: Component patch context, which could be given to any patch hook
        @rtype: patch_specs.PatchContext

        @raise ValueError: If a component question is not answered
        '''
        compUserData = {}
        userData = userData if userData is not None else {}
        for q in c.requirementsResult.requirements.questions:
            if q.userDataId not in userData:
                logger.info("Expect following userdata format: \n%s",
                            self._getAnswerFileFormat())
                error = "Question %s of component %s is not answered" % \
                                 (q.userDataId, c.name)
                logger.error(error)
                raise ValueError(error)
            compUserData[q.userDataId] = userData[q.userDataId]

        return PatchContext(c.stageDir, self.locale, compUserData, upgradeType=self.upgradeType)

    def persist(self, persistDir):
        '''Persists the context into the disk for later usages. This happens
        only at discovery phase

        @param persistDir: Directory where the patch context will be persisted to
        @type persistDir: str

        @raise Exception: If persist logic fails.
        '''
        ctxFile = os.path.join(persistDir, self.CONTEXT_FILE)
        fileSerializer = JsonFileSerializer(ctxFile)
        fileSerializer.serialize(self)

    @staticmethod
    def load(persistDir):
        '''Load patch context which has been preliminary persisted by the discovery
        phase

        @param persistDir: Directory where the patch context has been persisted to
        @type persistDir: str

        @return: Loaded patch context instance
        @rtype: PatchPhaseContext

        @raise Exception: If load logic fails.
        '''
        ctxFile = os.path.join(persistDir, PatchPhaseContext.CONTEXT_FILE)
        if not os.path.exists(ctxFile):
            raise ValueError('PatchPhaseContext file cannot be found at %s. '
                             'Did you forget to run discovery phase first?' % ctxFile)

        fileSerializer = JsonFileSerializer(ctxFile)
        ctxObj = fileSerializer.deserialize()

        if not isinstance(ctxObj, PatchPhaseContext):
            raise ValueError("Loaded object has type %s, expected PatchPhaseContext" %
                             type(ctxObj))

        rpmStageDir = None
        stagePathFile = "/storage/core/software-update/stage/stageDir.json"
        if os.path.exists(stagePathFile):
            with open(stagePathFile) as fp:
                rpmStageDir = json.load(fp)['StageDir']
        if rpmStageDir is not None and os.path.exists(rpmStageDir):
            # ctxObj.stageDir is updatedir which don't have rpms downloaded at stage phase.
            # replace ctxObj.stageDir with stagedir where rpms get downloaded at stage phase.
            # The staged directory is a directory. (Named patch_runner) inside rpmStageDir
            persistedStageDir = os.path.abspath(os.path.join(ctxObj.stageDir, os.pardir))
            if persistedStageDir != rpmStageDir:
                logger.info("Replacing updatedir %s with stagedir %s" % (persistedStageDir, rpmStageDir))
                ctxObj.stageDir = ctxObj.stageDir.replace(persistedStageDir, rpmStageDir, 1)
                for component in ctxObj.components:
                    # updating component specific patchscript dir which is child of ctxObj.stageDir.
                    component.patchScript = component.patchScript.replace(persistedStageDir, rpmStageDir, 1)
                    # updating component specific stage dir which is child of ctxObj.stageDir.
                    component.stageDir = component.stageDir.replace(persistedStageDir, rpmStageDir, 1)
        else:
            logger.warning('Cannot find the stage directory.')
        return ctxObj

