#!/usr/bin/env python
#
# Copyright 2017-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
#
"""
This module contains functions for getting available patches.
"""

import glob
import inspect
import logging
import os
from operator import methodcaller

from .patches.base import PatchBase

logger = logging.getLogger(__name__)


def get_applicable_patches(patch_context):
    """
    Get all patches that should be applied during the patch phase.

    :param patch_context: Context given by the patch framework
    :type patch_context: patch_specs.PatchContext

    :return: PatchBase objects sorted by version. Always return non-null result.
    :rtype: Array of PatchBase objects
    """

    logger.info('Getting applicable patches...')
    result = [p for p in get_all_patches() if p.should_patch(patch_context)]
    logger.info('Found %d applicable patches' % len(result))
    return result


def get_applicable_patches_for_expand(patch_context):
    """
    Get all patches that should be applied during the expand phase.

    :param patch_context: Context given by the patch framework
    :type patch_context: patch_specs.PatchContext

    :return: PatchBase objects sorted by version. Always return non-null result.
    :rtype: Array of PatchBase objects
    """

    logger.info('Getting applicable patches for expand...')
    result = [patch for patch in get_all_patches() if patch.should_expand(patch_context)]
    logger.info('Found %d applicable patches for expand' % len(result))
    return result


def get_executed_patches(patch_context):
    """
    Get executed patches during current upgrade.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext

    :return: PatchBase objects sorted by version. Always return non-null result.
    :rtype: Array of PatchBase objects
    """

    logger.info('Getting the executed patches...')
    result = [p for p in get_all_patches() if p.is_patched(patch_context)]
    logger.info('Found %d executed patches' % len(result))
    return result


def get_expanded_patches(patch_context):
    """
    Get expanded patches during current upgrade.

    :param patch_context: Context given by the patch framework
    :type patch_context: PatchContext

    :return: PatchBase objects sorted by version. Always return non-null result.
    :rtype: Array of PatchBase objects
    """

    logger.info('Getting the expanded patches...')
    result = [patch for patch in get_all_patches() if patch.is_expanded(patch_context)]
    logger.info('Found %d expanded patches' % len(result))
    return result


def get_all_patches():
    """
    Get the list of all patches available in the './patches' folder.

    :return: PatchBase objects sorted by version. Always return non-null result
    :rtype: Array of PatchBase objects
    """

    result = []
    logger.info('Getting all available patches...')

    # Get all 'patch_XX' modules in the 'patches' package.
    module_names = _get_patch_module_names()
    if not module_names:
        logger.info('No patch modules found')
        return result

    patches_module = __import__('patches', globals(), locals(),
                                fromlist=module_names, level=1)
    for module_name in module_names:
        # Get the module object.
        module = getattr(patches_module, module_name)

        # Find a subclass of PatchBase in the given module.
        clazz = _find_patch_class(module)
        if not clazz:
            logger.info('Patch class not found in module %s' % module_name)
            continue

        # Instantiate the class and add it to the result.
        patch = clazz()
        result.append(patch)

    # Sort the result by patch version
    result = sorted(result, key=methodcaller('get_version'))
    logger.info(
        'All patches retrieved successfully. Found %d patches.' % len(result))
    return result


def _get_patch_module_names():
    """
    :return: Get all ./patches/patch_XX.py files and return their names (without the extensions).
    :rtype: List of strings
    """

    module_names = []

    logger.info('Getting patch module names...')

    path = os.path.dirname(os.path.realpath(__file__))
    patch_files = glob.glob(os.path.join(path, 'patches', 'patch_*.py'))
    for patch_file in patch_files:
        patch_file = os.path.basename(patch_file)
        patch_file = os.path.splitext(patch_file)[0]
        module_names.append(patch_file)
        logger.info('Module found: %s' % patch_file)

    return module_names


def _find_patch_class(module):
    """
    Find a subclass of PatchBase in the given 'module'.

    :param module: Module to inspect
    :type module: Python module

    :return: Class that is a subclass of PatchBase within the specified module.
             Will return 'None' if there is no such class.
    :rtype: class
    """

    members = inspect.getmembers(module)
    for name, member in members:
        if not inspect.isclass(member):
            continue

        if issubclass(member, PatchBase):
            return member

    return None
