#!/usr/bin/env python
#
# Copyright 2016-2018, 2021 VMware, Inc.  All rights reserved. -- VMware Confidential
#

import logging
import os
import shutil
import sys
import unittest
from distutils.version import LooseVersion
from importlib import import_module
from unittest.mock import Mock

logging.basicConfig(level=logging.INFO)

# Mocks
sys.modules['l10n'] = Mock()
sys.modules['extensions'] = Mock()
sys.modules['os_utils'] = Mock()
sys.modules['patch_errors'] = Mock()
sys.modules['patch_specs'] = Mock()
sys.modules['reporting'] = Mock()
sys.modules['vcsa_utils'] = Mock()
sys.modules['fss_utils'] = Mock()

# Absolute import so that one can run the unit tests from the endpoint-b2b folder
# with PYTHONPATH=$(dirname `pwd`) python3.5 -m unittest
# import_module is used because of the dash in the module's name.
module = import_module('endpoint-b2b.patch_util')


class PatchUtilTestCase(unittest.TestCase):
   """
   Unit tests for the 'patch_util' module
   """

   def setUp(self):
      # Change the patches folder to test/patches
      module.PATCHES_PATH = 'test.patches'

   def test_get_all_patches(self):
      """
      Test getting all patch classes
      """

      # Invoke the method to test
      patches = module.get_all_patches()

      # Verify the result
      self._verify_single_test_patch(patches)

   def test_get_applicable_patches_smaller_cur_version(self):
      """
      Test getting all applicable patches.
      Current version is smaller than the patch version, so the patch will be returned.
      """

      # Invoke the method to test
      cur_version = LooseVersion("6.5.0.0")
      patches = module.get_applicable_patches(cur_version)

      # Verify the result
      self._verify_single_test_patch(patches)

   def test_get_applicable_patches_greater_cur_version(self):
      """
      Test getting all applicable patches.
      Current version is greater than the patch version, so no patches should be returned.
      """

      # Invoke the method to test
      cur_version = LooseVersion("6.5.2.0")
      patches = module.get_applicable_patches(cur_version)

      # Verify the result
      self.assertIsNotNone(patches)
      self.assertEqual(len(patches), 0)

   def test_get_applicable_patches_equal_cur_version(self):
      """
      Test getting all applicable patches.
      Current version is equal to the patch version, so no patches should be returned.
      """

      # Invoke the method to test
      cur_version = LooseVersion("6.5.1.312")
      patches = module.get_applicable_patches(cur_version)

      # Verify the result
      self.assertIsNotNone(patches)
      self.assertEqual(len(patches), 0)

   def _verify_single_test_patch(self, patches):
      self.assertIsNotNone(patches)
      self.assertEqual(len(patches), 1)
      self.assertEqual(patches[0].get_version(), LooseVersion("6.5.1.312"))
