# Copyright (c) 2020 VMware, Inc.  All rights reserved. -- VMware Confidential

'''
This file contains authentication related functions. This file is invoked
when trying to authenticate the client for WCP Data Provider running on given
vcHost,
required to make the request for getting license status.
'''
import requests
import os
import sys
import logging

logger = logging.getLogger(__name__)
rootDir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(rootDir, '../../../', 'libs'))
sys.path.append(os.path.join(rootDir, '../../../', 'libs', 'feature-state'))
sys.path.extend(os.environ['VMWARE_VAPI_PYTHONPATH'].split(':'))
from cis.vecs import VecsEntry

from pyVim import sso
from com.vmware.cis_client import Session

from vmware.vapi.lib.connect import get_requests_connector
from vmware.vapi.security.sso import create_saml_security_context
from vmware.vapi.security.session import create_session_security_context
from vmware.vapi.stdlib.client.factories import StubConfigurationFactory
MY_PAYLOAD_DIR = os.path.dirname(__file__)
CERT_PEM_FILEPATH = os.path.join(MY_PAYLOAD_DIR, "certificate.pem")
PKEY_PEM_FILEPATH = os.path.join(MY_PAYLOAD_DIR, "private_key.pem")
DOMAIN_FILE = "/etc/vmware/install-defaults/vmdir.domain-name"

class TokenGenerationError(Exception):
    pass

class WCPAuthenticator(object):
    def __init__(self, vc_host, vc_port):
        self.vc_host = vc_host
        self.vc_port = vc_port
        self._session = None

    def get_sts_authenticator(self, domain_file=DOMAIN_FILE):
        # Get the domain of the setup, by reading its value from file
        # /etc/vmware/install-defaults/vmdir.domain-name
        try:
            with open(domain_file, 'r') as file:
                domain = file.read()
        except Exception as e:
            msg = "Failed to read the domain value of VC from file. " \
                  "Err {}".format(str(e))
            logger.info(msg)
            raise Exception(msg)

        logger.debug("connecting to domain : %s on host: %s "
                     "port : %s for autoupgrade prechecks", domain,
                     self.vc_host, self.vc_port)
        return sso.SsoAuthenticator(
            "https://%s:%s/sts/STSService/%s" % (
                self.vc_host, self.vc_port, domain))

    def create_vapi_stub_config(self, sec_ctx):
        """
        Creates VAPI connection stub using provided security context.
        """
        session = requests.Session()
        session.verify = False
        # Connect to wcpsvc directly (bypassing vapi-endpoint).
        connector = get_requests_connector(
            session, url="https://%s:%s/wcp" % (self.vc_host, self.vc_port),
            timeout=15)
        connector.set_security_context(sec_ctx)
        return StubConfigurationFactory.new_std_configuration(connector)

    def getToken(self):
        '''
        Using vecs to generate certificate.pem and private_key.pem which are
        needed to get the token.
        '''
        sts_auth = self.get_sts_authenticator()
        store = "vpxd-extension"
        alias = "vpxd-extension"
        vecs = VecsEntry(store)
        vecs.get_cert(alias, CERT_PEM_FILEPATH)
        vecs.get_key(alias, PKEY_PEM_FILEPATH)
        return sts_auth.get_hok_saml_assertion(CERT_PEM_FILEPATH,
                                               PKEY_PEM_FILEPATH,
                                               delegatable=True)


    def get_session_token(self):
        '''
        Function to get SAML token; create and return session
        '''
        try:
            if self._session is None:
                token = self.getToken()
                sec_ctx = create_saml_security_context(token, PKEY_PEM_FILEPATH)
                session_svc = Session(self.create_vapi_stub_config(sec_ctx))
                self._session = session_svc.create()
            return self._session
        except Exception as e:
            logger.error(str(e))
            raise TokenGenerationError("Failed to generate SAML token and \
                                        create session")


    def destroy_session(self):
        '''
        Function to destroy and cleanup the current session
        '''
        if self._session is None:
            raise Exception(
                "Tried to destroy session but there's no active session.")
        sec_ctx = create_session_security_context(self._session)
        session_svc = Session(self.create_vapi_stub_config(sec_ctx))
        session_svc.delete()
        self._session = None


    def get_session_security_context(self):
        '''
        Function to retrieve the session security context and get token
        to authenticate the session
        '''
        return create_session_security_context(self.get_session_token())
