# Copyright 2015-2023 VMware, Inc.
# All rights reserved. -- VMware Confidential

import logging
import re
import platform
import time
import os_utils
import os
import sys
import subprocess
import vcsa_utils
from pwd import getpwnam

log = logging.getLogger(__name__)

PY2 = sys.version_info[0] == 2
if not PY2:
    unicode = str

class PatchUtils:

    versionRegName = 'Version'

    @staticmethod
    def _toUnicode(noneUniCodeStr):
        '''Method to encode string to unicode. If it cannot achieve that it will
        return the same string

        @param noneUniCodeStr: String to encode to unicode
        @type str
        '''
        if noneUniCodeStr is None or \
            (PY2 and not isinstance(noneUniCodeStr, str)) or \
            (not PY2 and isinstance(noneUniCodeStr, str)):
            return noneUniCodeStr

        # Try with file system encoding, if it is not that, try with couple more
        # and after that give up
        try:
            return noneUniCodeStr.decode(sys.getfilesystemencoding())
        except (UnicodeDecodeError, LookupError):
            pass

        # valid ascii is valid utf-8 so no point of trying ascii
        try:
            return noneUniCodeStr.decode(DEF_ENCODING)
        except (UnicodeDecodeError, LookupError):
            pass

        # cannot decode, return same string and leave the system to fail
        log_error('Tried to decode a string to unicode but it wasn\'t successful.'
                  'Expecting system failures')
        return noneUniCodeStr

    @staticmethod
    def _toFileSystemEncoding(obj):
        '''Method to encode unicode to file system encoding. If it cannot achieve
        that it will return the same string

        @param obj: Unicode object to encode using file system encoding
        @type unicode
        '''
        # This is wrapping unicode in file system encoding so that Popen can read it
        if obj is not None and isinstance(obj, unicode):
            obj = obj.encode(sys.getfilesystemencoding())
        return obj

    def run_command(self, cmd, stdin=None):
        ''' Method to run commands
        '''
        process = subprocess.Popen(cmd,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE,
                                   stdin=subprocess.PIPE)
        if not PY2:
            # Need to be bytes the stdin
            stdin = PatchUtils._toFileSystemEncoding(stdin)

        stdout, stderr = process.communicate(stdin)
        rc = process.wait()

        if not PY2:
            # Return type is bytes for Python 3+ thus need to cast to str
            stdout = PatchUtils._toUnicode(stdout)
            stderr = PatchUtils._toUnicode(stderr)

        if rc == 0:
            stdout = stdout.rstrip()

        if stderr is not None:
            stderr = stderr.rstrip()
        return rc, stdout, stderr

    # Version should be in the form of 'w.x.y.z'
    def normalize_version(self, version):
        if version == None:
            return[0,0,0,0]

        subversions = [int(v) for v in version.split(".")]
        while len(subversions) > 4:
            subversions.pop()
        while len(subversions) < 4:
            subversions.append(0)
        return subversions

    def versionCompare(self, verA, verB):
        if verA == None:
            if verB == None:
                return 0
            return -1
        if verB == None:
            return 1;

        normVerA = self.normalize_version(verA)
        normVerB = self.normalize_version(verB)

        for a, b in zip(normVerA, normVerB):
            if a != b:
                return a - b
        return 0

    # Registry-based versioning
    def getSourceVersion(self):
        ''' Loads current component version from registry.
        '''
        dirKeyPath = '[HKEY_THIS_MACHINE\\Services\\vmafd]'

        return self.getRegValue(dirKeyPath, 'Version')

    def setSourceVersion(self,version):
        ''' Persists the latest applied patch version.
        @param version: The latest patch version
        '''
        keyPath = '[HKEY_THIS_MACHINE\\Services\\vmafd]'

        self.setRegValue(keyPath, 'Version', version)

    def getRequiredDiskSpace(self):
        raise NotImplementedError()

    def getDomainName(self):
        domainName = ''
        cmd = ['/usr/lib/vmware-vmafd/bin/vmafd-cli', 'get-domain-name', '--server-name', 'localhost']
        (stdout, stderr, rc) = os_utils.executeCommand(cmd)
        if rc == 0:
            domainName = stdout.rstrip()
            log.info("Domain name is %s" % domainName)
        else:
            log.error(stdout)
            log.error(stderr)
            raise Exception("Failed to get DomainName: "
                            "Couldn't run command get-domain-name")
        return domainName

    def getDCNameEx(self):
        dcName = ''
        dcNameExCmd = ['/usr/lib/vmware-vmafd/bin/vmafd-cli', 'get-dc-name-ex', '--server-name', 'localhost']
        (stdout, stderr, rc) = os_utils.executeCommand(dcNameExCmd)
        if rc == 0:
            dcName = stdout.rstrip()
        else:
            log.info(rc)
            log.error(stderr)
            raise Exception("Failed to get DC Name!: "
                            "Couldn't run command get-dc-name-ex")

        log.info("DC Name is %s" % dcName)
        return dcName

    def getDCName(self):
        dcName = ''
        dcNameCmd = ['/usr/lib/vmware-vmafd/bin/vmafd-cli', 'get-dc-name', '--server-name', 'localhost']
        (stdout, stderr, rc) = os_utils.executeCommand(dcNameCmd)
        if rc == 0:
            dcName = stdout.rstrip()
        else:
            log.info(rc)
            log.error(stderr)
            raise Exception("Failed to get DC Name!: "
                            "Couldn't run command get-dc-name")

        log.info("DC Name is %s" % dcName)
        return dcName

    def getAdminUPN(self):
        domainName = self.getDomainName()
        adminUPN = 'administrator@' + domainName
        return adminUPN.lower()

    def validatePassword(self, ssoPassword):
        ''' Validates the password provided by the user
        '''
        serverName = self.getDCName()
        adminUPN = self.getAdminUPN()

        cmd = ['ldapsearch',
               '-h',
               serverName,
               '-s',
               'base',
               '-Y',
               'SRP',
               '-U',
               adminUPN,
               '-y',
               '/dev/fd/0']
        (rc, out, err) = self.run_command(cmd, stdin=ssoPassword)
        if rc != 0:
            log.error('Could not validate password due to: %s %s', out, err)
        return rc == 0

class PatchUtilsLin(PatchUtils):
    __IS_MGMT_NODE = False
    __IS_GATEWAY = False
    __LW_BIN = '/opt/likewise/bin'
    __LW_SBIN = '/opt/likewise/sbin'
    __LW_REG_BIN = __LW_BIN + '/lwregshell'
    __LW_SM_BIN =  __LW_BIN + '/lwsm'
    __LW_SMD_BIN = __LW_SBIN + '/lwsmd'
    __VMAFD_COMPONENT_NAME = 'vmafd'
    __VMAFD_SERVICE_NAME = 'VMWareAfdService'
    __VMAFD_PRODUCT_NAME = 'VMware Afd Services'
    __VMAFD_COMPANY_NAME = "VMware, Inc."

    def __init__(self):
        nodeType = vcsa_utils.getDeploymentType()
        if not nodeType:
            log.error('Node type recevied is invalid')
            raise Exception("Failed to get node type")
        if nodeType == 'management':
            self.__IS_MGMT_NODE = True
        elif nodeType == 'embedded':
            if not os.path.exists('/usr/lib/vmware-vmdir/sbin/vmdird'):
                self.__IS_GATEWAY = True
        return

    def is_mgmt_node(self):
        return self.__IS_MGMT_NODE

    def is_gateway(self):
        return self.__IS_GATEWAY

    def is_infra_or_embedded_node(self):
        return (not self.__IS_MGMT_NODE) and (not self.__IS_GATEWAY)

    def update_repoint_config(self, val):
        """
        This function set/unset registry value that controls repointing
        """
        LwIsStarted = self.start_lwsmd()
        try:
            vmafdRegKey = '[HKEY_THIS_MACHINE\\Services\\vmafd]'
            regVal = 'EnableAutoRepoint'
            rc = self.setRegValue(vmafdRegKey, regVal, val, regType='REG_DWORD')
            if rc != 0:
                raise Exception("Failed to update repoint configuration")
        finally:
            if LwIsStarted:
                self.stop_lwsmd()

    def start_service(self, service=None, component=None):
        """
        This functions starts a Service
        """
        if service is None:
            service=self.__VMAFD_SERVICE_NAME
        if component is None:
            component=self.__VMAFD_COMPONENT_NAME

        log.info('Starting service [%s]' % component)

        # we ignore return code
        command = 'service %sd start' % component

        os_utils.executeCommand(command.split(' '))

        # wait for 5 minutes maximum
        command = '%s/lwsm status %s' % (self.__LW_BIN , component)
        for n in range(1, 600):
            (stdout, stderr, rc) = os_utils.executeCommand(command.split(' '))
            output = stdout.rstrip()
            log.info(output)
            if 'running' in output:
                log.info('Service [%s] started succesfully' % component)
                break
            time.sleep(0.5)
        else:
            raise Exception("Fail to start service : %s" % service)

    def stop_service(self, service=None, component=None):
        """
        This functions stops a Service
        """
        if service is None:
            service=self.__VMAFD_SERVICE_NAME
        if component is None:
            component=self.__VMAFD_COMPONENT_NAME

        log.info('Stopping service [%s]' % component)

        # we ignore return code
        command = 'service %sd stop' % component

        os_utils.executeCommand(command.split(' '))

        # wait for 5 minutes maximum
        command = '%s/lwsm status %s' % (self.__LW_BIN , component)
        for n in range(1, 600):
            (stdout, stderr, rc) = os_utils.executeCommand(command.split(' '))
            output = stdout.rstrip()
            log.info(output)
            if 'stopped' in output:
                log.info('Service [%s] stopped succesfully' % component)
                break
            time.sleep(0.5)
        else:
            raise Exception("Failed to stop service : %s" % service)

    #Returns the Destination machine disk space (in GB)
    def getRequiredDiskSpace(self):
        return "0.01"

    # Check if reg exists under given keyroot.
    def _has_reg_val_lin(self, valueName, keyRoot):
        command = [self.__LW_REG_BIN, 'list_values', keyRoot]

        (stdout, stderr, rc) = os_utils.executeCommand(command)

        if rc == 0:
            output = stdout.splitlines()
            for line in output:
                if line.find(valueName) != -1:
                    return True
        return False

    def setRegValue(self, key, valueName, value, regType='REG_SZ'):
        rc = 0
        if self._has_reg_val_lin(valueName, key):
            command = [self.__LW_REG_BIN, 'set_value', key, valueName, value]
            (stdout, stderr, rc) = os_utils.executeCommand(command)

            if rc != 0:
                log.error('Error setting key')
                return rc
        else:
            command = [self.__LW_REG_BIN, 'add_value', key, valueName, regType, value]
            (stdout, stderr, rc) = os_utils.executeCommand(command)

            if rc != 0:
                log.error('Error creating key')
                return rc
        return rc

    def getRegValue(self,key,valueName):

        command = [self.__LW_REG_BIN, 'list_values', key]

        (stdout, stderr, rc) = os_utils.executeCommand(command)

        if rc == 0:
            return self._getRegValue(stdout.splitlines(), valueName)
        else:
            return None

    def _getRegValue(self, output, valueName):
        for line in output:
            if line.find(valueName) != -1:
                line = line.replace("+","")
                return re.findall(r'"(.*?)"',line)[1]

    def deleteRegTree(self, keyName, ignoreError=True):
        rc = 0
        command = [self.__LW_REG_BIN, 'delete_tree', keyName]
        (stdout, stderr, rc) = os_utils.executeCommand(command)
        if rc != 0:
            if ignoreError:
                log.info("Unable to delete registry tree for key %s, Error ignored" % keyName)
                rc = 0
            else:
                log.error("Unable to delete registry tree for key %s" % keyName)
        return rc

    def is_service_running(self, serviceName):
        rc = 0
        command = ['service-control', '--status',  serviceName]
        (stdout, stderr, rc) =  os_utils.executeCommand(command)
        if rc != 0:
            log.error('Failed to get status of service %s ' % serviceName)
            raise Exception('Failed to get service status')

        output = stdout.rstrip()
        if 'Running' in output:
            return True
        return False

    def start_svc(self, serviceName):
        rc = 0
        command = ['service-control', '--start',  serviceName]
        (stdout, stderr, rc) =  os_utils.executeCommand(command)
        if rc != 0:
            log.error('Failed to start service %s ' % serviceName)
            raise Exception('Failed to start service %s ' % serviceName)

        log.info('Service %s started succesfully' % serviceName)

    def restart_svc(self, serviceName):
        rc = 0
        command = ['service-control', '--restart',  serviceName]
        (stdout, stderr, rc) =  os_utils.executeCommand(command)
        if rc != 0:
            log.error('Failed to restart service %s ' % serviceName)
            raise Exception('Failed to restart service %s ' % serviceName)

        log.info('Service %s restarted succesfully' % serviceName)

    def cleanup_vmdns_rpms(self):
        command = ['rpm', '-e',  'vmware-dns-server', '--noscripts']
        os_utils.executeCommand(command)

        command = ['rpm', '-e',  'vmware-dns-client', '--noscripts']
        os_utils.executeCommand(command)

    def start_lwsmd(self):
        log.info('Starting service [lwsmd]')

        # See if service is already running...
        command = [self.__LW_SM_BIN, 'status',  'lwreg']
        (stdout, stderr, rc) =  os_utils.executeCommand(command)
        output = stdout.rstrip()
        if 'ERROR_FILE_NOT_FOUND' in output:
            # Service was not running so start it
            command = [self.__LW_SMD_BIN, '--start-as-daemon']
            (stdout, stderr, rc) =  os_utils.executeCommand(command)
            # Make sure it starts
            command = [self.__LW_SM_BIN, 'status', 'lwreg']
            for n in range(1, 60):
                time.sleep(5)
                (stdout, stderr, rc) = os_utils.executeCommand(command)
                output = stdout.rstrip()

                if 'running' in output.lower():
                    log.info('Service lwsmd started succesfully')
                break
            else:
                raise Exception("Failed to start lwsmd")
            return True
        else:
            log.info('Service lwsmd already running')
            return False

    def start_lwsmd_by_service_control(self):
        log.info('Starting service [lwsmd] by service control')

        # Call service control unconditionally, it doesn't matter if lwsmd is already running or not.
        self.start_svc('lwsmd')

    def stop_lwsmd(self):
        log.info('Stopping service [lwsmd]')

        command = [self.__LW_SM_BIN, 'shutdown']
        os_utils.executeCommand(command)

        command = [self.__LW_SM_BIN, 'list']
        for n in range(1, 60):
            time.sleep(5)
            (stdout, stderr, rc) = os_utils.executeCommand(command)
            output = stdout.rstrip()

            if 'ERROR_FILE_NOT_FOUND' in output.upper():
                break
        else:
            raise Exception("Failed to stop lwsmd")

    ''' Block till vmafdd status is set to Running '''
    def wait_for_vmafdd_running_state(self):
        command = ['/usr/lib/vmware-vmafd/bin/vmafd-cli', 'get-status', '--server-name', 'localhost']
        stdout = None
        stderr = None
        for count in range(1, 300): #5 min retry
            (stdout, stderr, rc) = os_utils.executeCommand(command)
            if rc == 0:
                status = stdout.rstrip()
                if status == "Running":
                    log.info('vmafdd service is running')
                    break
            time.sleep(1)
        else:
            log.error("vmafdd service running state check timeout")
            log.error(stdout)
            log.error(stderr)

    def recursive_ownership(self, path, mode, uid, gid):
        for dirpath, dirnames, filenames in os.walk(path):
            os.chmod(dirpath, mode)
            os.chown(dirpath, uid, gid)
            for filename in filenames:
                try:
                    os.chmod(os.path.join(dirpath, filename), mode)
                    os.chown(os.path.join(dirpath, filename), uid, gid)
                #Until PR2912801 is resolved, it will be necessary to ignore the broken symbolic links
                except FileNotFoundError :
                    log.warn("Broken symbolic link detected %s" % os.path.join(dirpath, filename))

    def update_path_ownership(self):
        vmafd_user = "vmafdd-user"
        vmafd_uid = getpwnam(vmafd_user).pw_uid
        vmafd_gid = getpwnam(vmafd_user).pw_gid
        root_uid = 0

        vmafdd_data_path = "/var/lib/vmware/vmafdd_data"
        if os.path.exists(vmafdd_data_path):
            os.chmod(vmafdd_data_path, 0o700)
            os.chown(vmafdd_data_path, vmafd_uid, vmafd_gid)

        machine_ssl_crt_path = "/var/lib/vmware/vmafdd_data/machine-ssl.crt"
        if os.path.exists(machine_ssl_crt_path):
            os.chmod(machine_ssl_crt_path, 0o600)
            os.chown(machine_ssl_crt_path, vmafd_uid, vmafd_gid)

        machine_ssl_key_path = "/var/lib/vmware/vmafdd_data/machine-ssl.key"
        if os.path.exists(machine_ssl_key_path):
            os.chown(machine_ssl_key_path, vmafd_uid, vmafd_gid)
            os.chmod(machine_ssl_key_path, 0o600)

        etc_ssl_certs_path = "/etc/ssl/certs"
        etc_ssl_certs_path_mode = 0o775
        if os.path.exists(etc_ssl_certs_path):
            self.recursive_ownership(etc_ssl_certs_path, etc_ssl_certs_path_mode, root_uid, vmafd_gid)

        krb5_lotus_conf_path = "/var/lib/vmware/vmafdd_config/krb5/krb5.lotus.conf"
        if os.path.exists(krb5_lotus_conf_path):
            os.chown(krb5_lotus_conf_path, vmafd_uid, vmafd_gid)
            os.chmod(krb5_lotus_conf_path, 0o644)

        krb5_keytab_path = "/var/lib/vmware/vmafdd_config/krb5/krb5.keytab"
        if os.path.exists(krb5_keytab_path):
            os.chown(krb5_keytab_path, vmafd_uid, vmafd_gid)
            os.chmod(krb5_keytab_path, 0o770)

        certool_cfg_path = "/var/lib/vmware/vmca_config/certool.cfg"
        vmcad_user = "vmcad-user"
        vmcad_uid = getpwnam(vmcad_user).pw_uid
        vmcad_gid = getpwnam(vmcad_user).pw_gid
        if os.path.exists(certool_cfg_path):
            os.chown(certool_cfg_path, vmcad_uid, vmcad_gid)
            os.chmod(certool_cfg_path, 0o755)


    def upgrade_vmafd_reg_tree(self):
        status = 0
        command = [self.__LW_REG_BIN, 'upgrade', '/usr/lib/vmware-vmafd/share/config/vmafd.reg']
        (stdout, stderr, status) = os_utils.executeCommand(command)
        if status != 0:
            raise Exception("Unable to upgrade registry tree , Error: %s", stderr)

    def set_vmafd_lwreg(self, key, value):
        param = "[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters]"
        current_value = self.getRegValue(param, key)
        log.info("current %s value set to %s " % (key, current_value))
        if current_value != value:
            log.info("setting %s value to %s " % (key, value))
            status = self.setRegValue(param, key, value)
            if status != 0:
                raise Exception("Updating %s value failed." % key)
