# Copyright (c) 2021-2022 VMware, Inc. All rights reserved.
# VMware Confidential

import sys
import os
import os_utils
import logging
import shutil
import sys
import fss_utils

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import sps_utils
from os_utils import executeCommand
from extensions import extend, Hook
import vcsa_utils
logger = logging.getLogger(__name__)

# TODO: Extract all the common code between this file and __init__.py of external upgrade to a common utility script.

def getVersion(versionFile):
   if not os.path.isfile(versionFile):
      logger.info("No version file")
      return 0
   propertyDict = sps_utils.getPropertyDict(versionFile)
   if propertyDict is None:
      logger.info("No version available as per version file")
      return 0
   return int(propertyDict['SPS_PATCH_VERSION'])

SRC_VERSION_FILE_PATH = os.path.join(sps_utils.SPS_SERVER_CONFIG_PATH,
                                     'version.properties')
SRC_VERSION = getVersion(SRC_VERSION_FILE_PATH)
TARGET_VERSION_FILE_PATH = os.path.join(os.path.realpath(os.path.dirname(__file__)), 'version.properties')
TARGET_VERSION = getVersion(TARGET_VERSION_FILE_PATH)
COMMON_JAR_PATH = os.environ['VMWARE_COMMON_JARS']
MARKER_FILE_NDU_SOURCE = os.path.join(sps_utils.SPS_CONF_PATH, 'ndu_marker.template')
MARKER_FILE_NDU_DEST = os.path.join(sps_utils.SPS_CONF_PATH, 'ndu_marker')

@extend(Hook.Discovery)
def discover(ctx):
   if (TARGET_VERSION < SRC_VERSION):
      errorMsg = "SPS downgrade detected. Source version : %s, Target version : %s" %(SRC_VERSION, TARGET_VERSION)
      logger.error(errorMsg)
      raise Exception(errorMsg)
   logger.info("SPS's current version is %s and target version is %s" %(SRC_VERSION, TARGET_VERSION))
   logger.info("Creating a marker file for identification of NDU upgrade")
   # Create a marker file to identify the NDU upgrade at destination VC.
   open(MARKER_FILE_NDU_SOURCE, 'a').close()
   replicationConfig = {
      # Do not replicate conf files.
      sps_utils.SPS_CONF_PATH: None,
      MARKER_FILE_NDU_SOURCE: MARKER_FILE_NDU_DEST,
      SRC_VERSION_FILE_PATH: os.path.join(sps_utils.SPS_SERVER_CONFIG_PATH, 'version.properties.old')
   }
   return vcsa_utils.getComponentDiscoveryResult("sps", replicationConfig = replicationConfig)

@extend(Hook.Patch)
def patch(ctx):
   if fss_utils.getTargetFSS("NDU_Limited_Downtime") and not vcsa_utils.isDisruptiveUpgrade(ctx):
      logger.info("This is a case of NDU upgrade. Upgrade will take place during service startup.")
      return None
   if (SRC_VERSION == TARGET_VERSION):
      logger.info("SPS source and target version are same : %s. Hence, nothing to patch!" % (SRC_VERSION))
      return None
   logger.info("Starting SPS patch from %s to %s" %(SRC_VERSION, TARGET_VERSION))

   # Pickup the Java path
   javaCommand = sps_utils.getJavaCommand()
   logger.info("Picked up Java command as per VMWARE_JAVA_HOME as %s", javaCommand)

   # Main patch program
   spsPatch = "com.vmware.sps.PatchUpgradeMain"

   # Add required common libraries from common jars directory
   patchClassPath = sps_utils.addCommonLibs(COMMON_JAR_PATH)
   # Run the command
   logger.info("Patch command to be run with classpath : %s",(patchClassPath))
   spsPatchCommand = [javaCommand, '-Xms256m', '-Xmx512m',
                       '-Dlog4j.configurationFile=%s' % sps_utils.LOG4J2_PROP_PATH,
                       '-Djava.security.properties=%s'
                       % sps_utils.VMWARE_JAVA_SEC_PROP_PATH,
                       '-classpath', patchClassPath, spsPatch,
                       '-source_version', (str)(SRC_VERSION),
                       '-target_version', (str)(TARGET_VERSION)]
   stdout, stderr, returnCode,  = executeCommand(spsPatchCommand)
   logger.info("Return code : %s" %(returnCode))
   logger.info("Stdout : %s" %(stdout))
   logger.info("Stderr : %s" %(stderr))
   if (returnCode != 0):
      raise Exception("SPS patch failed")

   # Update the version by replacing the source version file with the target
   shutil.copyfile(TARGET_VERSION_FILE_PATH, SRC_VERSION_FILE_PATH)
   logger.info("SPS is successfully patched")