# Copyright (c) 2017-2022 VMware, Inc.
# All rights reserved. -- VMware Confidential
"""
Integrate Rsyslog service with the B2B patching framework
"""
import os
import sys
import logging
import os_utils
from shutil import copyfile

from extensions import extend, Hook
from patch_specs import DiscoveryResult
from fss_utils import getTargetFSS
from l10n import msgMetadata as _T, localizedString as _
from vcsa_utils import isDisruptiveUpgrade
from reporting import getProgressReporter

_PAYLOAD_DIR = os.path.dirname(os.path.realpath(__file__))
_PATCHES_DIR = os.path.join(_PAYLOAD_DIR, "patches")
sys.path.extend([_PAYLOAD_DIR, _PATCHES_DIR])
import utils

logger = logging.getLogger(__name__)

# List of patch tuples: [(patch_version, patch_module)]
# Thus with the update of any patch_module, patch_version need to be updated.
patches = [
    ("7.0.1.1", "syslog_levels"),
]

RSYSLOG_CONF = '/etc/rsyslog.conf'
RSYSLOG_CONF_REPLICATE = '/etc/rsyslog.conf.replicate'
SYSLOG_CONF = '/etc/vmware-syslog/syslog.conf'
SYSLOG_CONF_REPLICATE = '/etc/vmware-syslog/syslog.conf.replicate'

@extend(Hook.Discovery)
def discover(ctx):
    logger.info("Discovering Rsyslog service")
    replicationConfig = {
        "/etc/rsyslog.conf.replicate": "/etc/rsyslog.conf",
        "/etc/vmware-syslog/syslog.conf.replicate": "/etc/vmware-syslog/syslog.conf",
        "/etc/vmware-syslog/version.txt" : "/etc/vmware-syslog/version.txt",
        "/etc/rsyslog.conf": None,
        "/etc/vmware-syslog/syslog.conf": None
    }
    return DiscoveryResult(displayName=_(_T("common.comp.displayName",
                                            "VMware Rsyslog")),
                           componentId="rsyslog",
                           replicationConfig=replicationConfig)


def _doIncrementalPatching(ctx):
    curVersion = getCurrentVersion()
    patchCurrentVersion = curVersion is None
    patchApplied = False

    for ver, modulePath in patches:
        patchCurrentVersion |= (curVersion != ver)
        mod = __import__(modulePath)

        if patchCurrentVersion:
            logger.info("Applying patch {0} from version {1}".format(
                        modulePath, ver))
            mod.doPatching(ctx)
            patchApplied = True
            logger.info("Patch {0} applied".format(modulePath))

        logger.info("Setting TLS1.2 for rsyslog")
        mod.setDefaultTls(ctx)
        patchApplied = True

    # creating replicate file for RDU when the patch path is not done.
    if not isDisruptiveUpgrade(ctx):
        logger.info("Creating replpicate files for syslog.")
        if not os.path.isfile(RSYSLOG_CONF_REPLICATE):
            copyfile(RSYSLOG_CONF, RSYSLOG_CONF_REPLICATE)
        if not os.path.isfile(SYSLOG_CONF_REPLICATE):
            copyfile(SYSLOG_CONF, SYSLOG_CONF_REPLICATE)
        logger.info("Successfully replpicated files for syslog.")
    utils.setSourceVersion(getLatestVersion())
    return patchApplied


def getCurrentVersion():
    return utils.getSourceVersion()


def getLatestVersion():
    return patches[-1][0]


def restartRsyslogService():
    restart_cmd = ['systemctl', 'restart', 'rsyslog']
    out, err, rc = os_utils.executeCommand(restart_cmd)
    if rc != 0:
        logger.error("Failed to restart Rsyslog after applying the patches."
                     " rc = %d, stdout: %s, stderr: %s" % (rc, out, err))
        raise ValueError("Failed during Rsyslog patching.")

@extend(Hook.Expand)
def expand(ctx):
    ''' Expand hook for rsyslog '''
    if (getTargetFSS("NDU_Limited_Downtime") and
        (not isDisruptiveUpgrade(ctx)) ):
        logger.info("Starting the expand hook for Rsyslog Service")
        progressReporter = getProgressReporter()
        progressReporter.updateProgress(0, _(_T("rsyslog.expand.begin",
                                                'Start rsyslog expand')))
        patchApplied = _doIncrementalPatching(ctx)
        if patchApplied != True:
            raise ValueError("Failed during Rsyslog patching.")
        logger.info("Expand hook completed for Rsyslog")
        progressReporter.updateProgress(100, _(_T("rsyslog.expand.complete",
                                                'Completed rsyslog expand')))
    else:
        logger.info("FSS is off or this is a disruptive upgrade, "
                    "all work is done in Patch hook")

@extend(Hook.Contract)
def contract(ctx):
    ''' Contract hook for rsyslog '''
    # Restarting rsyslog service after the config replication,
    # to reflect the changes for syslog.
    if (getTargetFSS("NDU_Limited_Downtime") and
        (not isDisruptiveUpgrade(ctx)) ):
        restartRsyslogService()

@extend(Hook.Revert)
def revert(ctx):
    ''' Revert hook for rsyslog '''
    if getTargetFSS("NDU_Limited_Downtime"):
        logger.info("Starting the revert hook for Rsyslog.")
        rsyslog_replicate_file = "/etc/rsyslog.conf.replicate"
        syslog_replicate_file = '/etc/vmware-syslog/syslog.conf.replicate'
        if os.path.isfile(syslog_replicate_file):
            os.remove(syslog_replicate_file)
        if os.path.isfile(rsyslog_replicate_file):
            os.remove(rsyslog_replicate_file)
        logger.info("Reverted the syslog patch.")
    else:
        logger.info("FSS is off, skipping revert hook")

@extend(Hook.Patch)
def patch(ctx):
    logger.info("Patching Rsyslog Service")
    progressReporter = getProgressReporter()
    progressReporter.updateProgress(0)
    if not getTargetFSS("NDU_Limited_Downtime") or isDisruptiveUpgrade(ctx):
        # In-place patch hook execution.
        needRestart = _doIncrementalPatching(ctx)
        if needRestart:
            progressReporter.updateProgress(50)
            restartRsyslogService()
    progressReporter.updateProgress(100)
