#! /usr/bin/env python
#*******************************************************************************
# ALMA - Atacama Large Millimiter Array
# (c) Associated Universities Inc., 2009 
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
#
# "@(#) $Id: ObsCalIFDelays.py 247921 2017-08-08 15:07:45Z ahirota $"

#
# forcing global imports is due to an OSS problem
#
global copy
import copy

global CCL
import CCL.Global

global Control
import Control

global ControlExceptionsImpl
import ControlExceptionsImpl

global Observation
import Observation.DelayCalTarget
import Observation.SSRTuning
import Observation.ObsCalBase


class ObsCalDiffGainCheckout(Observation.ObsCalBase.ObsCalBase):

    options = [
        Observation.ObsCalBase.scriptOption("bandList", str, "3,7"),
        Observation.ObsCalBase.scriptOption("dumpDuration", float, 0.192),
        Observation.ObsCalBase.scriptOption("channelAverageDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("integrationDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("tpIntegrationDuration", float, 0.016),
        Observation.ObsCalBase.scriptOption("ElLimit", str, "20 deg"),
        Observation.ObsCalBase.scriptOption("harmonicB2B", bool, False),
        Observation.ObsCalBase.scriptOption("dgcRefIntTime", float, 60.),
        Observation.ObsCalBase.scriptOption("dgcSciIntTime", float, 72.),
        Observation.ObsCalBase.scriptOption("dgcIntCycle", float, 18.),
        Observation.ObsCalBase.scriptOption("coalesceSubscans", bool, True),
        Observation.ObsCalBase.scriptOption("scanSeq", bool, True),
        Observation.ObsCalBase.scriptOption("repeats", int, 1),
        Observation.ObsCalBase.scriptOption("corrMode", str, "FDM"),
        Observation.ObsCalBase.scriptOption("freqHigh", float, 0.),
        Observation.ObsCalBase.scriptOption("freqLow", float, 0.),
        Observation.ObsCalBase.scriptOption("doPointing", bool, False),
        Observation.ObsCalBase.scriptOption("doATM", bool, False),
        Observation.ObsCalBase.scriptOption("nSources", int, 1),
        Observation.ObsCalBase.scriptOption("enable90DegWalsh", bool, 0),
        Observation.ObsCalBase.scriptOption("bbFreqsHigh", str, ""),
        Observation.ObsCalBase.scriptOption("bbFreqsLow", str, ""),
        Observation.ObsCalBase.scriptOption("sbPrefHigh", str, ""),
        # Observation.ObsCalBase.scriptOption("sbPrefLOW", str, ""),
    ]

    def parseOptions(self):
        tokens = self.args.bandList.split(",")
        self.bandList = [int(token) for token in tokens]
        self.dumpDuration            = self.args.dumpDuration
        self.channelAverageDuration  = self.args.channelAverageDuration
        self.integrationDuration     = self.args.integrationDuration
        self.tpIntegrationDuration   = self.args.tpIntegrationDuration
        self.elLimit                 = self.args.ElLimit
        corrType = self._array.getCorrelatorType()
        if self.args.corrMode == "FDM" and str(corrType) == "BL":
            if self.dumpDuration < 0.192 * 3:
                self.dumpDuration = 0.192 * 3
            if self.integrationDuration < 0.192 * 3:
                self.integrationDuration = 0.192 * 3
            if self.channelAverageDuration < 0.192 * 3:
                self.channelAverageDuration = 0.192 * 3
        # # 10.08 / 16 = 0.63
        # if self.corrMode == "FDM" and self.integrationDuration < 1.152:
        #     self.integrationDuration = 1.152

        # BB center frequencies for the higher band setup
        tokens = self.args.bbFreqsHigh.split(",")
        self.args.bbFreqsHigh = [float(t_) for t_ in tokens if t_ != ""]
        # BB center frequencies for the lower band setup
        tokens = self.args.bbFreqsLow.split(",")
        self.args.bbFreqsLow= [float(t_) for t_ in tokens if t_ != ""]

    def applySpectralAveraging(self, ss):
        if ss.BLCorrelatorConfiguration:
            corrConfig = ss.BLCorrelatorConfiguration
            for bbc in corrConfig.BLBaseBandConfig:
                for spw in bbc.BLSpectralWindow:
                    # 32 is not accepted by CYCLE6 software
                    spw.spectralAveragingFactor = 16
        elif ss.ACACorrelatorConfiguration:
            corrConfig = ss.ACACorrelatorConfiguration
            for bbc in corrConfig.ACABaseBandConfig:
                for spw in bbc.ACASpectralWindow:
                    spw.spectralAveragingFactor = 32
        elif ss.ACASpectrometerCorrelatorConfiguration:
            corrConfig = ss.ACASpectrometerCorrelatorConfiguration
            for bbc in corrConfig.ACABaseBandConfig:
                for spw in bbc.ACASpectralWindow:
                    spw.spectralAveragingFactor = 32
        else:
            raise Exception()

    def generateTunings(self):
        from Observation.SSRTuning import generateHarmonicPhaseCalSpectralSpec
        corrType = self._array.getCorrelatorType()
        bandLow, bandHigh = self.bandList
        kwargs = dict(
            intent="interferometry_continuum",
            # intent="calsurvey",
            corrType=corrType,
            corrMode=self.args.corrMode,
            dualMode=True,
            pol="2",
            dump=self.dumpDuration,
            channelAverage=self.channelAverageDuration,
            integration=self.integrationDuration,
            tpSampleTime=self.tpIntegrationDuration,
            # Do it science-like
            enable180DegWalsh=True,
            enable90DegWalsh=self.args.enable90DegWalsh,
        )

        # Higher-band
        if len(self.args.bbFreqsHigh) > 0:
            kwargs["bbFreqs"] = self.args.bbFreqsHigh

            if self.args.enable90DegWalsh and bandHigh in [9, 10] \
                and self.args.sbPrefHigh == "":
                msg = "When 90-deg Walsh is enabled and BB center frequencies"
                msg += " are specified, 'sbPrefHigh' should be specified"
                self.logError(msg)

                ex = ControlExceptionsImpl.IllegalParameterErrorExImpl()
                ex.setData(Control.EX_USER_ERROR_MSG, msg)
                self.closeExecution(ex)
                raise ex

        if self.args.sbPrefHigh != "":
            kwargs["SBPref"] = self.args.sbPrefHigh

        ssHigh = self._tuningHelper.GenerateSpectralSpec(
            frequency=self.args.freqHigh if self.args.freqHigh > 0 else None,
            band=bandHigh, **kwargs
        )
        if self.args.corrMode == "FDM":
            self.applySpectralAveraging(ssHigh)
        ssHigh.name = "Band %d setup" % (bandHigh)

        # Lower-band
        if len(self.args.bbFreqsLow) > 0:
            kwargs["bbFreqs"] = self.args.bbFreqsLow

        if self.args.harmonicB2B:
            ssLow = generateHarmonicPhaseCalSpectralSpec(
                ssHigh, self, self._obsmode, userBandPreference=bandLow
            )
        else:
            ssLow = self._tuningHelper.GenerateSpectralSpec(
                frequency=self.args.freqLow if self.args.freqLow > 0 else None,
                band=bandLow, **kwargs
            )
        if self.args.corrMode == "FDM":
            self.applySpectralAveraging(ssLow)
        ssLow.name = "Band %d setup" % (bandLow)

        self._spectralSpecHigh = ssHigh
        self._spectralSpecLow = ssLow

        self.logInfo("ssHigh = %s" % (ssHigh.toDOM().toprettyxml()))

    def doATMCal(self, src, ss):
        from Observation.SSRTuning import generateAtmSpectralSpec
        ssATM = generateAtmSpectralSpec(ss)
        # self.logInfo(ssATM.toDOM().toprettyxml())
        atm = Observation.AtmCalTarget.AtmCalTarget(src, ssATM, doHotLoad=True)
        atm.setOnlineProcessing(True)
        atm.setDataOrigin('FULL_RESOLUTION_AUTO')
        atm.setDoZero(False)
        atm.setSubscanDuration(5.76)
        atm.setIntegrationTime(1.5)
        atm.setWVRCalReduction(True)
        # Applying the results takes a while with lots of
        # antennas, so until we use online WVR, don't bother
        atm.setApplyWVR(False)

        # Automatically adjust reference position for this target.
        atm.tweakReferenceOffset()
        self.logInfo('Executing AtmCal on ' + src.sourceName + '...')
        atm.execute(self._obsmode)
        self.logInfo('Completed AtmCal on ' + src.sourceName)

    def doDiffGainCals(self):
        try:
            for iRepeat in range(self.args.repeats):
                for src in self._calSources:
                    self.doDiffGainCal(src)
        except BaseException as ex:
            import traceback
            self.logError(traceback.format_exc())
            msg = "Error executing pointing on source %s" % (src.sourceName)
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def doDiffGainCal(self, src):
        from Observation.DiffGainCalTarget import DiffGainCalTarget
        from Observation.DelayCalTarget import DelayCalTarget
        import Observation.ScanList
        import math
        import fractions

        if self.args.doATM:
            self.doATMCal(src, self._spectralSpecLow)
            self.doATMCal(src, self._spectralSpecHigh)

        dgcSciIntTime = self.args.dgcSciIntTime
        dgcRefIntTime = self.args.dgcRefIntTime
        dgcIntCycle = self.args.dgcIntCycle

        # Number of cycles
        numCycles = math.ceil(float(dgcSciIntTime / dgcIntCycle))
        numRefernce = numCycles + 1

        subd = dgcRefIntTime / numRefernce

        target = DelayCalTarget(src, self._spectralSpecLow)
        subd = target.roundSubscanDuration(subd)

        subdList = [subd, target.roundSubscanDuration(dgcIntCycle)]
        for subd_ in subdList:
            subd = fractions.gcd(int(round(subd * 1.0e7)),
                                 int(round(subd_ * 1.0e7))) / 1.0e7
        if subd < 2:
            subd = target.roundSubscanDuration(2.)
        # TEMPORAL - test non-coalescence case
        if subd > 6:
            subd = target.roundSubscanDuration(6.)

        import numpy as np
        sciIntTime = np.round(dgcSciIntTime / numCycles / subd) * subd * numCycles
        self.logInfo("sciIntTime=%s" % (sciIntTime))
        self.logInfo("dgcSciIntTime=%s" % (dgcSciIntTime))
        self.logInfo("numRefernce = %d" % (numRefernce))
        self.logInfo("subd = %.2f" % (subd))
        dgcSciIntTime = sciIntTime
        target = DiffGainCalTarget(src,
                                   self._spectralSpecLow,
                                   [self._spectralSpecHigh],
                                   SubscanDuration=subd,
                                   SubscanDurationList=[subd],
                                   IntegrationTime=dgcRefIntTime,
                                   IntegrationTimeList=[dgcSciIntTime],
                                   SetupDurationList=[dgcIntCycle],
                                   InternalCycleTime=dgcIntCycle,
                                   doScienceFirst=False)
        # ICT-19608
        from PyDataModelEnumeration import PyCalDataOrigin
        target.setDataOrigin(PyCalDataOrigin.CHANNEL_AVERAGE_CROSS)
        target.coalesceSubscans = self.args.coalesceSubscans

        if self.args.scanSeq:
            scanList = Observation.ScanList.ScanList()
        else:
            scanList = None
        self.logInfo('Executing DiffGainCal on ' + src.sourceName + '...')
        target.execute(self._obsmode, scanList=scanList)
        if scanList is not None:
            scanList.execute(self._obsmode)
        self.logInfo('Completed DiffGainCal on ' + src.sourceName)

    def doPointing(self):
        import Observation.PointingCalTarget
        if not self.args.doPointing:
            return
        self.pointingSubscanDuration = 5.76
        try:
            pointingCal = Observation.PointingCalTarget.PointingCalTarget(self._srcPointFocus, self._spectralSpecLow)
            pointingCal.setSubscanDuration(self.pointingSubscanDuration)
            pointingCal.setDataOrigin('CHANNEL_AVERAGE_CROSS')
            # if not self.pipelineFriendly:
            #     pointingCal.setDelayCalReduction(True)
            self.logInfo('Executing PointingCal on ' + self._srcPointFocus.sourceName + '...')
            pointingCal.execute(self._obsmode)
            self.logInfo('Completed PointingCal on ' + self._srcPointFocus.sourceName)
            result = pointingCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0:
                for key in list(result.keys()):
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                pointingCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No pointing results!")
        except BaseException as ex:
            import traceback
            self.logError(traceback.format_exc())
            msg = "Error executing pointing on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def findPointFocusSource(self, minEl=35.0, maxEl=85.0):
        self.logInfo("Querying for Ponting/Focus source using getBrightGridSource()...")
        cc = self.sourceHelper.getCalibratorCatalog()
        cc.setupArrayInformation()
        duration = min(300 * max(self.args.nSources, 1) * self.args.repeats, 1500)
        srcList = cc.getBrightGridSources(freq=92.0e9, minEl=minEl, maxEl=maxEl,
                                          returnFieldSource=True,
                                          checkShadowing=True,
                                          duration=duration,
                                          nRequest=-1)
        src = srcList[0]
        self.logInfo("Ponting/Focus source will be: %s" % src.toxml())
        self._srcPointFocus = src
        self._calSources = srcList[:self.args.nSources]
        for iSource, src in enumerate(self._calSources):
            self.logInfo("[Source%d] %s" % (iSource, src.sourceName))

    # def doFocus(self):
    #     self.focusSubscanDuration = 5.76
    #     try:
    #         focusCal = Observation.FocusCalTarget.FocusCalTarget(
    #                 SubscanFieldSource = self._srcPointFocus,
    #                 Axis = 'Z',
    #                 SpectralSpec = self._spectralSpecLow,
    #                 DataOrigin = 'CHANNEL_AVERAGE_CROSS',
    #                 SubscanDuration = self.focusSubscanDuration,
    #                 OneWay = False,
    #                 NumPositions = 7)
    #         self.logInfo('Executing FocusCal on ' + self._srcPointFocus.sourceName + '...')
    #         focusCal.execute(self._obsmode)
    #         self.logInfo('Completed FocusCal on ' + self._srcPointFocus.sourceName)
    #         result = focusCal.checkResult(self._array)
    #         self.logInfo("Result is: %s" % str(result))
    #         if len(result) > 0:
    #             for key in list(result.keys()):
    #                 self.logInfo("Found solution for %s using polarization(s) %s" %
    #                         (key, result[key]))
    #             focusCal.applyResult(self._obsmode, result)
    #         else:
    #             if not "OSS" in self._array._arrayName:
    #                 raise Exception("No focus results!")
    #     except BaseException as ex:
    #         import traceback
    #         self.logError(traceback.format_exc())
    #         msg = "Error executing focus on source %s" % self._srcPointFocus.sourceName
    #         self.logError(msg)
    #         self.closeExecution(ex)
    #         raise ex


obs = ObsCalDiffGainCheckout()
obs.parseOptions()
obs.checkAntennas()
obs.startPrepareForExecution()
try:
    obs.findPointFocusSource()
except BaseException as ex:
    msg = "Error in methods run during execution/obsmode startup"
    obs.logException(msg, ex)
    obs.completePrepareForExecution()
    obs.closeExecution(ex)
    raise ex
obs.completePrepareForExecution()
obs.generateTunings()
obs.logInfo("Executing DelayCals...")
obs.doPointing()
obs.doPointing()
# obs.doFocus()
# obs.doPointing()
obs.doDiffGainCals()
obs.closeExecution()
