# Copyright (c) 2018-2021 VMware, Inc. All rights reserved.
# VMware Confidential

import os
import sys
import logging
import shutil
from pathlib import Path

import defusedxml.cElementTree as et
from extensions import extend, Hook
from patch_specs import DiscoveryResult
from l10n import msgMetadata as _T, localizedString as _
from vcsa_utils import isDisruptiveUpgrade

_PAYLOAD_DIR = os.path.dirname(os.path.realpath(__file__))
_PATCHES_DIR = os.path.join(_PAYLOAD_DIR, "patches")
sys.path.extend([_PAYLOAD_DIR, _PATCHES_DIR])

logger = logging.getLogger(__name__)

class MergeXML(object):
    def __init__(self, filenames):
        if (len(filenames) < 0):
            logger.error("No XML files to merge!")
            return
        # save all the roots, in order
        try:
           self.roots = [et.parse(f).getroot() for f in filenames]
        except Exception as e:
            logger.error("Exception while parsing XML file " + str(e))

    def combine(self):
        for r in self.roots[1:]:
            # combine each element with the first one, and update that
            self.combine_element(self.roots[0], r)
        return et.tostring(self.roots[0])

    def combine_element(self, one, other):
        """
        Updates the text or the children of an element if another
        element is found in `one`, or adds it from `other` if not found.
        """

        mapping = {}
        for e1 in one:
            mapping[e1.tag] = e1;
        for el in other:
            if len(el) == 0:
                try:
                    # Update the text
                    mapping[el.tag].text = el.text
                except KeyError:
                    # An element with this name is not in the mapping
                    mapping[el.tag] = el
                    one.append(el)
            else:
                try:
                    # Recursively process the element, and update it in the same way
                    self.combine_element(mapping[el.tag], el)
                except KeyError:
                    # Not in the mapping
                    mapping[el.tag] = el
                    one.append(el)



@extend(Hook.Discovery)
def discover(ctx):
    logger.info("Discovering statsmonitor service")

    replicationConfig = {"/etc/vmware/statsmonitor/statsMonitor.xml": None}
    return DiscoveryResult(displayName=_(_T("common.displayName",
                                            "vmware-statsmonitor")),
                           componentId="vmware-statsmonitor",
                           replicationConfig=replicationConfig)

@extend(Hook.Patch)
def merge(ctx):
    if isDisruptiveUpgrade(ctx):
        new_xml_file = "/etc/vmware/statsmonitor/statsMonitor.xml.rpmnew"
        if os.path.exists(new_xml_file):
            old_file = "/etc/vmware/statsmonitor/statsMonitor.xml"
            logger.info("Executing statsMonitor post-patch script.")
            r = MergeXML((new_xml_file, old_file)).combine()
            with open("/tmp/merge.xml", "wb") as f:
                f.write(r)
            shutil.copyfile("/tmp/merge.xml", old_file)
            os.system("/usr/bin/service-control --restart vmware-statsmonitor")
            logger.info("Successfully executed statsMonitor post-patch script.")
