#!/usr/bin/env python

#
# Copyright 2019 VMware, Inc.  All rights reserved. -- VMware Confidential
# This script contains utility methods which are used by other scripts.
#

import logging
import os
import sys
import warnings
import datetime
import xml.etree.ElementTree as ET

sys.path.append(os.environ['VMWARE_PYTHON_PATH'])
from cis.defaults import get_cis_tmp_dir
from cis.vecs import Service, cli_path, vmafd_machine_id, SsoGroup, VecsEntry
from cis.vpxdevent_lib import get_default_connect_urls
from cis.utils import run_command, FileBuffer, log, log_error
from pyVim import sso
from pyVmomi import vim, VmomiSupport, SoapStubAdapter

sys.path.append('/usr/lib/vmware-vmafd/lib64')

from tempfile import NamedTemporaryFile

# Suppress warnings coming from vmafd and identity imports.
with warnings.catch_warnings():
    warnings.simplefilter('ignore', RuntimeWarning)

sys.path.append(os.path.dirname(__file__))

logger = logging.getLogger(__name__)

# ----
# Creates a new vim.Extension object for the vSphere Client extension.
# [out]: A new vim.Extension spec with populated data.
# [throws]: IOError, ParseError, Exception
# ----
def create_vsphere_client_extension(config_dir):
    logInfo("Creating vSphere Client extension")
    extension_xml_file = os.path.join(config_dir, 'extension.xml')
    extensionXml = ET.parse(extension_xml_file)
    extRoot = extensionXml.getroot().find('extension')

    name = extRoot.find('name').text
    desc = vim.Description(label=(name + '.label'), summary=(name + '.summary'))
    company = extRoot.find('company').text
    ver = extRoot.find('version').text

    # Fill tasks type information.
    tasks = [vim.Extension.TaskTypeInfo(taskID=task.get('id')) for task in
             extRoot.findall('./tasks/task')]

    # Fill faults type information.
    faults = [vim.Extension.FaultTypeInfo(faultID=fault.get('id')) for fault in
              extRoot.findall('./faults/fault')]

    # Fill privileges type information
    privileges = [vim.Extension.PrivilegeInfo(
        privID=privilege.get('group') + "." + privilege.get('id'),
        privGroupName=privilege.get('group')
    ) for privilege in extRoot.findall('./privileges/privilege')]

    # Fill resources information.
    resources = []
    l10n_dir = os.path.join(config_dir, 'l10n')
    subdirs = next(os.walk(l10n_dir))[1]
    for locale in subdirs:
        locale_dir = os.path.join(l10n_dir, locale)
        vmsg_file_names = next(os.walk(locale_dir))[2]
        for vmsg_file_name in vmsg_file_names:
            entries = []
            vmsg_file = os.path.join(locale_dir, vmsg_file_name)
            resource_file_buffer = FileBuffer()
            resource_file_buffer.readFile(vmsg_file)
            for line in resource_file_buffer.getBufferContents():
                if '=' in line:
                    prop = line.split('=')
                    entries.append(
                        vim.KeyValue(key=prop[0].strip(), value=prop[1].strip()))

            resources.append(
                vim.Extension.ResourceInfo(
                    data=entries,
                    locale=locale,
                    module=vmsg_file_name[:-5]
                ))

    return vim.Extension(description=desc, key=name, company=company, version=ver,
                         taskList=tasks, faultList=faults, resourceList=resources,
                         privilegeList=privileges, lastHeartbeatTime=datetime.datetime.now())


# ----
# Registers the vSphere Client extension with the vCenter ExtensionManager.
# Cannot throw exception.
# ----
def register_vsphere_extension(config_dir):
    try:
        logInfo("Registering vSphere Client extension")
        cert_tmp_file = NamedTemporaryFile(dir=get_cis_tmp_dir(), delete=True)
        cert_tmp_file.close()
        crt_key_tmp_file = NamedTemporaryFile(dir=get_cis_tmp_dir(), delete=True)
        crt_key_tmp_file.close()

        vecs_entry = VecsEntry('vpxd-extension')
        cert = cert_tmp_file.name
        vecs_entry.get_cert('vpxd-extension', cert)
        cert_key = crt_key_tmp_file.name
        vecs_entry.get_key('vpxd-extension', cert_key)

        sts_url, vc_endpoints = get_default_connect_urls()
        url = vc_endpoints[0]
        logInfo("Endpoint URL: %s" % str(url))

        authenticator = sso.SsoAuthenticator(sts_url=sts_url)
        token = authenticator.get_hok_saml_assertion(cert, cert_key, token_duration=86400)
        vmodlVersion = VmomiSupport.newestVersions.Get('vim')

        conn = SoapStubAdapter(
            url=url,
            version=vmodlVersion,
            samlToken=token
        )

        serviceInstance = vim.ServiceInstance('ServiceInstance', conn)
        content = serviceInstance.content
        em = content.extensionManager
        sm = content.sessionManager

        if not sm.currentSession:
            with conn.requestModifier(lambda request: sso.add_saml_context(request,
                                      token, cert_key)):
                try:
                    conn.samlToken = token
                    sm.LoginByToken()
                except Exception as err:
                    logError("VC login failure. Exception is %s" % str(err))
                finally:
                    conn.samlToken = None

        ext_spec = create_vsphere_client_extension(config_dir)
        try:
            ext = em.FindExtension(ext_spec.key)
            if ext:
                em.UpdateExtension(extension=ext_spec)
                logInfo("Updated extension %s" % str(ext_spec.key))
            else:
                em.RegisterExtension(extension=ext_spec)
                logInfo("Registered extension %s" % str(ext_spec.key))
        except Exception as ex:
            logError("Failed to Update/Register extension %s: %s" % (ext_spec.key, str(ex)))
        finally:
            sm.Logout()
    except Exception as ex:
        logError("Failed to register extension 'com.vmware.vsphere.client'. Exception is: %s" % str(ex))

def logInfo(message):
    logger.info(message)
    log(message)

def logError(message):
    logger.error(message)
    log_error(message)
