# Copyright (c) 2017-2019 VMware, Inc.  All rights reserved.
# -- VMware Confidential

""" VCDB VMODL In-Place (B2B) Upgrade Patcher

This module defines the function patchVmodlInVcdb which is called in the patch
hook defined in __init__.py. The patching algorithm bundles patches in order
to fit in the time and space constraints.

The reading from and writing to VCDB are implemented with psycopg2.
"""

import argparse
import cProfile
import logging
import math
import multiVmomi
import operator
import re
import sequencer
import sys
import time
import util

# This module should be in our site-packages dir.
# See bora/scons/package/vcenter/vcdb-vmodl-patch-zip.sc.
#
# On a vanilla vCSA deployed with nimbus-vcvadeploy it is already
# available in the PYTHONPATH so the streamline upgrade works fine.
import psycopg2  # pylint: disable=import-error

from sequencer import BranchType


def getTables(conn):
    """Return all the tables that have string columns"""

    result = {}  # { "table1": ["stringCol1", "stringCol2", ...], ... }

    selectAllStringTables = """
        select c.relname, array_agg(a.attname::text)
          from pg_class c
          join pg_attribute a on c.oid = a.attrelid
          join pg_namespace n on c.relnamespace = n.oid
          where relkind ='r'
              and n.nspname = 'vc'
              and atttypid in ('varchar'::regtype,
                               'char'::regtype,
                               'text'::regtype)
          group by c.relname
          order by c.relname;
    """

    with conn.cursor() as cursor:
        cursor.execute(selectAllStringTables)
        conn.commit()

        result = dict(cursor.fetchall())

    return result


def createSelectStatement(tableName, columns, sourceBranchType):
    """Craft a select statement filtering out the VMODLs that don't need upgrade
    """

    whereConditions = []
    for col in columns:
        condition = ""

        if sourceBranchType == BranchType.main:
            condition += "{0} LIKE '%%versionId=\"%%'".format(col)
        else:  # else sourceBranchType == BranchType.productBranch
            condition += "{0} LIKE '%%versionId=\"x\"%%'".format(col)

        whereConditions.append(condition)

    statement = "SELECT {0} FROM {1} WHERE {2};".format(
        ", ".join(columns), tableName, " OR ".join(whereConditions))

    return statement, []


class UpdateStats:
    def __init__(self):
        self.charsProcessed = 0
        self.charsIgnored = 0
        self.vmodlsTotal = 0
        self.vmodlsModified = 0

    def AccumulateUpdateStats(self, updateStats):
        self.charsProcessed += updateStats.charsProcessed
        self.charsIgnored += updateStats.charsIgnored
        self.vmodlsTotal += updateStats.vmodlsTotal
        self.vmodlsModified += updateStats.vmodlsModified


def generateUpdates(record, columns, patchSeq, logger, updateStats):
    """Map every key containing VMODL to its new value, logs offending XMLs

    This function returns the updates necessary for a single row.
    """

    result = {}  # { "key1": "newValue1", ... }

    vmodlRegex = re.compile(r'.*versionId="\w+".*')

    for index, value in enumerate(record):
        if not isinstance(value, str):
            continue

        if vmodlRegex.match(value):
            updateStats.charsProcessed += len(value)
            try:
                newValue = patchSeq.TranslateText(value)
                updateStats.vmodlsTotal += 1
                if newValue != value:
                    result[columns[index]] = newValue
                    updateStats.vmodlsModified += 1
            except:
                logger.critical("[BUG] Translation failed on XML: %s", value)
                logger.critical("columns: %s", columns)
                logger.critical("record: %s", record)
                raise
        else:
            updateStats.charsIgnored += len(value)

    return result


def createUpdateStatement(tableName, updates, cursorName):
    """Craft an update statement and list of value suitable for execution"""

    values = []

    setClause = []
    for key, value in updates.items():
        setClause.append("{} = %s".format(key))
        values.append(value)

    statement = "UPDATE {0} SET {1} WHERE CURRENT OF {2};\n".format(
        tableName, ", ".join(setClause), cursorName)

    return statement, values


class TableStats(UpdateStats):
    def __init__(self, tableName):
        super().__init__()
        self.rowsUpdated = 0
        self.tableTime = TimeDelta()
        self.tableName = tableName

    def AccumulateTableStats(self, tableStats):
        self.AccumulateUpdateStats(tableStats)
        self.rowsUpdated += tableStats.rowsUpdated
        self.tableTime.value += tableStats.tableTime.value


class TotalStats(TableStats):
    def __init__(self):
        super().__init__("Total")
        self.selectDuration = TimeDelta()
        self.readDuration = TimeDelta()
        self.transformDuration = TimeDelta()
        self.createUpdateStatementDuration = TimeDelta()
        self.updateDuration = TimeDelta()
        self.commitDuration = TimeDelta()


class TimeDelta:
    def __init__(self):
        self.value = 0.0

    def AsInt(self):
        return int(math.trunc(self.value))


class TimeDeltaAccumulator:
    def __init__(self, timeDelta):
        self.timeDelta = timeDelta
        self.start = 0.0

    def __enter__(self):
        self.start = time.monotonic()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.timeDelta.value += time.monotonic() - self.start


class DbPatcher:
    def __init__(self, conn, patchSeq, logger):
        self.conn = conn
        self.patchSeq = patchSeq
        self.logger = logger
        self.totalStats = TotalStats()
        self.perTableStats = []
        self.lastProgressUpdate = 0.0
        self.sourceBranchType = patchSeq.GetSourceBranchType()

    def TableNeedsUpgrade(self, tableName):
        if self.sourceBranchType == BranchType.main:
            return True
        # else self.sourceBranchType == BranchType.productBranch
        return not tableName.startswith("vpx_event_")

    def PatchDb(self, tables):
        """Patch the VMODLs by reading from a dump and writing to a cursor"""
        for tableName, columns in tables.items():
            if self.TableNeedsUpgrade(tableName):
                self.MeasurePatchTable(tableName, columns)
            else:
                self.logger.info("Table %s skipped - no upgrade needed",
                                 tableName)
        logStats(self.totalStats, self.perTableStats, self.logger)
        assert (self.totalStats.vmodlsModified != 0) == \
               (self.totalStats.rowsUpdated != 0)  # wrong patching or update

    def SelectVMODLs(self, tableName, columns, reader):
        statement, values = createSelectStatement(tableName, columns,
                                                  self.sourceBranchType)

        # The result of the cursor.mogrify will be discarded if the logging
        # level is not high enough.
        if self.logger.isEnabledFor(logging.DEBUG):
            self.logger.debug("Executing SQL Select Statement: %s",
                              reader.mogrify(statement, values))

        with TimeDeltaAccumulator(self.totalStats.selectDuration):
            # Side effect: BEGIN TRANSACTION
            reader.execute(statement, values)

    def UpdateVMODL(self, tableName, updates, cursorName, tableStats):
        self.logger.debug("Creating SQL Update Statement")

        with TimeDeltaAccumulator(
                self.totalStats.createUpdateStatementDuration):
            statement, values = createUpdateStatement(tableName, updates,
                                                      cursorName)

        with self.conn.cursor() as writer:
            # The result of the cursor.mogrify will be discarded
            # if the logging level is not high enough.
            if self.logger.isEnabledFor(logging.DEBUG):
                self.logger.debug("Executing SQL Update Statement: %s",
                                  writer.mogrify(statement, values))

            with TimeDeltaAccumulator(self.totalStats.updateDuration):
                writer.execute(statement, values)

        tableStats.rowsUpdated += 1
        if time.monotonic() - self.lastProgressUpdate > 60:
            logTableStats("Progress ", tableStats, self.logger)
            self.lastProgressUpdate = time.monotonic()

    def MeasurePatchTable(self, tableName, columns):
        self.logger.info("Traversing %s Start", tableName)
        tableStats = TableStats(tableName)
        self.perTableStats.append(tableStats)

        with TimeDeltaAccumulator(tableStats.tableTime):
            self.PatchTable(tableName, columns, tableStats)

        logTableStats("Traversing Done", tableStats, self.logger)
        self.totalStats.AccumulateTableStats(tableStats)

    def PatchTable(self, tableName, columns, tableStats):
        """Patch the VMODLs in a single table"""

        self.lastProgressUpdate = time.monotonic()
        logger = self.logger
        conn = self.conn
        totalStats = self.totalStats

        cursorName = "c_{}".format(tableName)

        with conn.cursor(cursorName) as reader:
            self.SelectVMODLs(tableName, columns, reader)

            reader.itersize = 1  # Fetch one row from DB at a time.

            readStart = time.monotonic()
            for record in reader:
                totalStats.readDuration.value += time.monotonic() - readStart

                logger.debug("Executing VMODL Transformation Routine")
                # TRANSFORM
                with TimeDeltaAccumulator(totalStats.transformDuration):
                    updates = generateUpdates(record, columns, self.patchSeq,
                                              logger, tableStats)
                # End TRANSFORM

                if updates:
                    self.UpdateVMODL(tableName, updates,
                                     cursorName, tableStats)

                readStart = time.monotonic()

        logger.debug("Executing SQL Commit Statement")

        with TimeDeltaAccumulator(totalStats.commitDuration):
            conn.commit()  # END TRANSACTION

        # if updateCount > 0: autocommit, VACUUM (VERBOSE, ANALYZE) tableName


def toMega(x):
    """Result is truncated to int"""
    return int(math.trunc(x / 1000000))


def logTableStats(prefix, tableStats, logger):
    s = tableStats
    logger.info("{} {} -"
                " time:{:.3f}s;"
                " vmodls:{:,}/{:,};"
                " rows:{:,};"
                " processed: {:,} Mutf8;"
                " ignored: {:,} Mutf8".format(prefix, s.tableName,
                                              s.tableTime.value,
                                              s.vmodlsModified, s.vmodlsTotal,
                                              s.rowsUpdated,
                                              toMega(s.charsProcessed),
                                              toMega(s.charsIgnored)))


def logParsedTableStats(tableStats, logger):
    s = tableStats
    parsedOutput = [
        "| Table_{0}_Time:"
        " {1:d}: {1:,}s".format(s.tableName, s.tableTime.AsInt()),

        "| Table_{0}_VmodlsTotal:"
        " {1:d}: {1:,}".format(s.tableName, s.vmodlsTotal),

        "| Table_{0}_VmodlsModified:"
        " {1:d}: {1:,}".format(s.tableName, s.vmodlsModified),

        "| Table_{0}_RowsUpdated:"
        " {1:d}: {1:,}".format(s.tableName, s.rowsUpdated),

        "| Table_{}_utf8CharsProcessed:"
        " {:d}: {:,}MB".format(s.tableName,
                               s.charsProcessed, toMega(s.charsProcessed)),

        "| Table_{}_utf8CharsIgnored:"
        " {:d}: {:,}MB".format(s.tableName,
                               s.charsIgnored, toMega(s.charsIgnored)),
    ]
    for output in parsedOutput:
        logger.info(output)


def logStats(totalStats, perTableStats, logger):
    sortedStats = sorted(perTableStats,
                         key=operator.attrgetter('tableTime.value',
                                                 'rowsUpdated',
                                                 'charsProcessed',
                                                 'charsIgnored'),
                         reverse=True)

    if sortedStats and (sortedStats[0].charsProcessed != 0 or
                        sortedStats[0].charsIgnored != 0):
        for s in sortedStats:
            if s.charsProcessed or s.charsIgnored:
                logTableStats("Top consumers", s, logger)

    # Log Stats relevant to Bug 2087250.
    # Be careful when changing the logged output. The following logging is
    # parsed by script (vc-B2B-cost-to-json.bash) storing B2B result data.
    # Details: "| name: integer_value: anything" is parsed as:
    #  * key - "name"
    #  * value - integer_value
    #  * ": anything" is ignored
    logger.info("TOTAL Time in seconds spent per category:")
    s = totalStats
    parsedOutput = [
        "| select: {:d}".format(s.selectDuration.AsInt()),
        "| read: {:d}".format(s.readDuration.AsInt()),
        "| transform: {:d}".format(s.transformDuration.AsInt()),
        "| create_update_statement: {:d}".format(
            s.createUpdateStatementDuration.AsInt()),
        "| update: {:d}".format(s.updateDuration.AsInt()),
        "| commit: {:d}".format(s.commitDuration.AsInt()),
        "| Vmodls: {0:d}: {0:,}".format(s.vmodlsTotal),
        "| VmodlsModified: {0:d}: {0:,}".format(s.vmodlsModified),
        "| RowsUpdated: {0:d}: {0:,}".format(s.rowsUpdated),
        "| utf8CharsProcessed: {:d}: {:,}Mutf8".format(
            s.charsProcessed, toMega(s.charsProcessed)),
        "| utf8CharsIgnored: {:d}: {:,}Mutf8".format(s.charsIgnored,
                                                     toMega(s.charsIgnored)),
    ]
    for output in parsedOutput:
        logger.info(output)

    toBePrinted = 10
    for s in sortedStats:
        if s.charsProcessed == 0 and s.charsIgnored == 0:
            break
        logParsedTableStats(s, logger)
        toBePrinted -= 1
        if toBePrinted == 0:
            break


def patchVmodlInVcdb(importSrc, importDst, logger, _):
    """Translate all the VMODLs from src to dst version using pyVmomi"""

    # Start marker for time measurements. See calculatePatchCost.sh.
    logger.info("VCDB VMODL Extension Patch Hook started UPGRADE")

    patchSeq = sequencer.PatchSequencer(
        importSrc, importDst,
        checkDstPyVmomiFn=lambda x: None,
        logger=logger)

    if patchSeq.IsUpgradeNeeded():
        with psycopg2.connect(user="postgres", dbname="VCDB") as conn:
            tables = getTables(conn)
            logger.info("Tables count: %d", len(tables))

            DbPatcher(conn, patchSeq, logger).PatchDb(tables)

            logger.info("DB reported:\n%s", "\n".join(conn.notices))

        # The 'with' block doesn't close the connection.
        conn.close()

    # End marker for time measurements. See calculatePatchCost.sh.
    logger.info("VCDB VMODL Extension Patch Hook finished UPGRADE")

    return 0


def patchAndMeasure(importSrc, importDst, logger, _):
    """Add measurements to the patching

    Those will be used later by calculatePatchCost.sh to find the time and
    space used by the upgrade. The Profile log can be used for optmiziations.
    """

    util.LogDiskUsage(logger)

    prof = cProfile.Profile()

    prof.enable()

    result = patchVmodlInVcdb(importSrc, importDst, logger, _)

    prof.disable()

    util.LogDiskUsage(logger)
    util.LogProfile(logger, prof)

    return result


def ParseArgs():
    """Parses the command line arguments

    They must be enough to run the patching logic without the B2B framework.
    """

    parser = argparse.ArgumentParser(description="Patch VMODL in VCDB.")

    parser.add_argument("--srcPyVmomi", help="Path to the source pyVmomi.")
    parser.add_argument("--dstPyVmomi", help="Path to the dest pyVmomi.")
    parser.add_argument("--stageDir", help="Path to the patch stage dir.")
    parser.add_argument("--verify-only", action="store_true", default=False,
                        help="Do a verification instead of patching.")

    args = parser.parse_args()
    return args


def CreateLogger():
    """Creates the debug logger for the patch logic

    Its verbosity can be increased to see more detailed messages.
    The logger is configuret similarly to the upgrade logger:

        See bora/install/upgrade/vcdb/test_support/libs/sdk/logging_utils.py.
    """

    logging.Formatter.converter = time.gmtime

    logLineFormat = "%(asctime)s,%(msecs)dZ vcdb_vmodl:Patch " \
                    "%(levelname)s vcdb_vmodl %(message)s"

    logging.basicConfig(format=logLineFormat,
                        datefmt="%Y-%m-%d %H:%M:%S",
                        level=logging.INFO)

    return logging.getLogger()


def Main():
    """Entry point for patching only the VMODL in VCDB

    This function assumes the mocked in patchDB.sh patching context.
    """

    args = ParseArgs()
    logger = CreateLogger()

    result = patchAndMeasure(multiVmomi.MakePyVmomiImporter(args.srcPyVmomi),
                             multiVmomi.MakePyVmomiImporter(args.dstPyVmomi),
                             logger,
                             args.stageDir)

    return result


if "__main__" == __name__:
    sys.exit(Main())
