# Copyright 2015-2021 VMware, Inc.
# All rights reserved. -- VMware Confidential
import os
import socket
import subprocess
import logging
import re
import time
import platform
import sys

log = logging.getLogger(__name__)
WIN_ENCODING = 'mbcs'
DEF_ENCODING = 'utf-8'

PY2 = sys.version_info[0] == 2
if not PY2:
    unicode = str

class PatchUtils(object):

    versionRegName = 'Version'

### Encoding utility methods ###
    @staticmethod
    def _encode(obj, encodingFunction):
        '''
        @param obj: Object on which to execute the provided encoding function, if
        it is a dict or list, executes it on the elements inside

        @param encodingFunction: Function to execute on the provided object
        @type function
        '''
        if isinstance(obj, dict):
            result = {}
            for k, v in obj.iteritems():
                result[k] = PatchUtils._encode(v, encodingFunction)
            return result
        elif isinstance(obj, list):
            return [PatchUtils._encode(element, encodingFunction) for element in obj]
        else:
            return encodingFunction(obj)

    @staticmethod
    def _toUnicode(noneUniCodeStr):
        '''Method to encode string to unicode. If it cannot achieve that it will
        return the same string

        @param noneUniCodeStr: String to encode to unicode
        @type str
        '''
        if noneUniCodeStr is None or \
            (PY2 and not isinstance(noneUniCodeStr, str)) or \
            (not PY2 and isinstance(noneUniCodeStr, str)):
            return noneUniCodeStr

        # Try with file system encoding, if it is not that, try with couple more
        # and after that give up
        try:
            return noneUniCodeStr.decode(sys.getfilesystemencoding())
        except (UnicodeDecodeError, LookupError):
            pass

        # if it is windows its more likely to succeed with mbcs, however if
        # that fails we will try with utf-8 and then give up
        if os.name == 'nt':
            try:
                return noneUniCodeStr.decode(WIN_ENCODING)
            except (UnicodeDecodeError, LookupError):
                pass

        # valid ascii is valid utf-8 so no point of trying ascii
        try:
            return noneUniCodeStr.decode(DEF_ENCODING)
        except (UnicodeDecodeError, LookupError):
            pass

        # cannot decode, return same string and leave the system to fail
        log_error('Tried to decode a string to unicode but it wasn\'t successful.'
                  'Expecting system failures')
        return noneUniCodeStr

    @staticmethod
    def _toFileSystemEncoding(obj):
        '''Method to encode unicode to file system encoding. If it cannot achieve
        that it will return the same string

        @param obj: Unicode object to encode using file system encoding
        @type unicode
        '''
        # This is wrapping unicode in file system encoding so that Popen can read it
        if obj is not None and isinstance(obj, unicode):
            obj = obj.encode(sys.getfilesystemencoding())
        return obj



    # Version should be in the form of 'w.x.y.z'
    def normalize_version(self, version):
        if version == None:
            return[0,0,0,0]

        subversions = [int(v) for v in version.split(".")]
        while len(subversions) > 4:
            subversions.pop()
        while len(subversions) < 4:
            subversions.append(0)
        return subversions

    def versionCompare(self, verA, verB):
        if verA == None:
            if verB == None:
                return 0
            return -1
        if verB == None:
            return 1;

        normVerA = self.normalize_version(verA)
        normVerB = self.normalize_version(verB)

        for a, b in zip(normVerA, normVerB):
            if a != b:
                return a - b
        return 0

    # Registry-based versioning
    def getSourceVersion(self):
        ''' Loads current component version from registry.
        '''
        # Windows not supported
        if platform.system().lower() == "windows":
            return None

        dirKeyPath = '[HKEY_THIS_MACHINE\\Services\\vmdir]'

        return self.getRegValue(dirKeyPath, 'Version')


    def setSourceVersion(self,version):
        ''' Persists the latest applied patch version.
        @param version: The latest patch version
        '''
        # Windows not supported
        if platform.system().lower() == "windows":
            return

        keyPath = '[HKEY_THIS_MACHINE\\Services\\vmdir]'

        self.setRegValue(keyPath, 'Version', version)

    def getAdminDn(self):
        keyPath = '[HKEY_THIS_MACHINE\\Services\\vmafd\\Parameters]'

        # build up the Administrator DN
        domainName =  self.getRegValue(keyPath, 'DomainName')

        domainSplit = domainName.split('.')

        adminDn = 'cn=administrator,cn=users'

        for dc in domainSplit:
            adminDn += ',dc=' + dc

        return adminDn

    def validatePassword(self, ssoPassword):
        ''' Validates the password provided by the user
        '''
        cmd = ['/usr/bin/python', os.path.join(os.path.dirname(__file__), 'vmdir_validator.py')]
        ssoPassword = PatchUtils._encode(ssoPassword, PatchUtils._toUnicode)
        ssoPassword = PatchUtils._encode(ssoPassword, PatchUtils._toFileSystemEncoding)
        rc, out, err = self.run_command(cmd, stdin=ssoPassword)
        if rc > 1:
            log.error('Could not validate password due to: %s %s', out, err)
        return rc == 0

    def _validatePassword(self, ssoPassword):
        ''' Validates the password using libvmidentityclientpython available on source
        Don't call that method directly but via the validatePassword, to ensure
        that different pythons are handled correctly
        '''
        sys.path.append("/opt/vmware/lib64")
        import libvmidentityclientpython

        if ssoPassword is None or not isinstance(ssoPassword, (str, unicode)):
            return False

        adminDn = self.getAdminDn()
        ssoPassword = PatchUtils._encode(ssoPassword, PatchUtils._toUnicode)
        ssoPassword = PatchUtils._encode(ssoPassword, PatchUtils._toFileSystemEncoding)

        try:
            ret = libvmidentityclientpython.VmDeployValidateVMDirCredentials_Secure('localhost',
                                                                                adminDn,
                                                                                ssoPassword)
        except libvmidentityclientpython.vmdeploy_exception as e:
            return False
        return ret

    def getRequiredDiskSpace(self):
        raise NotImplementedError()

    def _run_command(self, cmd, stdin=None):
        ''' Method to run commands
        '''
        process = subprocess.Popen(cmd,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE,
                                   stdin=subprocess.PIPE)
        if not PY2:
            # Need to be bytes the stdin
            stdin = PatchUtils._toFileSystemEncoding(stdin)

        stdout, stderr = process.communicate(stdin)
        rc = process.wait()

        if not PY2:
            # Return type is bytes for Python 3+ thus need to cast to str
            stdout = PatchUtils._toUnicode(stdout)
            stderr = PatchUtils._toUnicode(stderr)

        if rc == 0:
            stdout = stdout.rstrip()

        if stderr is not None:
            stderr = stderr.rstrip()
        return rc, stdout, stderr

    def run_command_with_display(self, cmd, cmddisplay, stdin=None):
        """
        execute a command with the given input and return the return code and
        output use this function if cmd contains sensitive information which
        should not be logged.
        """
        log.info("Running command: " + str(cmddisplay))
        return self._run_command(cmd, stdin)

    def run_command(self, cmd, stdin=None):
        """
        execute a command with the given input and return
        the return code and output
        """
        log.info("Running command: " + str(cmd))
        return self._run_command(cmd, stdin)

class PatchUtilsWin(PatchUtils):

    def __init__(self):
        return

class PatchUtilsLin(PatchUtils):
    __VMDIR_CONFIG_PATH = '/usr/lib/vmware-vmdir/share/config'
    __LW_BIN = '/opt/likewise/bin'
    __LW_SBIN = '/opt/likewise/sbin'
    __LW_REG_BIN = __LW_BIN + '/lwregshell'
    __LW_SM_BIN =  __LW_BIN + '/lwsm'
    __LW_SMD_BIN = __LW_SBIN + '/lwsmd'
    __VMDIR_COMPONENT_NAME = 'vmdir'
    __VMDIR_SERVICE_NAME = 'VMwareDirectoryService'
    __VMDIR_PRODUCT_NAME = 'VMware Directory Services'
    __VMDIR_COMPANY_NAME = "VMware, Inc."

    def __init__(self):
         host_name = socket.getfqdn().rstrip('\n')
         return
    def getConfigDir(self):
        return os.path.normpath(self.__VMDIR_CONFIG_PATH)

    def getInstallDir(self):
        install_dir = "/usr/lib/vmware-%s" % self.__VMDIR_COMPONENT_NAME
        return os.path.normpath(install_dir)

    def getVmdirComponentName(self):
        return self.__VMDIR_COMPONENT_NAME

    #Returns the Destination machine disk space (in GB)
    def getRequiredDiskSpace(self):
        return "0.01"

    # Check if reg exists under given keyroot.
    def _has_reg_val_lin(self, valueName, keyRoot):
        command = [self.__LW_REG_BIN, 'list_values', keyRoot]

        (rc, stdout, stderr) = self.run_command_with_display(command, command)

        if rc == 0:
            output = stdout.splitlines()
            for line in output:
                if line.find(valueName) != -1:
                    return True
        return False

    def setRegValue(self, key, valueName, value, regType='REG_SZ'):

        if self._has_reg_val_lin(valueName, key):
            command = [self.__LW_REG_BIN, 'set_value', key, valueName, value]

            (rc, stdout, stderr) = self.run_command_with_display(command, command)

            if rc != 0:
                log.error('Error setting key')
                return

        else:
            command = [self.__LW_REG_BIN, 'add_value', key, valueName, regType, value]

            (rc, stdout, stderr) = self.run_command_with_display(command, command)

            if rc != 0:
                log.error('Error creating key')
                return

        command = [self.__LW_SM_BIN, 'refresh']
        (rc, stdout, stderr) = self.run_command_with_display(command, command)

        if rc != 0:
            log.error('Error refreshing lwsm')
            return

    def getMachineAccountDN(self):
        keyPath = '[HKEY_THIS_MACHINE\\Services\\vmdir]'
        return self.getRegValue(keyPath, 'dcAccountDN')

    def getMachineCredentials(self):
        cmd = ["/opt/likewise/bin/lwregshell",
               "list_values",
               "[HKEY_THIS_MACHINE\\Services\\vmdir]"]
        (rc, stdout, stderr) = self.run_command_with_display(cmd, cmd)
        if rc != 0:
            raise Exception('Unable to get Machine Credentials')

        for line in stdout.splitlines():
            if "REG_SZ" not in line:
                continue
            words = line.split("REG_SZ")
            key = words[0].split("+")[-1].strip()
            value = words[1].strip()
            key = key[1:-1]
            value = value[1:-1]
            if "dcAccount" == key:
                username = value
            if "dcAccountPassword" == key:
                pwd = value
                # special character handling for " and \
                # lwregshell add's extra / before each of them
                pwd = pwd.replace('\\"', '"')
                pwd = pwd.replace('\\\\', '\\')
        return username, pwd

    def getRegValue(self,key,valueName):

        command = [self.__LW_REG_BIN, 'list_values', key]

        (rc, stdout, stderr) = self.run_command_with_display(command, command)

        if rc == 0:
            return self._getRegValue(stdout.splitlines(), valueName)
        else:
            return None

    def _getRegValue(self, output, valueName):
        for line in output:
            if line.find(valueName) != -1:
                line = line.replace("+","")
                return re.findall(r'"(.*?)"',line)[1]

    def start_service(self):
        """
        This functions starts a Service
        """
        log.info('Starting service [%s]' % self.__VMDIR_COMPONENT_NAME)

        # we ignore return code
        command = 'service %sd start' % self.__VMDIR_COMPONENT_NAME

        self.run_command(command.split(' '))

        # wait for 5 minutes maximum
        command = '%s/lwsm status %s' % (self.__LW_BIN , self.__VMDIR_COMPONENT_NAME)
        for n in range(1, 60):
            time.sleep(5)

            (rc, stdout, stderr) = self.run_command(command.split(' '))
            output = stdout.rstrip()
            log.info(output)
            if 'running' in output:
                log.info('Service [%s] started succesfully' % self.__VMDIR_COMPONENT_NAME)
                break
        else:
            raise Exception("Failed to start service : %s" % self.__VMDIR_SERVICE_NAME)

    def stop_service(self):
        """
        Stops the Service and waits for the service to stop or time out happens.
        """
        log.info('Stopping service [%s]' % self.__VMDIR_COMPONENT_NAME)

        command = '%s/lwsm status %s' % (self.__LW_BIN , self.__VMDIR_COMPONENT_NAME)

        (rc, stdout, stderr) = self.run_command(command.split(' '))
        output = stdout.rstrip()
        log.info(output)
        if 'running' not in output:
            log.info('Service [%s] is not running and we are trying to stop it' %
                     self.__VMDIR_COMPONENT_NAME)

        # we ignore return code
        command = 'service %sd stop' % self.__VMDIR_COMPONENT_NAME

        self.run_command(command.split(' '))

        # wait for 5 minutes maximum
        command = '%s/lwsm status %s' % (self.__LW_BIN , self.__VMDIR_COMPONENT_NAME)
        for n in range(1, 60):
            time.sleep(5)

            (rc, stdout, stderr) = self.run_command(command.split(' '))
            output = stdout.rstrip()
            log.info(output)
            if 'stopped' in output:
                break
        else:
            raise Exception("Failed to stop service : %s" % self.__VMDIR_SERVICE_NAME)
        return

    def start_lwsmd(self):
        log.info('Starting service [lwsmd]')

        # See if service is already running...
        command = [self.__LW_SM_BIN, 'status',  'lwreg']
        (rc, stdout, stderr) =  self.run_command(command)
        output = stdout.rstrip()
        if 'ERROR_FILE_NOT_FOUND' in output:
            # Service was not running so start it
            command = [self.__LW_SMD_BIN, '--start-as-daemon']
            (rc, stdout, stderr) =  self.run_command(command)
            # Make sure it starts
            command = [self.__LW_SM_BIN, 'status', 'lwreg']
            for n in range(1, 60):
                time.sleep(5)
                (rc, stdout, stderr) = self.run_command(command)
                output = stdout.rstrip()

                if 'running' in output.lower():
                    log.info('Service lwsmd started succesfully')
                break
            else:
                raise Exception("Failed to start lwsmd")
            return True
        else:
            log.info('Service lwsmd already running')
            return False

    def stop_lwsmd(self):
        log.info('Stopping service [lwsmd]')

        command = [self.__LW_SM_BIN, 'shutdown']
        self.run_command(command)

        command = [self.__LW_SM_BIN, 'list']
        for n in range(1, 60):
            time.sleep(5)
            (rc, stdout, stderr) = self.run_command(command)
            output = stdout.rstrip()

            if 'ERROR_FILE_NOT_FOUND' in output.upper():
                break
        else:
            raise Exception("Failed to stop lwsmd")

    def start_lwsmd_service(self):
        log.info('Starting service "lwsmd"')

        # Check if the service is already running...
        command = ['service-control','--status','lwsmd']
        (rc, stdout, stderr) =  self.run_command(command)
        output = stdout.rstrip()
        if 'stopped' in output.lower():
            # Service is not running, so start it
            command = ['service-control','--start','lwsmd']
            (rc, stdout, stderr) =  self.run_command(command)
            # Make sure it starts
            command = ['service-control','--status','lwsmd']
            for n in range(1, 60):
                time.sleep(5)
                (rc, stdout, stderr) = self.run_command(command)
                output = stdout.rstrip()

                if 'running' in output.lower():
                    log.info('Started the lwsmd service succesfully')
                    break
            else:
                raise Exception("Failed to start lwsmd service")
            return
        else:
            log.info('lwsmd service is running already')
            return



    def is_service_running(self, component_name):
        """
        This functions checks if service is currently running.
        :param component_name:
        """
        command = ['/opt/likewise/bin/lwsm', 'status', component_name]
        (rc, stdout, stderr) = self.run_command(command)
        output = stdout.rstrip()
        if 'running' in output:
            return True
        else:
            return False

    def setSchemaDeleteVal(self, filePath, value):
        if value:
            log.info('Schema deletion is required for this upgrade.')
        else:
            log.info('Schema deletion is NOT required for this upgrade.')
        with open(filePath, "w") as fp:
            fp.write(str(value))

    def getSchemaDeleteVal(self, filePath):
        if os.path.exists(filePath):
            with open(filePath) as fp:
                value = fp.read();
        else:
            log.info('schemaDel.txt file does not exist.')
            value = False
        return value
