# Copyright 2016 VMware, Inc.  All rights reserved. -- VMware Confidential
#
''' Module containing status aggregator functions used in reporting
'''
import logging
import os
import json
from multiprocessing import Queue
from time import sleep
from collections import OrderedDict

from status_reporting_sdk.statusAggregator import StatusAggregator
from status_reporting_sdk.componentStatus import ComponentsExecutionStatusInfo, MessageInfo, \
    ErrorInfo, ProgressData
from l10n import msgMetadata as _T, localizedString as _
from reporting.reporting_factory import MessageType

logger = logging.getLogger(__name__)

PULLING_INTERVAL = 1  # 1 seconds pulling interval

B2B_RESUME_MESSAGE = _(_T('patch.internal.resume.progress',
                          'In-place upgrade is being resumed.'))

def _createInternalDataStorageForResume(reportingQueue,
                                        statusFile,
                                        completedReporters):
    ''' Creates an InternalDataStorage that has the information from the
    previous run already populated

    @param reportingQueue: A queue that is used to dispatch progress,
            message question messages from the message dispatchers.
    @type reportingQueue: queue

    @param statusFile: File where the status is to be dumped. If there is already
        status there it will load it
    @type statusFile: str

    @param completedReporters: Names of already completed reporters in order
        to be able to resume form the same point.
    @type completedReporters: list(str)

    @rtype: _InternalDataStorage
    '''
    messages = []
    if os.path.exists(statusFile):
        logger.info('Loading previous messages from status file %s', statusFile)
        with open(statusFile) as f:
            data = json.load(f)
            for m in data.get('info', []):
                messages.append((m, MessageInfo.Severity.INFO))
            for m in data.get('warning', []):
                messages.append((m, MessageInfo.Severity.WARNING))

    logger.info('Pre-populated with progress data in order to start '
                'from the same point as before.')
    progressData = OrderedDict()
    doneProgress = ProgressData(ProgressData.State.SUCCESS,
                                percentage=100,
                                progress_message=B2B_RESUME_MESSAGE)
    for name in completedReporters:
        progressData[name] = doneProgress

    return _InternalDataStorage(reportingQueue,
                               oldProgressData=progressData,
                               oldMessages=messages)

def getStatusAggregator(trackedReportingProducers, statusFile,
                        reportingQueue=None,
                        completedReporters=None):
    ''' Returns a StatusAggregator who aggregates the different reporting messages
        received on the queue and reports them in the file provided. If completedReporters
        is provided it starts progress reporting from not 0 and also tries to reload
        old messages from the status file

    @param trackedReportingProducers: Number of tracked reporters
    @type trackedReportingProducers: int

    @param statusFile: File where the status is to be dumped
    @type statusFile: str

    @param reportingQueue: A queue that is used to dispatch progress,
            message question messages from the message dispatchers.
    @type reportingQueue: queue

    @param completedReporters: Names of already completed reporters in order
        to be able to resume form the same point.
    @type completedReporters: list(str)

    @rtype: StatusAggregator
    '''
    logger.debug('Creating StatusAggregator')
    reportingQueue = reportingQueue or Queue()
    progressAggregator = ProgressAggregator(trackedReportingProducers)

    if completedReporters:
        logger.debug('Pre-populating internal storage as this is resume.')
        internalDataStorage = _createInternalDataStorageForResume(reportingQueue,
                                                                  statusFile,
                                                                  completedReporters)
    else:
        internalDataStorage = _InternalDataStorage(reportingQueue)

    internalStatusReader = InternalStatusReader(internalDataStorage)

    class _StatusAggregator(StatusAggregator):
        def __init__(self):
            self.reportingQueue = reportingQueue
            super(_StatusAggregator, self).__init__(PULLING_INTERVAL, statusFile,
                                   progressAggregator,
                                   inputFun=internalStatusReader,
                                   replyHandler=None)
        def stop(self):
            # Allow for any data that is still being written to the queue to be flush
            sleep(PULLING_INTERVAL)
            super(_StatusAggregator, self).stop()

    return _StatusAggregator()

class _QuestionInfo(object):
    ''' Internal structure keeping together the question and sender'''
    def __init__(self, question, sender):
        self.question = question
        self.sender = sender

class _InternalDataStorage(object):
    '''This is a class that provides in-memory caching capabilities and reading
    from reporting queues.
    '''
    def __init__(self, reportingQueue, oldProgressData=None, oldMessages=None):
        '''
        @param reportingQueue: A queue that is used to dispatch progress,
            message question messages from the message dispatchers.
        @type reportingQueue: queue

        @param oldProgressData: Pre-populated dict of oldProgressData that is
            used to initialize this internal storage in order to start reporting
            from different point then 0
        @type oldProgressData: OrderedDict

        @param oldMessages: Pre-populated messages that are used to initialize
            the internal storage in order to have messsages from the start.
        @type oldMessages: list of Messages
        '''
        self.reportingQueue = reportingQueue
        self.progressData = OrderedDict()
        self.questionMsg = []
        self.messages = []

        if oldProgressData:
            self.progressData.update(oldProgressData)

        if oldMessages:
            self.messages.extend(oldMessages)

    def _progressReceiver(self, message):
        ''' Process message of type progress
        '''
        oldProgressData = self.progressData.get(message['metadata']['identifier'])
        currentProgressData = message['payload']

        # If there is no new progress message reuse old one
        if not currentProgressData.progress_message and oldProgressData:
            currentProgressData.progress_message = oldProgressData.progress_message

        # If error is reporter preserve old percentage as error progress data
        # comes with progress of 0
        if currentProgressData.status == ProgressData.State.ERROR and oldProgressData:
            currentProgressData.percentage = oldProgressData.percentage

        self.progressData[message['metadata']['identifier']] = currentProgressData

    def _questionsReceiver(self, message):
        ''' Process message of type question
        '''
        self.questionMsg.append(_QuestionInfo(message['payload'], message['metadata']['identifier']))

    def _messagesReceiver(self, message):
        ''' Process message of type message
        '''
        self.messages.append(message['payload'])

    def _unknownReceiver(self, message):
        ''' Process message of any unknown type
        '''
        logger.warning('Got unknown message. %s', message)

    def _configReceiver(self, message):
        ''' Process message of type configuration
        '''
        # Only framework allowed to change configuration
        if message['metadata']['identifier'] == 'B2B-patching':
            if message['payload'] == 'RESET':  # Allows to reset data externally
                self.progressData = {}
                self.questionMsg = []
                self.messages = []

    # This provides switch functionality
    messageDispatcher = {
            MessageType.UNKNOWN : _unknownReceiver,
            MessageType.MESSAGE_REPORTING : _messagesReceiver,
            MessageType.QUESTION_REPORTING : _questionsReceiver,
            MessageType.PROGRESS_REPORTING : _progressReceiver,
            MessageType.CONFIGURATION_CHANGE : _configReceiver
    }

    def synchronize(self):
        ''' This method synchronize the data with the external data producers
            and stores the received data in memory for later use
        '''
        while not self.reportingQueue.empty():
            msg = self.reportingQueue.get_nowait()
            # Switching message process based on message type, handles new types
            unknownReceiver = self.messageDispatcher[MessageType.UNKNOWN]
            receiver = self.messageDispatcher.get(msg['metadata']['type'], \
                                                    unknownReceiver)
            receiver(self, msg)

    def getNextQuestion(self):
        ''' Returns the next cached question if any otherwise None.
            This method is not thread safe
        @rtype: _QuestionInfo
        '''
        if self.questionMsg:
            return self.questionMsg.pop()
        return None

class InternalStatusReader(object):
    ''' This class is responsible to aggregate reported messages so far
    '''
    def __init__(self, internalDataStorage):
        self.currentQuestion = None
        self.internalDataStorage = internalDataStorage

    def _updateQuestion(self):
        ''' Updates the internal status with next question asked
        '''
        if not self.currentQuestion:
            # Only single question can be processed at a given time so till the
            # question is processing any other question won't be processed.
            # Once the current question is answered then the next question can
            # start processing.
            self.currentQuestion = self.internalDataStorage.getNextQuestion()

    def _generateErrorInfo(self, messages):
        ''' Generates ErrorInfo from the error messages
        '''
        errorInfo = None
        for error in [e[0] for e in messages\
                      if e[1] == MessageInfo.Severity.ERROR]:
            if not errorInfo:
                errorInfo = ErrorInfo(error.detail,
                                      componentKey=error.componentKey,
                                      resolution=error.resolution,
                                      problemId=error.problemId)
            else:
                for cause in error.detail:
                    if cause not in errorInfo.detail:
                        errorInfo.appendErrorDetail(cause)
        return errorInfo

    def _updateMessages(self, status):
        ''' Updates the internal status with all messages posted so far
        '''
        messages = self.internalDataStorage.messages

        # Aggregate Info-messages and Warnings
        for (message, sevirity) in messages:
            if sevirity == MessageInfo.Severity.INFO:
                status.appendInfo(message)
            elif sevirity == MessageInfo.Severity.WARNING:
                status.appendWarning(message)

        # Aggregate Errors
        if status.error:
            return

        errorInfo = self._generateErrorInfo(messages)
        if errorInfo:
            # Set the error with aggregated error's detail
            # This always will be the first error in the queue
            status.setError(errorInfo)

    def _updateProgress(self, status):
        ''' Updates the internal status with latest progress data
        '''
        progresses = self.internalDataStorage.progressData
        for component in progresses:
            if progresses[component].status == ProgressData.State.ERROR:
                status.setError(progresses[component].progress_message)
            status.updateComponentProgress(component, progresses[component])

    def __call__(self):

        executionStatus = ComponentsExecutionStatusInfo()

        self.internalDataStorage.synchronize()

        self._updateMessages(executionStatus)
        self._updateProgress(executionStatus)
        self._updateQuestion()

        # Populate question in status until it is not answered
        if self.currentQuestion:
            # common-sdk requires to set None for the question once it has been
            # answered in order to remove reply file. If not pass None we are
            # going to leave the reply.json forever.
            executionStatus.setQuestion(self.currentQuestion.question)

            # Once the question is answered allow to process the next question
            if not self.currentQuestion.question:
                self.currentQuestion = None

        # Return snapshot of the upgrade process status
        return executionStatus

class ProgressAggregator(object):
    def __init__(self, progressProducers):
        ''' Create an instance of ProgressAggregator class responsible for
        aggregating progress data from all components
        @param progressProducers: how many progress producers are tracked by
        this aggregator
        @type progressProducers: int
        '''
        self.allSteps = progressProducers

    def __call__(self, internalStatus):
        ''' Calculates the aggregated progress based on internal status.
        '''
        # Sum of all reported states in percentage
        currentProgress = 0
        # Number of succeeded states
        succeededStates = 0

        overallState = ProgressData.State.RUNNING
        for progressData in internalStatus.allProgress.values():
            # Calculate percentage
            currentProgress += progressData.percentage
            # Calculate state
            if progressData.status == ProgressData.State.ERROR:
                overallState = ProgressData.State.ERROR
            elif progressData.status == ProgressData.State.SUCCESS:
                succeededStates += 1

        # Check for success
        if overallState == ProgressData.State.RUNNING and\
                succeededStates >= self.allSteps:
            # All are succeeded
            overallState = ProgressData.State.SUCCESS

        # Aggregate
        totalPercentage = 100 * self.allSteps
        overallProgress = int((1.0 * currentProgress / totalPercentage) * 100)
        # Guarantee that overall progress will be in range [0, 100]
        if overallProgress < 0:
            overallProgress = 0
        if overallProgress > 100:
            overallProgress = 100
        logger.debug('===============================')
        logger.debug('Overall Progress: %s%%', overallProgress)
        logger.debug('Executed steps %s out of %s', currentProgress, totalPercentage)
        logger.debug('Operation: %s', internalStatus.lastProgressMessage)
        logger.debug('===============================')

        # Return aggregated progress-data
        return ProgressData(percentage=overallProgress,
                            status=overallState,
                            progress_message=internalStatus.lastProgressMessage)
