# Copyright (c) 2021 VMware, Inc. All rights reserved.
# VMware Confidential

import ldap3 as ldap

import logging
import json
from typing import Tuple
from . import utils

LOCAL_HOST = "localhost"
DOMAIN_CONTROLLERS_DN = "ou=Domain Controllers,{}"
ADMIN_UPN = "Administrator@{}"
ADMIN_DN = "cn=Administrator,cn=Users,{}"
ATTR_ROOT_DOMAIN_NAMING_CONTEXT = "rootDomainNamingContext"

class LdapConnection:
    ldap_domain = None
    ldap_domain_dn = None
    ldap_admin_dn = None
    ldap_account_password = None
    ldap_machine_account = None
    ldap_admin_upn = None
    domain_controllers_dn = None
    use_machine_account = False

    def __init__(self, ldap_server=LOCAL_HOST,
                 password=None, use_machine_account=False) -> None:
        self.logger = logging.getLogger(self.__class__.__name__)
        self._get_domain_info(ldap_server)
        self.ldap_account_password = password
        self.use_machine_account = use_machine_account

    def open_anonymous_ldap_connection(self, node) -> ldap.Connection:
        """
        Opens anonymous connection to the ldap server
        we can only read the DSE root information of the server
        We use default synchronous strategy
        :param node: hostname of the ldap server
        :return:
        """

        self.logger.debug("Opening Anonymous Connection "
                          "to {}".format(node))
        ldap_connection = None
        try:
            server = ldap.Server(self.get_uri_from_hostname(node),
                                 get_info=ldap.ALL)
            ldap_connection = ldap.Connection(server)
            ldap_connection.bind()
            return ldap_connection
        except ldap.core.exceptions.LDAPException as e:
            self.logger.error("Error Opening Anonymous "
                              "Connection to {}".format(node))
            self.logger.debug("Exception: {}".format(str(e)))
        return ldap_connection

    def open_ldap_connection(self, node) -> ldap.Connection:
        """
        Open ldap connection to the ldap server provided with
        ldap admin user credentials.
        :param node: hostname of the ldap server
        :return:
        """
        self.logger.debug("Opening Connection to {}".format(node))
        ldap_connection = None
        try:
            server = ldap.Server(self.get_uri_from_hostname(node),
                                 get_info=ldap.ALL)
            ldap_user = self.ldap_admin_dn
            if self.use_machine_account:
                ldap_user = self.ldap_machine_account

            ldap_connection = ldap.Connection(server,
                                              user=ldap_user,
                                              password=self.
                                              ldap_account_password)
            ldap_connection.bind()
            return ldap_connection
        except ldap.core.exceptions.LDAPException as e:
            self.logger.error("Error Opening Connection to {}".format(node))
            self.logger.debug("Exception: {}".format(str(e)))
        return ldap_connection

    def close_ldap_connection(self, ldap_connection) -> None:
        """
        Close the ldap bind connection
        :param ldap_connection:ldap connection to be closed
        :return:
        """
        self.logger.debug("Closing LDAP connection")
        ldap_connection.unbind()
        if ldap_connection.result["result"] != 0:
            self.logger.error("Error closing connection. Error Msg: %s",
                              ldap_connection.result["message"])

    def ldap_search(self, ldap_connection, base_dn, ldap_filter,
                    ldap_scope, ldap_attributes) -> bool:
        """
        This method takes the ldap connection, base dn, filter , scope
        and list of ldap attributes to be returned.
        :param ldap_connection  connection to ldap server
        :param base_dn dn where the search starts
        :param ldap_filter to filter the search the results
        :param ldap_scope scope of the search
        :param ldap_attributes list of ldap attributes to be returned
        """
        self.logger.debug("LDAP Search with \n Base DN:%s \n"
                          "Filter: %s\nScope: %s\nAttributes: %s",
                          base_dn, ldap_filter, str(ldap_scope),
                          str(" ".join(ldap_attributes)))
        if ldap_connection is None:
            self.logger.error("No ldap connection")
        else:
            ldap_connection.bind()
            result = ldap_connection.search(base_dn, ldap_filter,
                                            ldap_scope,
                                            attributes=ldap_attributes)
            if result:
                return True
        self.logger.error("LDAP Search failed.Error Message: %s",
                          ldap_connection.result["message"])
        return False

    def ldap_modify_single_value_attr(self, ldap_connection, dn, ATTR, new_value, operation) -> bool:
        """
        This method takes the ldap connection, entry data like dn, attributes,
        new value, operation to perform and  modify's it to the ldap database.
        This method modifies single attribute at a time.
        :param ldap_connection:
        :param dn: dn of the entry to be modified
        :param new_value: new_value to be changed
        :param ATTR: Attribute for which value is to be modified
        :return:
        """
        self.logger.debug("LDAP MODIFY with \n DN :%s", dn)
        if ldap_connection is None:
            self.logger.error("No ldap connection")
        else:
            ldap_connection.bind()
            result = ldap_connection.modify(
                dn, {ATTR: [(operation, [new_value])]})
            self.logger.debug("ldap_connection.result Message: %s",
                              ldap_connection.result["message"])
            if result:
                self.logger.debug("Modified DN:%s", dn)
                return True
            else:
                self.logger.error("Error Modifying User DN:{}.Error Message:{}"
                                  .format(dn,
                                          ldap_connection.result["message"]))
        return False

    def ldap_add(self, ldap_connection, dn, object_class, attributes) -> bool:
        """
        This method takes the ldap connection, entry data like dn, attributes,
         objectclass and add's it to the ldap database
         :param ldap_connection:
        :param dn: dn of the entry to be added
        :param object_class: list of objectclasses
        :param attributes:  dictionary of key value pairs where a value can be
        a single or list
        :return:
        """
        self.logger.debug("LDAP Add with \n DN: %s\nObjectClass:%s\n"
                          "Attributes:%s", dn, str(" ".join(object_class)),
                          json.dumps(attributes))
        if ldap_connection is None:
            self.logger.error("No ldap connection")
        else:
            ldap_connection.bind()
            if ldap_connection.add(dn=dn, object_class=object_class,
                                   attributes=attributes):
                self.logger.info("Added DN:%s", dn)
                return True
            else:
                self.logger.error("Error Adding User DN:{}.Error Message:{}"
                                  .format(dn,
                                          ldap_connection.result["message"]))
        return False

    def ldap_delete(self, ldap_connection, dn) -> bool:
        """
        This method  takes the ldap connection and entry dn and deletes
        that entry from ldap server
        :param ldap_connection: ldap connection
        :param dn: entry dn to be delete
        :return:
        """
        self.logger.debug("LDAP delete with DN: %s", dn)
        if ldap_connection is None:
            self.logger.error("No ldap connection")
        else:
            ldap_connection.bind()
            if ldap_connection.delete(dn):
                self.logger.info("Deleted entry DN: %s", dn)
                return True
            else:
                self.logger.error("Error deleting entry DN: {}".format(dn))
                return False

    @staticmethod
    def get_attribute(ldap_entry, attribute) -> str:
        """
        get the attribute value for a given ldap entry
        This function can only be used for single value attributes.
        :param ldap_entry: ldap entry
        :param attribute: attribute to be returned
        :return: value which is string
        """
        val = ldap_entry['attributes'][attribute]
        if isinstance(val, str):
            return val
        return val[0]

    @staticmethod
    def get_uri_from_hostname(hostname) -> str:
        """
        Converts the hostname to LdapUri format
        :param hostname:
        :return:
        """
        return "ldap://{}".format(hostname)

    def _get_domain_info(self, ldap_server) -> None:
        """
        This method populates the ldap server information like domain name,
        domain upn, server dn, server dn
        :param ldap_server: ldap server or VMDir server
        :return:
        """
        self.logger.debug("Populating LDAP Server information")
        ldap_connection = self.open_anonymous_ldap_connection(ldap_server)
        ldap_domain_dn, ldap_domain = \
            self._get_server_info(ldap_connection)
        self.ldap_domain_dn = ldap_domain_dn
        self.ldap_domain = ldap_domain
        self.domain_controllers_dn = \
            DOMAIN_CONTROLLERS_DN.format(self.ldap_domain_dn)
        self.ldap_admin_dn = ADMIN_DN.format(self.ldap_domain_dn)
        self.ldap_machine_account = utils.PatchUtilsLin().getMachineAccountDN()
        self.ldap_admin_upn = ADMIN_UPN.format(ldap_domain)
        self.close_ldap_connection(ldap_connection)

    def _get_server_info(self, ldap_connection) -> Tuple[str, str]:
        """
        This method retrieves ldap server information from DSE root
        by binding to the server anonymously
        :param ldap_connection: ldap connection to the server
        :return:
        """
        self.logger.debug("Getting Server information from DSE root")
        if self.ldap_search(ldap_connection, "", "(objectClass=*)",
                            ldap.BASE,
                            [ATTR_ROOT_DOMAIN_NAMING_CONTEXT]):
            domain_dn = self.get_attribute(ldap_connection.response[0],
                                           ATTR_ROOT_DOMAIN_NAMING_CONTEXT)
            domain_name = '.'.join(list(map(lambda val: val.split('=')[1],
                                            domain_dn.split(','))))
            return domain_dn, domain_name
