#!/usr/bin/env python
#
# Copyright 2019 VMware, Inc.  All rights reserved. -- VMware Confidential
#
"""
This module contains helper functions for connecting to vAPI.
"""
import logging
import os
import requests
import urllib
import time
import warnings
import sys
import importlib

from pyVim import sso
from pyVmomi import lookup
from cis.cisreglib import LookupServiceClient, VmafdClient
from cis.defaults import get_cis_tmp_dir

with warnings.catch_warnings():
    warnings.simplefilter('ignore', RuntimeWarning)
    from identity.vmkeystore import VmKeyStore

logger = logging.getLogger(__name__)

vapiPythonPath = os.environ['VMWARE_VAPI_PYTHONPATH'].split(':')
vapiPythonPath.reverse()
for entry in vapiPythonPath:
    if entry not in sys.path:
        sys.path.insert(0, entry)

try:
    if sys.modules.get('vmware'):
        importlib.reload(sys.modules['vmware'])
    import vmware.vapi
    from vmware.vapi.stdlib.client.factories import StubConfigurationFactory
    from vmware.vapi.security.session import create_session_security_context
    from com.vmware.cis_client import Session
    from vmware.vapi.lib.connect import get_requests_connector
    from vmware.vapi.security.sso import create_saml_security_context
except Exception as ex:
    logger.info("Fail to import vAPI package : %s " % str(ex))

# Fetches and stores the machine certificate and key to a file
# Copied from VsanWcpUtil.py. Modified method name and variable names.
def get_machine_certificate(force_refresh=False):
  certFile = os.path.join(get_cis_tmp_dir(), 'machine.cert')
  keyFile = os.path.join(get_cis_tmp_dir(), 'machine.key')
  if force_refresh \
      or not os.path.exists(certFile) \
      or not os.path.exists(keyFile):
      try:
          with VmKeyStore('VKS') as vks:
              vks.load("machine")
              cert = vks.get_certificate('machine')
              key = vks.get_key('machine')
              for fileName, value in [(certFile, cert), (keyFile, key)]:
                  with open(fileName, 'w') as f:
                      f.write(value)
      except Exception as ex:
          logger.info("Failed to get VC machine certification %s" % str(ex))
  return certFile, keyFile

# Establishes connection to vapi using machine certificate
# Copied from VsanWcpUtil.py. Modified method name and variable names.
# Implemented the additional retry logic, which retries every 20 secs for 5 times, if connection could
# not be established
def connect_vapi():
  def connect_vapi_internal(force_fetch_machine_cert=False):
      session = requests.session()
      session.verify = False
      ls_client = LookupServiceClient(VmafdClient().get_ls_location())
      vapi_url = 'https://localhost/api'
      vapi_endpoints = ls_client.get_service_endpoints(
          svc_typeid="cs.vapi", ep_protocol="vapi.json.https.public",
          ep_type=None, local_nodeid=VmafdClient().get_ldu())
      if (vapi_endpoints[0].url):
          vapi_url = vapi_endpoints[0].url
      logger.info("Connecting to vAPI at URL: %s" % vapi_url)
      connector = get_requests_connector(session=session,url=vapi_url)

      sts_url = ls_client.get_sts_endpoint_data()[0]

      authenticator = sso.SsoAuthenticator(sts_url=sts_url)
      cert_file, key_file = get_machine_certificate(force_refresh=force_fetch_machine_cert)
      token = authenticator.get_hok_saml_assertion(cert_file, key_file, delegatable=True)

      ctx = create_saml_security_context(token, key_file)
      connector.set_security_context(ctx)
      _stub_cfg = StubConfigurationFactory.new_std_configuration(connector)

      session_svc = Session(_stub_cfg)
      ctx = create_session_security_context(session_svc.create())
      connector.set_security_context(ctx)
      stub_cfg = StubConfigurationFactory.new_std_configuration(connector)

      return stub_cfg
  force_refresh_cert = False
  # retry logic which retries the connection 5 times with 20 seconds timeout in between.
  # this is required because vmware-vapi-endpoint service needs additional time to initialize
  # even though service status shows running.
  # here we retry incase of token expiry or HTTPError when service is still starting.
  # from what I could see in the logs, the connection is successful within 2 retries, still keeping the
  # retries to 4 for now.
  errors = []
  for i in range(1, 5):
      try:
          stub = connect_vapi_internal(force_refresh_cert)
          return stub
      except sso.SoapException as ssoEx:
          errors.append(ssoEx)
          if 'FailedAuthentication' in ssoEx._fault_code:
             logger.info("Reconnect vAPI to refresh token")
             force_refresh_cert = True
             logger.info("Failed to connect vAPI. Retrying again with force refresh certificate after 20 seconds.")
             time.sleep(20)
          else:
              logger.info("Failed to connect vAPI : %s" % ssoEx)
      except requests.exceptions.HTTPError as httpEx:
          logger.info("Failed to connect vAPI. Retrying again after 20 seconds.")
          errors.append(httpEx)
          time.sleep(20)
      except Exception as ex:
          logger.info("Failed to connect vAPI - %s." % str(ex))
          errors.append(ex)
  logger.warning("Failed to connect vAPI after 4 attempts.")
  for i in range(1, 5):
      logger.warning("--- Connect vAPI Attempt %s error ---", i, exc_info=errors[i-1])
