#! /usr/bin/env python
#*******************************************************************************
# ALMA - Atacama Large Millimiter Array
# (c) Associated Universities Inc., 2009 
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
#
# "@(#) $Id: ObsCalPhaseCheck.py 247921 2017-08-08 15:07:45Z ahirota $"

#
# forcing global imports is due to an OSS problem
#
global CCL
import CCL.Global

global Control
import Control

global ControlExceptionsImpl
import ControlExceptionsImpl

global Observation
import Observation.ObsCalBase
import Observation.AtmCalTarget
import Observation.DelayCalTarget
import Observation.PointingCalTarget
import Observation.FocusCalTarget
import Observation.PhaseCalTarget
import Observation.SBRatioCalTarget


class ObsCalPhaseCheck(Observation.ObsCalBase.ObsCalBase):

    options = [
        Observation.ObsCalBase.scriptOption("pointFocusBand", int, 3),
        Observation.ObsCalBase.scriptOption("bandList", str, "3"),
        Observation.ObsCalBase.scriptOption("doPoint", bool, True),
        Observation.ObsCalBase.scriptOption("doFocus", bool, False),
        Observation.ObsCalBase.scriptOption("doSBR",   bool, True),
        Observation.ObsCalBase.scriptOption("doATM",   bool, True),
        Observation.ObsCalBase.scriptOption("corrMode", str, "FDM"),
        Observation.ObsCalBase.scriptOption("dumpDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("channelAverageDuration", float, 0.576),
        # ACA requires more than about 8.5 sec in FDM mode for 16-ant array
        Observation.ObsCalBase.scriptOption("integrationDuration", float, 9.216),
        Observation.ObsCalBase.scriptOption("tpIntegrationDuration", float, 0.016),
        Observation.ObsCalBase.scriptOption("numDelayCals", int, 1),
        Observation.ObsCalBase.scriptOption("DelaySubscanDuration", float, 120.0),
        Observation.ObsCalBase.scriptOption("PointingSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("AtmSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("FocusSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("polarization", str, "4"),
        Observation.ObsCalBase.scriptOption("ElLimit", str, "20 deg")
    ]

    def __init__(self):
        Observation.ObsCalBase.ObsCalBase.__init__(self)
        self._srcPointFocus = None
        self._reverseSpecs = False
        self.WVRResult = None
        self.expectWVRResults = True

    def parseOptions(self):
        self.pointFocusBand          = self.args.pointFocusBand
        bandStr                      = self.args.bandList
        self.shouldPoint             = self.args.doPoint
        self.shouldFocus             = self.args.doFocus
        self.shouldSBR               = self.args.doSBR
        self.shouldATM               = self.args.doATM
        self.corrMode                = self.args.corrMode
        self.dumpDuration            = self.args.dumpDuration
        self.channelAverageDuration  = self.args.channelAverageDuration
        self.integrationDuration     = self.args.integrationDuration
        self.tpIntegrationDuration   = self.args.tpIntegrationDuration
        self.numDelayCals            = self.args.numDelayCals
        self.delaySubscanDuration    = self.args.DelaySubscanDuration
        self.pointingSubscanDuration = self.args.PointingSubscanDuration
        self.atmSubscanDuration      = self.args.AtmSubscanDuration
        self.focusSubscanDuration    = self.args.FocusSubscanDuration
        self.polarization            = self.args.polarization
        self.elLimit                 = self.args.ElLimit
        self.bandList = []
        for s in bandStr.split(','):
            n = int(s)
            if n < 1 or n > 10:
                raise Exception("Invalid band number in band list: '%s'" % s)
            self.bandList.append(n)
        self.logInfo("Band list: %s" % str(self.bandList))

    def setTelCalParams(self):
        self.logInfo("Setting TelCal parameters ontheflyWVRcorrection=%s spectrum=True" % (str(self.haveOperationalWVR)))
        tcParameters = self.getTelCalParams()
        tcParameters.setCalibParameter('ontheflyWVRcorrection', self.haveOperationalWVR)
        # spectrum is to enable channel-by-channel bandpass result. Also binningFactor can be specified if desired.
        tcParameters.setCalibParameter('spectrum', True)
        # ICT-15218
        tcParameters.setCalibParameter("binningFactor", 2)

    def generateTunings(self):
        corrType = self._array.getCorrelatorType()
        self._pointFocusSpectralSpec = self._tuningHelper.GenerateSpectralSpec(
                band = self.pointFocusBand,
                intent = "interferometry_continuum",
                corrType = corrType,
                dualMode = True,
                # durations as in science pointing case (SSRTuning.overridePointingSpectralSpecs)
                dump = 1.008,
                channelAverage = 1.008,
                integration = 2.016,
                tpSampleTime = self.tpIntegrationDuration)
        self._pointFocusSpectralSpec.name = "Band %d pointing/focus" % self.pointFocusBand
        self._calSpectralSpecs = []
        for band in self.bandList:
            ss = self._tuningHelper.GenerateSpectralSpec(
                            band = band,
                            intent = "calsurvey",
                            corrType = corrType,
                            corrMode = self.corrMode,
                            # 180deg is to be more of a science-like test -- more possible things to go wrong!
                            enable180DegWalsh = True,
                            dualMode = True,
                            pol = self.polarization,
                            dump = self.dumpDuration,
                            channelAverage = self.channelAverageDuration,
                            integration = self.integrationDuration,
                            tpSampleTime = self.tpIntegrationDuration)
            ss.name = "Band %d delay" % band
            self._calSpectralSpecs.append(ss)

    def orderedSpecs(self):
        ret = self._calSpectralSpecs
        if self._reverseSpecs:
            ret = list(reversed(self._calSpectralSpecs))
        self._reverseSpecs = not self._reverseSpecs
        return ret

    def doPointing(self):
        if self.shouldPoint == 0:
            return
        try:
            pointingCal = Observation.PointingCalTarget.PointingCalTarget(self._srcPointFocus, self._pointFocusSpectralSpec)
            pointingCal.setSubscanDuration(self.pointingSubscanDuration)
            pointingCal.setDataOrigin('CHANNEL_AVERAGE_CROSS')
            # For Baseline run avoid extra delay results
            pointingCal.setDelayCalReduction(False)
            self.logInfo('Executing PointingCal on ' + self._srcPointFocus.sourceName + '...')
            pointingCal.execute(self._obsmode)
            self.logInfo('Completed PointingCal on ' + self._srcPointFocus.sourceName)
            result = pointingCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0:
                for key in list(result.keys()):
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                pointingCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No pointing results!")
        except BaseException as ex:
            print(ex)
            msg = "Error executing pointing on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def doFocus(self):
        if self.shouldFocus == 0:
            return
        try:
            focusCal = Observation.FocusCalTarget.FocusCalTarget(
                    SubscanFieldSource = self._srcPointFocus,
                    Axis = 'Z',
                    SpectralSpec = self._pointFocusSpectralSpec,
                    DataOrigin = 'CHANNEL_AVERAGE_CROSS',
                    SubscanDuration = self.focusSubscanDuration,
                    OneWay = False,
                    NumPositions = 7)
            self.logInfo('Executing FocusCal on ' + self._srcPointFocus.sourceName + '...')
            focusCal.execute(self._obsmode)
            self.logInfo('Completed FocusCal on ' + self._srcPointFocus.sourceName)
            result = focusCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0:
                for key in list(result.keys()):
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                focusCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No focus results!")
        except BaseException as ex:
            print(ex)
            msg = "Error executing focus on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def doAtmCals(self):
        if not self.shouldATM:
            return
        ss = Observation.SSRTuning.generateAtmSpectralSpec(self._calSpectralSpecs[0])
        try:
            atm = Observation.AtmCalTarget.AtmCalTarget(self._srcPointFocus, ss, doHotLoad=True)
            atm.setOnlineProcessing(True)
            atm.setDataOrigin('CHANNEL_AVERAGE_AUTO')
            # because we use dual mode we should do a zero subscan to allow proper comparison
            atm.setDoZero(True)
            atm.setSubscanDuration(self.atmSubscanDuration)
            atm.setIntegrationTime(1.5)
            atm.setWVRCalReduction(self.haveOperationalWVR)
            atm.setApplyWVR(False)
            self.logInfo('Executing AtmCal on ' + self._srcPointFocus.sourceName + '...')
            atm.execute(self._obsmode)
            self.logInfo('Completed AtmCal on ' + self._srcPointFocus.sourceName)
            if not self.haveOperationalWVR:
                return
            # For now keep this as a sanity check, but we no longer apply the result
            self.WVRResult = atm.checkWVRResult(self._array)
            self.logInfo("Retrieved WVR result: %s" % str(self.WVRResult))
            if not "OSS" in self._array._arrayName:
                if self.WVRResult is None:
                    raise Exception("WVR Result is None, aborting execution")
        except BaseException as ex:
            print(ex)
            msg = "Error executing AtmCal on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def doSBRCal(self):
        if self.shouldSBR == 0:
            return
        ss = self._calSpectralSpecs[0]
        try:
            subscanDuration = 0.1e-9 * ss.getMeanFrequency()
            sbrCal = Observation.SBRatioCalTarget.SBRatioCalTarget(
                SubscanFieldSource = self._srcPointFocus,
                SpectralSpec = ss,
                DataOrigin = 'CHANNEL_AVERAGE_CROSS',
                SubscanDuration = subscanDuration,
                IntegrationTime = subscanDuration)
            self.logInfo('Executing SBRatioCal on ' + self._srcPointFocus.sourceName + '...')
            sbrCal.execute(self._obsmode)
            self.logInfo('Completed SBRatioCal on ' + self._srcPointFocus.sourceName)
        except BaseException as ex:
            print(ex)
            msg = "Error executing SBRatio on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def checkCorrelator(self):
        # ICT-23071
        corrType = self._array.getCorrelatorType()
        # if corrType == 'TPS':
        #     self.logInfo("TPS does not support FDM, so will use TDM")

    def doDelayCals(self):
        src = self._srcPointFocus
        for i in range(self.numDelayCals):
            for ss in self._calSpectralSpecs:
                try:
                    calTarget = Observation.BandpassCalTarget.BandpassCalTarget(src, ss)
                    calTarget.setOnlineProcessing(True)
                    calTarget.setDelayCalReduction(True)
                    calTarget.setAmpliCalReduction(True)
                    calTarget.setPhaseCalReduction(True)
                    calTarget.setSubscanDuration(self.delaySubscanDuration)
                    calTarget.setIntegrationTime(1.0)
                    self.logInfo('Executing DelayCal on ' + src.sourceName + '...')
                    calTarget.execute(self._obsmode)
                    self.logInfo('Completed DelayCal on ' + src.sourceName)
                except BaseException as ex:
                    print(ex)
                    msg = "Error executing cal survey scans on source %s" % src.sourceName
                    self.logError(msg)
                    self.closeExecution(ex)
                    raise ex

    def setEstimatedSourceFluxDensities(self):
        # ICT-8125: Put estimated flux density into Source.xml.
        src = self._srcPointFocus
        cc = self.sourceHelper._calibratorCatalog
        cc.estimateSourceFlux(src, self._calSpectralSpecs,
                              setSourceProperties=True)


obs = ObsCalPhaseCheck()
obs.parseOptions()
obs.checkAntennas()
obs.startPrepareForExecution()
try:
    obs.generateTunings()
    obs.checkCorrelator()
    obs.findPointFocusSource()
    obs.setEstimatedSourceFluxDensities()
    obs.checkForOperationalWVR()
    obs.setTelCalParams()
    obs.setTelCalRefAntennaList()
except BaseException as ex:
    obs.logException("Error in methods run during execution/obsmode startup", ex)
    obs.completePrepareForExecution()
    obs.closeExecution(ex)
    raise ex
obs.completePrepareForExecution()
obs.logInfo("Executing first pointing...")
obs.doPointing()
#obs.logInfo("Executing second pointing -- make sure results are good!...")
#obs.doPointing()
#obs.logInfo("Executing focus...")
#obs.doFocus()
#obs.logInfo("Executing third pointing after focus -- make sure results are good!...")
#obs.doPointing()
obs.logInfo("Executing AtmCal...")
obs.doAtmCals()
# This is after AtmCal to avoid it getting applied in the atmosphere reduction
obs.doSBRCal()
obs.logInfo("Executing DelayCals...")
obs.doDelayCals()
obs.closeExecution()
