#!/usr/bin/env python
# Copyright 2018 VMware, Inc.  All rights reserved. -- VMware Confidential

# Inventory object types and their entity type ids for which the tag
# association operation is forwarded to the vpxd vAPI endpoint.
forwarded_types_str = set(["VirtualMachine", "HostSystem"])
forwarded_motypes = set(["vm", "host"])


class TagAssociation(object):
    """
    The kv_key column of the cis_kv_keyvalue table holds the data related
    to tag associations. It is a string of the form:
    <association-type> <object-id> <tag-id> where
    association-type is one of "tag_association" or "deleted_tag_association"
    object-id is of form - urn:vmomi:HostSystem:<object-moid>
    tag-id is of form - InventoryServiceTag:<tag-uuid>:GLOBAL.
    """
    @classmethod
    def get_association(cls, assn_field):
        assn_tokens, msg = cls._get_assn_tokens(assn_field)
        if assn_tokens is not None:
            is_deleted = (assn_tokens[0] == "deleted_tag_association")
            tag, msg = cls._get_tag(assn_tokens[2])
            if tag is not None:
                return (cls(assn_field, is_deleted, assn_tokens[1], tag),
                        "")
        return (None, msg)

    def __init__(self, assn_field, is_deleted, object_id, tag):
        self._row = assn_field
        self._is_deleted = is_deleted
        self._object_id = object_id
        self._entity_type = None
        self._monum = None
        self._tag = tag

    @staticmethod
    def _get_assn_tokens(assn_field):
        if assn_field is not None and len(assn_field.strip()) > 0:
            tokens = assn_field.split(" ")
            if len(tokens) == 3:
                return (tokens, "")
        return (None, "Unrecognized row")

    @staticmethod
    def _get_tag(tag_field):
        if (":" in tag_field and
            tag_field.startswith("InventoryServiceTag")):
            tokens = tag_field.split(":")
            if len(tokens) == 3 and tokens[2] == "GLOBAL":
                return (tokens[1], "")
        return (None, "Unrecognized tag format")

    def is_vpxd_tag_assn(self, valid_entities_map):
        """
        Returns true if this tag association is owned by vpxd.
        """
        motype = None
        if self._object_id.startswith("urn:vmomi"):
            entity_tokens = self._object_id.split(":")
            if len(entity_tokens) == 4 or len(entity_tokens) == 5:
                self._entity_type = entity_tokens[2]
                tokens = entity_tokens[3].split("-")
                if len(tokens) == 2:
                    # We don't want to cast monum to an int here because
                    # tag association rows contain associations for entities
                    # with a non-int monum. Like "c7" below:
                    # urn:vmomi:ClusterComputeResource:domain-c7
                    # Since we filter out any associations that do not belong
                    # to hosts or vms and hosts and vms are guaranteed to have
                    # integer monums, we can defer the casting to int to until
                    # after we have filtered out the associations.
                    motype, self._monum = tokens

        return (not self._is_deleted and
                self._entity_type in forwarded_types_str and
                motype in forwarded_motypes and
                int(self._monum) in valid_entities_map[motype])


class VcdbUtil(object):
    _sql = {
        "seq_get_sql": "SELECT NEXTVAL('VPX_TAG_DEF_SEQ');",
        "tag_def_get_sql": "SELECT ID FROM VPX_TAG_DEF WHERE "
                           "TAG_UUID='{}' AND TYPE_ID={};",
        "tag_set_sql": "INSERT INTO VPX_TAG_DEF VALUES ({}, '{}', {});",
        "xref_set_sql": "INSERT INTO VPX_ENTITY_TAG_XREF VALUES ({}, {});",
        "entity_host_get_sql": "SELECT ID FROM VPX_ENTITY WHERE TYPE_ID=1;",
        "entity_vm_get_sql": "SELECT ID FROM VPX_ENTITY WHERE TYPE_ID=0;",
        "tagassn_get_sql": "SELECT KV_KEY FROM CIS_KV_KEYVALUE WHERE "
            "KV_PROVIDER='tagging:TagAssociations:default-scope';",
        "tagassn_del_sql": "DELETE FROM CIS_KV_KEYVALUE WHERE "
            "KV_PROVIDER='tagging:TagAssociations:default-scope' "
            "AND KV_KEY = '{}';"
    }

    def __init__(self):
        import psycopg2
        self._conn = psycopg2.connect("dbname=VCDB user=postgres")
        self._cur = self._conn.cursor()

    def _get_valid_entities(self):
        """
        Retuns a map with moid type name as keys and set of moids of that type
        as values.
        """
        valid_entities_map = {}
        self._cur.execute(self._sql["entity_host_get_sql"])
        # Results are returned as [(10,), (16,), (83,), (85,), (86,)].
        valid_entities_map["host"] = set([a[0] for a in self._cur.fetchall()])
        self._cur.execute(self._sql["entity_vm_get_sql"])
        valid_entities_map["vm"] = set([a[0] for a in self._cur.fetchall()])
        return valid_entities_map

    def _get_associations(self):
        """
        Queries the CIS_KV_KEYVALUE table for the tag association data.
        """
        self._cur.execute(self._sql["tagassn_get_sql"])
        return [a[0] for a in self._cur.fetchall()]

    def _get_seq(self, tag, type_id):
        """
        Checks if the tag uuid, tag type pair is already present in vpx_tag_def
        table. If yes, returns the id of the row. Otherwise, generates the next
        sequence for the table and returns it.
        """
        self._cur.execute(self._sql["tag_def_get_sql"].format(
            tag, type_id))
        seqs = self._cur.fetchall()
        if len(seqs) == 0:
            self._cur.execute(self._sql["seq_get_sql"])
            return (True, self._cur.fetchone()[0])
        else:
            return (False, seqs[0][0])

    def _persist_associations(self, rows):
        for ta_field in rows:
            type_id = 0 if ta_field._entity_type == "VirtualMachine" else 1
            insert, seq = self._get_seq(ta_field._tag, type_id)
            if insert:
                self._cur.execute(self._sql["tag_set_sql"].format(int(seq),
                                                                  ta_field._tag,
                                                                  type_id))
            self._cur.execute(self._sql["xref_set_sql"].
                              format(int(ta_field._monum), int(seq)))

            # Delete the row from the CIS_KV_KEYVALUE table.
            self._cur.execute(self._sql["tagassn_del_sql"].
                              format(ta_field._row))
        self._conn.commit()

    def __del__(self):
        self._conn.close()
        self._cur.close()
