# Copyright 2016-2020 VMware, Inc.
# All rights reserved. -- VMware Confidential
'''
VCDB b2b entry point
'''
import sys
import os
import json
import platform
import logging

from patch_specs import DiscoveryResult, PatchInfo, RequirementsResult, Requirements # pylint: disable=E0401
from patch_errors import PermanentError, UserError

from extensions import extend, Hook # pylint: disable=E0401
from l10n import msgMetadata as _T, localizedString as _ # pylint: disable=E0401
import vcsa_utils # pylint: disable=E0401
from vcsa_utils import isDisruptiveUpgrade # pylint: disable=E0401
from reporting import getProgressReporter # pylint: disable=E0401

from transport import getCommandExitCode # pylint: disable=E0401
from transport.local import LocalOperationsManager # pylint: disable=E0401
from fss_utils import getTargetFSS # pylint: disable=E0401

from vcdb.vcdb_db_utils import (get_db_user_linux, get_db_pass_symkey, patch_db, pre_check_db)

rootDir = os.path.dirname(os.path.realpath(__file__))
nduDir = os.path.join(rootDir, "ndu")
nduSQLDir = os.path.join(rootDir, "ndu/vc_sql_procs")

psql = '/opt/vmware/vpostgres/current/bin/psql'

LOGGER = logging.getLogger(__name__)

def _cleanup_process():
    cmd = [psql, '-U',  'postgres', '-d', 'VCDB', '-f', os.path.join(nduDir, "vcdb_contract.sql")]
    exitCode = getCommandExitCode(LocalOperationsManager(), cmd)
    if exitCode:
        cause = _(_T('vcdb.ndu.contract.generic.error',
                     'Failed to remove left over state from the database.'))
        raise UserError(cause=cause)

def _update_db():
    #Mandatory update of the DB for VPXD
    cmd = [psql, '-U',  'postgres', '-d', 'VCDB', '-f', os.path.join(nduSQLDir, "vcdb_db_upgrade.sql")]
    exitCode = getCommandExitCode(LocalOperationsManager(), cmd)
    if exitCode:
        cause = _(_T('vcdb.ndu.update.generic.error',
                     'Failed to update DB.'))
        raise UserError(cause=cause)
    LOGGER.info("FFS is on and Patch hook for update of DB is working")


def _run_expand():
    """Executes the vcdb expand scripts in the database"""
    cmd = [psql, '-U',  'postgres', '-d', 'VCDB', '-f', os.path.join(nduDir, "vcdb_expand.sql")]
    exitCode = getCommandExitCode(LocalOperationsManager(), cmd)
    if exitCode:
        cause = _(_T('vcdb.ndu.expand.generic.error',
                     'Failed to extend the database state.'))
        raise UserError(cause=cause)


@extend(Hook.Discovery)
def discover(ctx): # pylint: disable=W0613
   '''DiscoveryResult discover(PatchContext ctx) throw UserError'''
   discovery_result = vcsa_utils.getComponentDiscoveryResult(
      "dbconfig",
      displayName=_(_T("VCDB.displayName", "VMware vCenter Server Database")))
   discovery_result.dependentComponents.append('vcdb_vmodl_patcher')
   return discovery_result

def _do_vcdb_patching():
   LOGGER.info("Retrieving DB user...")
   vc_user = get_db_user_linux()
   LOGGER.info("DB user retrieved: %s", vc_user)

   LOGGER.info("Retrieving DB password...")
   vc_pass = get_db_pass_symkey()
   LOGGER.info("DB password retrieved: ****")

   LOGGER.info("Execute DB patch...")
   patch_db(vc_user, vc_pass)
   LOGGER.info("DB is patched.")

@extend(Hook.Requirements)
def collectRequirements(ctx):
    '''RequirementsResult collectRequirements(PatchContext sharedCtx)'''
    requirements = Requirements()
    LOGGER.info("Calculating DB disk requirements...")
    ret, storage_set, storage_core = pre_check_db()
    if ret != 0:
      LOGGER.error("VCDB B2B failed.")
      error = _(_T('vcdb.error.b2b.precheck.diskreq', "Database error while calculating disk requirements"))
      sugg_action = _(_T('vcdb.action.b2b.precheck.diskreq',
                         "See PatchRunner.log and postgres.log files."))
      raise PermanentError(cause=error, resolution=sugg_action)
    requirements = Requirements(requiredDiskSpace={'/storage/core': storage_core, '/storage/seat': storage_set})
    return RequirementsResult(requirements)

@extend(Hook.Patch)
def patch(ctx):
    '''void patch(PatchContext ctx) throw UserError'''
    if isDisruptiveUpgrade(ctx):
        progressReporter = getProgressReporter()
        progressReporter.updateProgress(0, _(_T("vcdb.patch.begin",
                                                'Extend the VCDB schema to support the newer version')))
        # VCDB expand scripts are executed during patch for discruptive upgrade
        # in order to avoid situation where services are writing to tables
        # at the same time the vcdb expand scripts are modifying them.
        _run_expand()
        progressReporter.updateProgress(40, _(_T("vcdb.patch.expand", 'Expand of DB completed')))

        _update_db()
        progressReporter.updateProgress(80, _(_T("vcdb.patch.db", 'Update of DB completed')))

        _cleanup_process()
        progressReporter.updateProgress(100, _(_T("vcdb.patch.complete",
                                                  'VCDB schema is extended to support the newer version')))
    else:
        LOGGER.info("Patch hook is not being executed, work is done in other hooks")


@extend(Hook.Expand)
def expand(ctx):
    '''void expand(Expand ctx) throw UserError'''
    if isDisruptiveUpgrade(ctx):
        LOGGER.info("Expand hook is not being executed, work is done in other hooks")
    else:
        _run_expand()


@extend(Hook.SwitchOver)
def switchover(ctx):
    '''void switchover(Switchover ctx) throw UserError'''
    _update_db()


@extend(Hook.Contract)
def contract(ctx):
    '''void contract(Contract ctx) throw UserError'''

    progressReporter = getProgressReporter()
    progressReporter.updateProgress(0, _(_T("vcdb.contract.begin", 'Contract phase started')))

    if (not isDisruptiveUpgrade(ctx)):
       _cleanup_process()
    else:
       LOGGER.info("Contract hook is not being executed, work is done in other hooks")
    progressReporter.updateProgress(100, _(_T("vcdb.contract.complete", 'Contract phase completed')))



@extend(Hook.Revert)
def revert(ctx):
    ''' void revert(Revert ctx) throw UserError'''

    progressReporter = getProgressReporter()
    progressReporter.updateProgress(0, _(_T("vcdb.revert.begin", 'Revert phase started')))

    cmd = [psql, '-U',  'postgres', '-d', 'VCDB', '-f', os.path.join(nduDir, "vcdb_revert.sql")]
    exitCode = getCommandExitCode(LocalOperationsManager(), cmd)
    if exitCode:
        cause = _(_T('vcdb.ndu.revert.generic.error',
                     'Failed to revert database to stable state.'))
        raise UserError(cause=cause)
    progressReporter.updateProgress(100, _(_T("vcdb.revert.complete", 'Revert phase completed')))
