# Copyright 2020 - 2021 VMware, Inc.
# All rights reserved. -- VMware Confidential

import os
import sys
import logging
import re
import json
import xml.etree.ElementTree as ET

vmware_python_path = os.getenv('VMWARE_PYTHON_PATH')
if vmware_python_path and os.path.exists(vmware_python_path):
    sys.path.append(vmware_python_path)

from cis.utils import run_command

LW_REG_BIN = '/opt/likewise/bin/lwregshell'
LW_SM_BIN = '/opt/likewise/bin/lwsm'
VMIDENTITY_REG_KEY = '[HKEY_THIS_MACHINE\\Software\\VMware\\Identity\\Configuration]'
PATCH_LEVEL_REG_VALUE_NAME = 'PatchLevel'

RHTTPPROXY_CONFIG_FILE = '/etc/vmware-rhttpproxy/config.xml'

logger = logging.getLogger(__name__)


def is_patch_needed(patch_level):
    patch_to_apply = _to_int(patch_level)
    if not patch_to_apply:
        raise ValueError("Invalid patch level argument specified: %s"
                         % patch_level)

    # Get the current known patch level. If the value is not found in the
    # registry, this is essentially equivalent to patch level "0"
    current_level = _to_int(get_reg_value(VMIDENTITY_REG_KEY,
                                          PATCH_LEVEL_REG_VALUE_NAME), 0)

    msg = "Current patch level is '%d'. Desired patch level is '%d'. "\
          % (current_level, patch_to_apply)
    if current_level < patch_to_apply:
        logger.info(msg + "Patch is needed")
        return True
    else:
        logger.info(msg + "Patch is NOT needed")
        return False


def update_patch_level(patch_level):
    patch_to_apply = _to_int(patch_level)
    if not patch_to_apply:
        raise ValueError("Invalid patch level argument specified: %s"
                         % patch_level)

    # Attempt to write the patch level into the registry
    if set_reg_value(VMIDENTITY_REG_KEY,
                     PATCH_LEVEL_REG_VALUE_NAME,
                     str(patch_to_apply)):
        logger.info("Updated patch level to '%d'" % patch_to_apply)
        return True
    else:
        logger.info("Failed to update patch level to '%d'" % patch_to_apply)
        return False


def get_reg_value(key, value_name):
    command = [LW_REG_BIN, 'list_values', key]

    (rc, stdout, stderr) = run_command(command)

    if rc != 0:
        _log_command_failure(command, stdout, stderr)
        return None

    return _get_reg_value_from_lines(stdout.splitlines(), value_name)


def set_reg_value(key, value_name, value, reg_type='REG_SZ'):
    if _has_reg_value(key, value_name):
        command = [LW_REG_BIN, 'set_value', key, value_name, value]

        (rc, stdout, stderr) = run_command(command)

        if rc != 0:
            _log_command_failure(command, stdout, stderr)
            return False

    else:
        command = [LW_REG_BIN, 'add_value', key, value_name, reg_type, value]

        (rc, stdout, stderr) = run_command(command)

        if rc != 0:
            _log_command_failure(command, stdout, stderr)
            return False

    command = [LW_SM_BIN, 'refresh']
    (rc, stdout, stderr) = run_command(command)
    if rc != 0:
        _log_command_failure(command, stdout, stderr)
        return False

    return True


def _get_reg_value_from_lines(output, value_name):
    for line in output:
        line = line.replace("+", "")
        parts = line.split()
        # Expected columns: value_name reg_type value
        if len(parts) >= 3 and parts[0].find(value_name) != -1:
            # Get the rest of the line starting at parts[2] (reg value)
            value_index = line.find(parts[2])
            value_part = line[value_index:]
            matches = re.findall(r'"(.*)"', value_part)
            if matches:
                return matches[0]
    return None


def _has_reg_value(key, value_name):
    command = [LW_REG_BIN, 'list_values', key]

    (rc, stdout, stderr) = run_command(command)

    if rc != 0:
        _log_command_failure(command, stdout, stderr)
        return False

    if _get_reg_value_from_lines(stdout.splitlines(), value_name):
        return True

    return False


def _log_command_failure(command, stdout, stderr):
    logger.error("Failed to execute command '%s'" % command)
    logger.error(stdout)
    logger.error(stderr)


def _to_int(val, default=None):
    if val is not None:
        try:
            return int(val)
        except ValueError:
            logger.warning("Could not convert '%s' to integer" % val)
    return default


def get_external_idp_configured():
    idp_configured = False
    idp_uri = "http://localhost:%s/rest/vcenter/identity/providers"\
              % _get_http_port()
    command = ["curl", idp_uri]

    (rc, stdout, stderr) = run_command(command)

    if rc != 0:
        logger.error("Failed to invoke REST API at %s" % idp_uri)
        _log_command_failure(command, stdout, stderr)
        return False

    # If at least one identity provider is configured
    # as the default, indicate a successful result
    cmd_out = stdout.rstrip()
    if cmd_out:
        logger.info("providers command output is - %s" % cmd_out)
        try:
            idp_json = json.loads(cmd_out)
        except json.decoder.JSONDecodeError:
            logger.error("Decoding the JSON data has encountered an issue.")
            return False
        if 'majorErrorCode' in idp_json:
            logger.error("Failed to query providers REST API: %s" % cmd_out)
            return False
        for idp in idp_json['value']:
            if 'is_default' in idp and idp['is_default']:
                idp_configured = True
                break
    return idp_configured


def _get_http_port():
    port = '80'    # default http port; return as string to avoid conversion
    try:
        tree = ET.parse(RHTTPPROXY_CONFIG_FILE)
        root = tree.getroot()
        for entry in root:
            if entry.tag == 'proxy':
                for child in entry:
                    if child.tag == 'httpPort':
                        port = child.text
                        logger.info("Found configured HTTP port %s" % port)
                        break
                break
    except Exception as e:
        logger.error("Failed to parse rhttpproxy config file '%s' for "
                     "http port number: %s" % (RHTTPPROXY_CONFIG_FILE, str(e)))

    return port
