# Copyright (c) 2017-2020 VMware, Inc.  All rights reserved.
# -- VMware Confidential

"""VCDB VMODL In-Place Upgrade (B2B) Patch Sequencer

This module defines a mechanism to apply the VCDB VMODL patches in their
correct sequence. Take a look at the unit tests in bora/vim/test/vcdbVmodlB2B.
"""

import enum
import xml.parsers.expat
from translator import AnyTranslator


class _Logger(object):
    """Null logger to use when we don't need logging"""

    def noop(self, *args, **kwargs):
        """A placeholder for every method in this class"""
        pass

    def __getattribute__(self, name):
        """Hook 'noop' for 'info', 'warning', etc"""
        return object.__getattribute__(self, 'noop')


class VersionParser(object):
    """A parser to parse only the version information of a VMODL XML

    The relevant version information is in the 'versionId' and 'xmlns' tags.

    Note that this discovery mechanism may be rendered unnecessary by future
    changes.
    """

    def __init__(self):
        """Init the underlying expat parser"""

        self.versionId = None
        self.xmlns = None
        self.parser = xml.parsers.expat.ParserCreate()
        self.parser.StartElementHandler = self.Start

    class StopParser(Exception):
        """An exception to stop the parser with"""
        pass

    def Start(self, _, attrs):
        """This parser is interested only in the root attributes."""

        # The second argument "name" is unused.
        self.versionId = attrs['versionId']
        self.xmlns = attrs['xmlns']
        raise self.StopParser()

    def Parse(self, text):
        """Main parser API"""

        try:
            self.parser.Parse(text)
        except self.StopParser:
            pass


def GetLtsSet(pyVmomi):
    if hasattr(pyVmomi.VmomiSupport, 'ltsVersions'):
        return pyVmomi.VmomiSupport.ltsVersions
    else:
        return pyVmomi.VmomiSupport.publicVersions


def GetVimVersion(pyVmomi):
    return pyVmomi.VmomiSupport.newestVersions.GetWireId('vim').split('/')[1]


def GetLtsVimVersion(pyVmomi):
    return GetLtsSet(pyVmomi).GetWireId('vim').split('/')[1]


def IsUnstable(wireId):
    return wireId.startswith("u")


def Deserialize(pyVmomi, text):
    try:
        return pyVmomi.SoapAdapter.Deserialize(text)
    except xml.parsers.expat.ExpatError as origExcept:
        raise Exception('Invalid XML (%s): %s' % (origExcept, text))


def ConvertToLts(pyVmomi, text, xmlNs):
    version = GetLtsSet(pyVmomi).GetNameW(xmlNs)
    versionId = pyVmomi.VmomiSupport.versionIdMap[version]

    value = Deserialize(pyVmomi, text)

    # Note that the returned value by Serialize is a b'binary string'.
    text = pyVmomi.SoapAdapter.Serialize(
        value, version=version).decode("utf-8")

    return (text, versionId)


class BranchType(enum.Enum):
    main = 0
    productBranch = 1


def GetBranchType(pyVmomi):
    return BranchType.main if IsUnstable(GetVimVersion(pyVmomi)) \
        else BranchType.productBranch


class PatchSequencer(object):
    """Provides API for translating a VMODL XML

    This class takes into account the breaking changes between the source and
    the desitnation pyVmomi and applies the appropriate patches in the correct
    order.
    """

    def __init__(self, ImportSrcPyVmomi, ImportDstPyVmomi, patches=None,
                 logger=None, checkDstPyVmomiFn=lambda x: None):
        """Initializes both pyVmomis and orders the breaking change patches

        For each breaking change a patched copy of the source pyVmomi is made
        to handle this particular breaking change. Also, a translator is made
        to translate to this particular pyVmomi from the previous one.
        """

        self.logger = logger if logger else _Logger()
        self.patches = []  # [Patch]
        self.srcPyVmomi = ImportSrcPyVmomi()
        self.dstPyVmomi = ImportDstPyVmomi()

        # Check if the new pyVmomi is OK.
        # This is the current pyVmomi integration test.
        checkDstPyVmomiFn(self.dstPyVmomi)

        # patches - {'branch': {'vmodlNamespace': [Patch]}}
        patches = patches if patches else {}

        # breakingChanges - {'branch': {'vmodlNamespace': count}}
        try:
            breakingChanges = self.srcPyVmomi.VmomiSupport.GetBreakingChanges()
        except AttributeError:  # GetBreakingChanges might be missing.
            breakingChanges = {}

        for branch, patchesPerBranch in patches.items():
            for namespace, patchesPerNamespace in patchesPerBranch.items():
                bcCount = breakingChanges.get(branch, {}).get(namespace, 0)
                for patchesPerSequence in patchesPerNamespace[bcCount:]:
                    self.patches.extend(patchesPerSequence)

        numberOfPatches = len(self.patches)

        pyVmomi = [self.srcPyVmomi]
        pyVmomi.extend(ImportSrcPyVmomi() for _ in range(0, numberOfPatches))

        for i in range(1, numberOfPatches + 1):
            for j in range(0, i):
                self.patches[j](pyVmomi[i])

        self.translators = [patch.Translator(pyVmomi[i], pyVmomi[i + 1])
                            for i, patch in enumerate(self.patches)]
        self.translators.append(AnyTranslator(pyVmomi[-1], self.dstPyVmomi))
        self.logger.info("Newest source vim version: %s",
                         GetVimVersion(self.srcPyVmomi))
        self.sourceBranchType = GetBranchType(self.srcPyVmomi)
        sourceLtsVersion = GetLtsVimVersion(self.srcPyVmomi)
        self.isSourceM8 = sourceLtsVersion == "6.9.1"
        self.logger.info("Upgrading from LTS %s to LTS %s", sourceLtsVersion,
                         GetLtsVimVersion(self.dstPyVmomi))
        self.logger.info(
            "Patches:{0} Breaking changes:{1}".format(
                len(self.patches), self.HasBreakingChanges()))

        srcNewestVersions = self.srcPyVmomi.VmomiSupport.newestVersions
        self.logger.info("Source      Unstable Versions:{0}".format(
            sorted(srcNewestVersions.EnumerateWireIds())))
        dstNewestVersions = self.dstPyVmomi.VmomiSupport.newestVersions
        self.logger.info("Destination Unstable Versions:{0}".format(
            sorted(dstNewestVersions.EnumerateWireIds())))

        srcLtsVersions = GetLtsSet(self.srcPyVmomi)
        self.logger.info("Source      LTS Versions:{0}".format(
            sorted(srcLtsVersions.EnumerateWireIds())))
        dstLtsVersions = GetLtsSet(self.dstPyVmomi)
        self.logger.info("Destination LTS Versions:{0}".format(
            sorted(dstLtsVersions.EnumerateWireIds())))

    def IsUpgradeNeeded(self):
        srcNewestVersions = self.srcPyVmomi.VmomiSupport.newestVersions
        dstNewestVersions = self.dstPyVmomi.VmomiSupport.newestVersions
        srcLtsVersions = GetLtsSet(self.srcPyVmomi)
        dstLtsVersions = GetLtsSet(self.dstPyVmomi)

        return self.HasBreakingChanges()\
            or \
            sorted(srcNewestVersions.EnumerateWireIds()) != \
            sorted(dstNewestVersions.EnumerateWireIds()) \
            or \
            sorted(srcLtsVersions.EnumerateWireIds()) != \
            sorted(dstLtsVersions.EnumerateWireIds())

    def GetSourceBranchType(self):
        return self.sourceBranchType

    def HasBreakingChanges(self):
        return self.isSourceM8 or self.patches

    def NoBreakingChanges(self):
        return not self.HasBreakingChanges()

    def NormalizeXML(self, text):
        """We have invalid XMLs in VCDB. See Bug 1784639."""

        _XSINamespace = 'xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"'
        if text.find(_XSINamespace) != -1:
            return text

        self.logger.debug("Error -> adding missing XSI namespace.")

        result = text.split(' ')
        result.insert(1, _XSINamespace)
        return ' '.join(result)

    def TranslateText(self, text):
        """Translates the text applying the breaking change patches"""

        text = self.NormalizeXML(text)

        parser = VersionParser()
        parser.Parse(text)

        srcVersionId = parser.versionId

        srcXmlNs = parser.xmlns.split(":", 1)[-1]  # Drop "urn:".

        # srcPyVmomi unstable -> srcPyVmomi LTS
        # TODO: Remove 'srcVersionId == "x"' after we guarantee that VCDB is
        #       valid prior starting B2B.
        if IsUnstable(srcVersionId) or (
                self.GetSourceBranchType() == BranchType.main and
                srcVersionId == "x"):
            text, srcVersionId = ConvertToLts(self.srcPyVmomi, text, srcXmlNs)
        # else:  # No need for trimming of LTS APIs.

        # Convert to destination LTS
        dstVmomiSupport = self.dstPyVmomi.VmomiSupport
        dstVersion = dstVmomiSupport.ltsVersions.GetNameW(srcXmlNs)
        dstVersionId = dstVmomiSupport.versionIdMap[dstVersion]

        # Do a simple string replace if the translations will be trivial.
        if self.NoBreakingChanges():
            # Note that this replace doesn't work on binary strings.
            return text.replace('versionId="{0}"'.format(srcVersionId),
                                'versionId="{0}"'.format(dstVersionId), 1)

        # At least the destination translator is in the list.
        assert len(self.translators) >= 1

        value = Deserialize(self.srcPyVmomi, text)

        # srcPyVmomi stable -> midPyVmomi stable -> dstPyVmomi stable
        for translator in self.translators:
            value = translator.TranslateAny(value)

        # Note that at least one translation was done.
        assert srcVersionId != dstVersionId

        dstSerialize = self.dstPyVmomi.SoapAdapter.Serialize

        # Note that Serialize returns a b'binary string' and *not* utf-8.
        return dstSerialize(value, version=dstVersion).decode("utf-8")
