# Copyright 2020-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
# coding: utf-8

import os
import logging

from .utils import get_psql_connection, get_db_properties, _execute_sql_file, \
    setup_logging

logger = logging.getLogger(__name__)
setup_logging(logger)

SQL_FOLDER = os.path.abspath(os.path.join(os.path.dirname(__file__), "sql"))
FIRST_SCRIPT = os.path.join(SQL_FOLDER, "TrackingTables.sql")
LAST_SCRIPT = os.path.join(SQL_FOLDER, "EndUpgrade.sql")
REVERT_SCRIPT = os.path.join(SQL_FOLDER, "RevertUpgrade.sql")


EXPAND_OPERATION = "EXPAND"
CONTRACT_OPERATION = "CONTRACT"

EXPAND = "expand.sql"
CONTRACT = "contract.sql"
REVERT = "revert.sql"

FSS_PROTECTED_VERSIONS = {
}


def execute_expand(fssValues):
    ''' Execute the expand phase of the upgrade. If there are scripts that are
    already applied it will skip them.
    '''
    logger.info("Expanding database.")
    # Sets the tracking for the upgrade tables
    _execute_sql_file(FIRST_SCRIPT)

    expandScripts, _ = _calculate_version_delta()

    for (id, script) in expandScripts:
        logger.info("Found ID(%s) needing execution.", id)
        if id in FSS_PROTECTED_VERSIONS and \
            not fssValues.get(FSS_PROTECTED_VERSIONS[id], False):
            logger.info("The FSS for the given script is OFF skipping it.")
            continue
        _execute_sql_file(script, \
            params=_get_upgrade_tracking_params(id, EXPAND_OPERATION))

def execute_contract(fssValues):
    ''' Execute the contract phase of the upgrade. If there are scripts that are
    already applied it will skip them. It will finish the upgrade workflow.
    '''
    logger.info("Contracting database.")
    # Makes the contract idempotent
    _execute_sql_file(FIRST_SCRIPT)
    _, contractScripts = _calculate_version_delta()

    for (id, script) in contractScripts:
        if id in FSS_PROTECTED_VERSIONS and \
            not fssValues.get(FSS_PROTECTED_VERSIONS[id], False):
            logger.info("The FSS for the given script is OFF skipping it.")
            continue
        _execute_sql_file(script, \
            params=_get_upgrade_tracking_params(id, CONTRACT_OPERATION))

    # Need to commit all the upgrade specific tracking IDs to the master tracker
    # for the next upgrades to be able to work.
    _execute_sql_file(LAST_SCRIPT)

def execute_revert():
    logger.info("Reverting expand phase")

    needingRevert = _calculate_revert_delta()
    for (id, script) in needingRevert:
        logger.info("Found ID(%s) needing execution.", id)
        # Params are passed as it needs to drop them from the tracking table
        _execute_sql_file(script, \
            params=_get_upgrade_tracking_params(id, EXPAND_OPERATION))
    #Upgrade the version tracking info.
    _execute_sql_file(LAST_SCRIPT)

def _get_upgrade_tracking_params(id, operation):
    ''' This is to be provided to the scripts to correctly update their work
    '''
    return {"ndu_id": id, "ndu_operation": "'%s'" % operation}

def _calculate_version_delta():
    ''' Calculates missing scripts to move to the new version
    and provides their files
    '''

    appliedAlready = _find_all_applied_sql()
    available = find_sql_dirs()

    needingExpand = []
    needingContract = []

    for (id, a) in available.items():
        if has_expand_phase(a) and id not in appliedAlready.get(EXPAND_OPERATION, []):
            needingExpand.append((id, os.path.join(a, EXPAND)))
        elif id not in appliedAlready.get(EXPAND_OPERATION, []):
            logger.debug("Expand phase for ID(%s) already done", id)
        else:
            logger.debug("NO Expand phase for ID(%s)", id)

        if has_contract_phase(a) and id not in appliedAlready.get(CONTRACT_OPERATION, []):
            needingContract.append((id, os.path.join(a, CONTRACT)))
        elif id not in appliedAlready.get(CONTRACT_OPERATION, []):
            logger.debug("Contract phase for ID(%s) already done", id)
        else:
            logger.debug("NO Contract phase for ID(%s)", id)

    # Ensure we execute the smallers first
    needingExpand.sort(key=lambda tup: tup[0])

    # Ensure that we execute the highers first
    needingContract.sort(key=lambda tup: tup[0], reverse=True)
    return needingExpand, needingContract

def _calculate_revert_delta():
    ''' Calculates what needs to be reverted from the current upgrade
    '''
    _connection, cursor = get_psql_connection()
    if not _upgrade_tracking_exists(cursor):
        logger.warn("There is no upgrade tracking table, returning empty list!")
        return []
    appliedAlready = _find_currently_applied_sql(cursor)
    available = find_sql_dirs()

    needingRevert = []
    for (id, a) in available.items():
        if has_revert_phase(a) and id in appliedAlready.get(EXPAND_OPERATION, []):
            needingRevert.append((id, os.path.join(a, REVERT)))

    # Ensure that we execute the highers first
    needingRevert.sort(key=lambda tup: tup[0], reverse=True)
    return needingRevert

def _find_all_applied_sql():
    ''' List all applied sql version up until now. Including previous upgrades
    and current one.
    :return: All applied version per type
    :rtype: dict{"EXPAND":[int], "CONTRACT":[int]}
    '''
    _connection, cursor = get_psql_connection()

    result = _find_currently_applied_sql(cursor)
    result.update(_find_previous_upgrade_applied_sql(cursor)) #There are no duplicates

    return result

def _upgrade_tracking_exists(cursor):
    '''
    Validates if there is upgrade traking table
    :return: True if there is an upgrade tracking table else False
    :rtype: bool
    '''
    cursor.execute("select 1 from information_schema.tables where table_schema=lower('vLCM') and table_name=lower('UpgradeTracking')")
    return bool(cursor.rowcount)

def _find_currently_applied_sql(cursor):
    # Find from current upgrade table
    cursor.execute("SELECT ID, OPERATION FROM vLCM.UpgradeTracking")
    rows = cursor.fetchall()

    result = {}
    for (id, operation) in rows:
        ids = result.get(operation, [])
        ids.append(int(id))
        result[operation] = ids
    return result

def _find_previous_upgrade_applied_sql(cursor):
    # Find for previous upgrades table
    cursor.execute("SELECT ID, OPERATION FROM vLCM.VersionTracking")
    rows = cursor.fetchall()

    result = {}
    for (id, operation) in rows:
        ids = result.get(operation, [])
        ids.append(int(id))
        result[operation] = ids
    return result

def find_sql_dirs():
    ''' List all sql directories that are known to this bundle.
    :return: All available sql folders.
    :rtype: dict(int:str)
    '''
    folderList = os.listdir(SQL_FOLDER)

    result = {}
    for name in folderList:
        path = os.path.join(SQL_FOLDER, name)
        if os.path.isdir(path):
            result[int(name)] = path
    return result

def has_expand_phase(path):
    ''' Indicates if there is expand.sql present in the table
    '''
    return os.path.isfile(os.path.join(path, EXPAND))

def has_contract_phase(path):
    ''' Indicates if there is contract.sql present in the table
    '''
    return os.path.isfile(os.path.join(path, CONTRACT))

def has_revert_phase(path):
    ''' Indicates if there is revert.sql present in the table
    '''
    return os.path.isfile(os.path.join(path, REVERT))
