# Copyright (c) 2017-2023 VMware, Inc.  All rights reserved.
# All rights reserved. -- VMware Confidential
"""
Rsyslog patch module for patching syslog conf files
"""
import logging
import re
from fss_utils import getTargetFSS
from vcsa_utils import isDisruptiveUpgrade
from l10n import msgMetadata as _T, localizedString as _

logger = logging.getLogger(__name__)

def getChanges():
    return _(_T("rsyslog.patches.syslog.summary",
                "Patch rsyslog configurations."))

def setDefaultTls(ctx):
    logger.info("Setting the rsyslog TLS to TLS1.2")
    SYSLOG_CONF = '/etc/rsyslog.conf'

    tls_12 = ['+VERS-TLS1.2']
    tls_10_tls_11 = ['+VERS-TLS1.0', '+VERS-TLS1.1',
                   '-VERS-TLS1.0', '-VERS-TLS1.1']
    tls11_cipher = ['+AES-256-CBC', '+AES-128-CBC', '+SHA1']
    system_none = ['NONE']

    with open(SYSLOG_CONF, 'r') as f:
        data = f.readlines()

    #Setting the TLS1.2 for rsyslog,
    for index,line in enumerate(data):
        if line.strip().startswith('gnutlsprioritystring'):
            split_list = line.split("=")
            cipher_and_tls = split_list[1].strip()[1:-1].split(":")
            protocols = list(set(cipher_and_tls) & set(tls_12))
            all_ciphers = list(set(cipher_and_tls) - set(protocols)
                               - set(tls_10_tls_11) - set(tls11_cipher)
                               - set(system_none))
            default = system_none[0] + ":" +  ':'.join(all_ciphers + protocols)
            line = split_list[0] + "=\"" + default + "\"\n"
            data[index] = line

    if not getTargetFSS("NDU_Limited_Downtime") or isDisruptiveUpgrade(ctx):
        with open(SYSLOG_CONF, 'w') as f:
            f.writelines(data)
    elif getTargetFSS("NDU_Limited_Downtime"):
        logger.info("Writing the rsyslog config with default TLSv1.2 "
                    "for replication in RDU.")
        RSYSLOG_CONF_REPLICATE = '/etc/rsyslog.conf.replicate'
        with open(RSYSLOG_CONF_REPLICATE, 'w') as f:
            f.writelines(data)

def doPatching(ctx):
    logger.info("Rsyslog config files patching being executed {0}".format(ctx))
    patchSyslogConf(ctx)
    patchRsyslogConf(ctx)

def patchSyslogConf(ctx):
    LOG_LEVEL_REGEX = ('\*.info|\*.notice|\*.warn|\*.error|\*.crit|\*.alert|'
                       '\*.emerg')

    SYSLOG_RELAY_CONF = '/etc/vmware-syslog/syslog.conf'
    SYSLOG_RELAY_CONF_ORIG = '/etc/vmware-syslog/syslog.conf.orig'

    with open(SYSLOG_RELAY_CONF, 'r') as f:
        output = f.read()

    with open(SYSLOG_RELAY_CONF_ORIG, 'w') as f:
        f.write(output)

    # if the ouptut endswith `& stop` already remove it
    output = '\n'.join(list(filter(lambda x: x!= '', output.split('\n'))))
    if output.endswith('& stop'):
        output = '\n'.join(output.split('\n')[:-1])

    output = re.sub(LOG_LEVEL_REGEX, '*.*', output)

    AUTH_PRIV_REGEX = ('authpriv.*\n')
    output = re.sub(AUTH_PRIV_REGEX, '', output)

    log_level_pattern = re.compile(r'\*\.\*')
    pattern = re.compile(r'\$IncludeConfig\ \/etc\/vmware\-syslog\/vmware\-services\-\*\.conf')

    if log_level_pattern.search(output) and (not pattern.search(output)):
        logger.info("Rsyslog imfile configuration will be enabled as remote log server is set")
        replace = '$IncludeConfig /etc/vmware-syslog/vmware-services-*.conf\n*.*'
        LOG_LEVEL_REGEX = ('\*\.\*')
        output = re.sub(LOG_LEVEL_REGEX, replace, output, 1)

    if not getTargetFSS("NDU_Limited_Downtime") or isDisruptiveUpgrade(ctx):
        with open(SYSLOG_RELAY_CONF, 'w') as f:
            f.write(output)
    elif getTargetFSS("NDU_Limited_Downtime"):
        SYSLOG_CONF_REPLICATE = '/etc/vmware-syslog/syslog.conf.replicate'
        with open(SYSLOG_CONF_REPLICATE, 'w') as f:
            f.write(output)

def patchRsyslogConf(ctx):
    SYSLOG_CONF = '/etc/rsyslog.conf'
    SYSLOG_CONF_ORIG = '/etc/rsyslog.conf.orig'

    with open(SYSLOG_CONF, 'r') as f:
        output = f.read()

    with open(SYSLOG_CONF_ORIG, 'w') as f:
        f.write(output)

    MODLOAD_REGEX = '\$ModLoad'
    pattern = re.compile(r'load="imfile"')
    if not pattern.search(output) :
        logger.info("Updating rsyslog.conf to enable imfile mode.")
        replace = ('module(load="imfile" mode="inotify")\n'
                   '$WorkDirectory /var/log/vmware/rsyslogd\n$ModLoad')
        output = re.sub(MODLOAD_REGEX, replace, output, 1)

    auth_priv_pattern = re.compile(r'authpriv\.\*')
    if not auth_priv_pattern.search(output):
        logger.info("Updating rsyslog.conf with authpriv config.")
        authpriv_replace = ('$IncludeConfig /etc/vmware-syslog/syslog.conf\n\n'
                            '$FileCreateMode 0640\n'
                            'authpriv.*   /var/log/audit/sshinfo.log\n'
                            '$FileCreateMode 0644\n')
        include_regex = '\$IncludeConfig\ \/etc\/vmware\-syslog\/syslog\.conf'
        output = re.sub(include_regex, authpriv_replace, output)

    fileCreateModeMatch = re.search(r'\$FileCreateMode [0-9]+', output)
    if not fileCreateModeMatch:
        logger.info("Updating rsyslog.conf with filecreatemode config.")
        filecreatemode_replace = ('$FileCreateMode 0640\n'
                                  'authpriv.*   /var/log/audit/sshinfo.log\n'
                                  '$FileCreateMode 0644\n')
        include_regex = 'authpriv.*   /var/log/audit/sshinfo.log'
        output = re.sub(include_regex, filecreatemode_replace, output)

    # Changing the Max Message size from default 8k to 64k to avoid
    # log truncation.
    WORKDIR_REGEX = '\$WorkDirectory'
    MaxMsgSizePattern = re.compile(r'MaxMessageSize')
    MaxMsgSize64Pattern = re.compile(r'MaxMessageSize 64k')
    MaxMsgSize128Pattern = re.compile(r'MaxMessageSize 128k')
    if not MaxMsgSizePattern.search(output) :
        logger.info("Setting MaxMessageSize in rsyslog.conf to 63k")
        replace = ('$MaxMessageSize 63k\n$WorkDirectory')
        output = re.sub(WORKDIR_REGEX, replace, output, 1)
    elif MaxMsgSize64Pattern.search(output):
        logger.info("Updating rsyslog.conf to set MaxMessageSize to 63k")
        replace = ('MaxMessageSize 63k\n')
        output = re.sub(MaxMsgSize64Pattern, replace, output, 1)
    elif MaxMsgSize128Pattern.search(output):
        logger.info("Updating rsyslog.conf to set MaxMessageSize to 63k")
        replace = ('MaxMessageSize 63k\n')
        output = re.sub(MaxMsgSize128Pattern, replace, output, 1)

    # Syslog TLS config changes to cater latest RPM.

    #Reading the TLS port set by user.
    TlsPortMatch = re.search(r'\$InputTCPServerRun [0-9]+', output)
    if TlsPortMatch:
        TLS_PORT = (TlsPortMatch.group(0).split(' ')[1])
    #Setting the default port for TLS.
    else:
        TLS_PORT = '1514'

    GnuTlsRegEx = re.compile(r'gnutlsprioritystring')
    if not GnuTlsRegEx.search(output) :
        logger.info("updating rsyslog.conf with new TLS configs")
        replace = ('module( load="imtcp"\n'
            '\tstreamdriver.name="gtls"\n'
            '\tstreamdriver.mode="1"\n'
            '\tstreamdriver.authmode="anon"\n'
            '\tgnutlsprioritystring="NONE:+AES-128-GCM:+ECDHE-RSA:+ECDHE-ECDSA:\
+AEAD:+SHA384:+SHA256:+COMP-NULL:+VERS-TLS1.2:+SIGN-RSA-SHA224:+SIGN-RSA-SHA256:\
+SIGN-RSA-SHA384:+SIGN-RSA-SHA512:+SIGN-DSA-SHA224:+SIGN-DSA-SHA256:\
+SIGN-ECDSA-SHA224:+SIGN-ECDSA-SHA256:+SIGN-ECDSA-SHA384:+SIGN-ECDSA-SHA512:\
+CURVE-SECP256R1:+CURVE-SECP384R1:+CURVE-SECP521R1:+CTYPE-OPENPGP:+CTYPE-X509:\
-CAMELLIA-256-CBC:-CAMELLIA-192-CBC:-CAMELLIA-128-CBC:-CAMELLIA-256-GCM:-CAMELLIA-128-GCM"\n'
            '\t)\ninput(type="imtcp" port="'+TLS_PORT+'")\n$ModLoad')
        output = re.sub(MODLOAD_REGEX, replace, output, 1)
    else:
        logger.info("Removing the insecure TLS CBC padding in rsyslog.conf")
        replace = ('gnutlsprioritystring="NONE:+AES-128-GCM:+ECDHE-RSA:+ECDHE-ECDSA:\
+AEAD:+SHA384:+SHA256:+COMP-NULL:+VERS-TLS1.2:+SIGN-RSA-SHA224:+SIGN-RSA-SHA256:\
+SIGN-RSA-SHA384:+SIGN-RSA-SHA512:+SIGN-DSA-SHA224:+SIGN-DSA-SHA256:\
+SIGN-ECDSA-SHA224:+SIGN-ECDSA-SHA256:+SIGN-ECDSA-SHA384:+SIGN-ECDSA-SHA512:\
+CURVE-SECP256R1:+CURVE-SECP384R1:+CURVE-SECP521R1:+CTYPE-OPENPGP:+CTYPE-X509:\
-CAMELLIA-256-CBC:-CAMELLIA-192-CBC:-CAMELLIA-128-CBC:-CAMELLIA-256-GCM:-CAMELLIA-128-GCM"')
        output = re.sub('gnutlsprioritystring.*', replace, output, 1)

    MODLOAD_TLS_REGEX = '\$ModLoad imtcp\.so  \# TLS'
    MODLOAD_TLS_PATTERN = re.compile('\$ModLoad imtcp\.so  \# TLS')
    if MODLOAD_TLS_PATTERN.search(output) :
        output = re.sub(MODLOAD_TLS_REGEX, '', output, 1)

    TLS_STREAM_DRIVER_MODE = '\$InputTCPServerStreamDriverMode 1'
    TlsStreamDriverPattern = re.compile('\$InputTCPServerStreamDriverMode 1')
    if TlsStreamDriverPattern.search(output) :
        output = re.sub(TLS_STREAM_DRIVER_MODE, '', output, 1)

    TLS_STREAM_AUTH = '\$InputTCPServerStreamDriverAuthMode anon'
    TlsStreamAuthPattern = re.compile('\$InputTCPServerStreamDriverAuthMode anon')
    if TlsStreamAuthPattern.search(output) :
        output = re.sub(TLS_STREAM_AUTH, '', output, 1)

    TLS_SERVER_PORT = '\$InputTCPServerRun [0-9]+'
    TlsServerPortPattern = re.compile('\$InputTCPServerRun [0-9]+')
    if TlsServerPortPattern.search(output) :
        output = re.sub(TLS_SERVER_PORT, '', output, 1)

    # PR:[3096783] Enhance log server confguration in vCenter,
    # to avoid duplicate logging.
    IncomingLogSrvPattern = re.compile('\*\.\* \?esxLoc\;esxFmt')
    if IncomingLogSrvPattern.search(output) :
        EnhancedLogSrvConfig = ('if ($hostname != $$myhostname ) then ?esxLoc;esxFmt')
        output = re.sub(IncomingLogSrvPattern, EnhancedLogSrvConfig, output)

    if not getTargetFSS("NDU_Limited_Downtime") or isDisruptiveUpgrade(ctx):
        with open(SYSLOG_CONF, 'w') as f:
            f.write(output)
    elif getTargetFSS("NDU_Limited_Downtime"):
        RSYSLOG_CONF_REPLICATE = '/etc/rsyslog.conf.replicate'
        with open(RSYSLOG_CONF_REPLICATE, 'w') as f:
            f.write(output)
