#!/usr/bin/env python
#
# Copyright 2016-2018 VMware, Inc.  All rights reserved. -- VMware Confidential
#
"""
This module contains functions for getting available patches.
Copied from NGC with slight modifications:
//depot/vsphere-client-modules/
   vmkernel-main/assembler/patch/vsphere_client/patch_util.py
"""

import glob
import inspect
import logging
import os
from operator import methodcaller

from .patches.base import PatchBase

logger = logging.getLogger(__name__)

PATCHES_PATH = 'patches'

def get_applicable_patches(cur_version):
   """
   Get all patches with version higher than the specified version.

   :param cur_version: Version to compare with.
   :type cur_version: LooseVersion

   :return: PatchBase objects sorted by version. Always return non-null result.
   :rtype: Array of PatchBase objects
   """

   logger.info('Getting applicable patches for version %s', cur_version)
   result = [p for p in get_all_patches() if p.get_version() > cur_version]
   logger.info('Found %d applicable patches', len(result))
   return result


def get_all_patches():
   """
   Get the list of all patches available in the './patches' folder.

   :return: PatchBase objects sorted by version. Always return non-null result
   :rtype: Array of PatchBase objects
   """

   result = []
   logger.info('Getting all available patches...')

   # Get all 'patch_*' modules in the 'patches' package.
   module_names = _get_patch_module_names()
   if not module_names:
      logger.info('No patch modules found')
      return result

   patches_module = __import__(PATCHES_PATH, globals(), locals(), fromlist=module_names, level=1)
   for module_name in module_names:
      # Get the module object.
      module = getattr(patches_module, module_name)

      # Find a subclass of PatchBase in the given module.
      clazz = _find_patch_class(module)
      logger.info('Patch class found in module %s: %s', module_name, clazz)
      if not clazz:
         logger.info('Patch class not found in module %s', module_name)
         continue

      # Instantiate the class and add it to the result.
      patch = clazz()
      result.append(patch)

   # Sort the result by patch version
   result = sorted(result, key=methodcaller('get_version'))
   logger.info('All patches retrieved successfully. Found %d patches.', len(result))
   return result


def _get_patch_module_names():
   """
   :return: Get all ./patches/patch_*.py files and return their names (without the extension).
   :rtype: List of strings
   """

   module_names = []

   logger.info('Getting patch module names...')

   current_dir = os.path.dirname(os.path.realpath(__file__))
   patches_dir = os.path.join(current_dir, PATCHES_PATH.replace('.', '/'))
   patch_files = glob.glob(os.path.join(patches_dir, 'patch_*.py'))
   for patch_file in patch_files:
      patch_file = os.path.basename(patch_file)
      patch_file = os.path.splitext(patch_file)[0]
      module_names.append(patch_file)
      logger.info('Module found: %s', patch_file)

   return module_names


def _find_patch_class(module):
   """
   Find a subclass of PatchBase in the given 'module'.

   :param module: Module to inspect
   :type module: Python module

   :return: Class that is a subclass of PatchBase within the specified module.
            Will return 'None' if there is no such class.
   :rtype: class
   """

   members = inspect.getmembers(module)
   for name, member in members:
      if not inspect.isclass(member):
         continue

      if member == PatchBase:
         continue

      if issubclass(member, PatchBase):
         return member

   return None
