#! /usr/bin/env python
#*******************************************************************************
# ALMA - Atacama Large Millimiter Array
# (c) Associated Universities Inc., 2009 
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
#
# "@(#) $Id: ObsCalBaselines.py 247921 2017-08-08 15:07:45Z ahirota $"

#
# forcing global imports is due to an OSS problem
#
global CCL
import CCL.Global

global Control
import Control

global ControlExceptionsImpl
import ControlExceptionsImpl

global Observation
import Observation.ObsCalBase
import Observation.AtmCalTarget
import Observation.DelayCalTarget
import Observation.PointingCalTarget
import Observation.FocusCalTarget

global PyDataModelEnumeration
import PyDataModelEnumeration.PyCalibrationSet


class ObsCalBaselines(Observation.ObsCalBase.ObsCalBase):

    options = [
        Observation.ObsCalBase.scriptOption("RepeatCount", int, 1),
        Observation.ObsCalBase.scriptOption("PointingSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("AtmSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("SBRSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("FocusSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("dumpDuration", float, 0.192),
        Observation.ObsCalBase.scriptOption("channelAverageDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("integrationDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("tpIntegrationDuration", float, 0.016),
        Observation.ObsCalBase.scriptOption("ElLimit", str, "20 deg"),
        Observation.ObsCalBase.scriptOption("NumTargets", int, 9999999),
        Observation.ObsCalBase.scriptOption("pointFocusBand", int, 3),
        Observation.ObsCalBase.scriptOption("bandList", str, "3"),
        Observation.ObsCalBase.scriptOption("delayModelType", str, ""),
        Observation.ObsCalBase.scriptOption("doPoint", bool, True),
        Observation.ObsCalBase.scriptOption("doFocus", bool, True),
        Observation.ObsCalBase.scriptOption("polarization", str, "4")
    ]

    def __init__(self):
        Observation.ObsCalBase.ObsCalBase.__init__(self)
        self._srcPointFocus = None
        self._reverseSpecs = False
        self.WVRResult = None

    def parseOptions(self):
        self.repeatCount             = self.args.RepeatCount
        self.pointingSubscanDuration = self.args.PointingSubscanDuration
        self.atmSubscanDuration      = self.args.AtmSubscanDuration
        self.sbrSubscanDuration      = self.args.SBRSubscanDuration
        self.focusSubscanDuration    = self.args.FocusSubscanDuration
        self.dumpDuration            = self.args.dumpDuration
        self.channelAverageDuration  = self.args.channelAverageDuration
        self.integrationDuration     = self.args.integrationDuration
        self.tpIntegrationDuration   = self.args.tpIntegrationDuration
        self.elLimit                 = self.args.ElLimit
        self.numTargets              = self.args.NumTargets
        self.pointFocusBand          = self.args.pointFocusBand
        bandStr                      = self.args.bandList
        self.delayModelType          = self.args.delayModelType
        self.shouldPoint             = self.args.doPoint
        self.shouldFocus             = self.args.doFocus
        self.polarization            = self.args.polarization
        self.bandList = []
        for s in bandStr.split(','):
            n = int(s)
            if n < 1 or n > 10:
                raise Exception("Invalid band number in band list: '%s'" % s)
            self.bandList.append(n)
        self.logInfo("Band list: %s" % str(self.bandList))

    def setTelCalParams(self):
        self.logInfo("Setting TelCal parameters ontheflyWVRcorrection=False")
	tcParameters = self.getTelCalParams()
	tcParameters.setCalibParameter('ontheflyWVRcorrection', False)

    def setDelayModelType(self):
        # Assume obsmode defined in the hope that the python exception will be clear enough
        delayServer = self._obsmode.getDelayServer()
        # setAtmosphericDelayModel() is quite happy taking a string, so we let it do the validation.
        if self.delayModelType is not None and len(self.delayModelType) > 0:
            self.logInfo("setDelayModel(): setting DelayServer modelType to %s" % self.delayModelType)
            delayServer.setAtmosphericDelayModel(modelType=self.delayModelType)
        else:
            self.logInfo("setDelayModel(): leaving DelayServer modelType at previous value")
        currentModel = delayServer.getAtmosphericDelayModel()
        self.logInfo("setDelayModel(): DelayServer now reports delay model: %s" % str(currentModel))

    def generateTunings(self):
        corrType = self._array.getCorrelatorType()
        self._pointFocusSpectralSpec = self._tuningHelper.GenerateSpectralSpec(
                band = self.pointFocusBand,
                intent = "interferometry_continuum",
                corrType = corrType,
                dualMode = True,
                dump = self.dumpDuration,
                channelAverage = self.channelAverageDuration,
                integration = self.integrationDuration,
                tpSampleTime = self.tpIntegrationDuration)
        self._pointFocusSpectralSpec.name = "Band %d pointing/focus" % self.pointFocusBand
        self._calSpectralSpecs = []
        for band in self.bandList:
            ss = self._tuningHelper.GenerateSpectralSpec(
                            band = band,
                            intent = "calsurvey",
                            #intent = "interferometry_continuum",
                            corrType = corrType,
                            dualMode = True,
                            pol = self.polarization,
                            dump = self.dumpDuration,
                            channelAverage = self.channelAverageDuration,
                            integration = self.integrationDuration,
                            tpSampleTime = self.tpIntegrationDuration)
            ss.name = "Band %d delay" % band
            self._calSpectralSpecs.append(ss)

    def populateSourceList(self):
        self._calSources = self.sourceHelper.getAllSkySources(randomise=True, maxNumSources=self.numTargets)

    def orderedSpecs(self):
        ret = self._calSpectralSpecs
        if self._reverseSpecs:
            ret = list(reversed(self._calSpectralSpecs))
        self._reverseSpecs = not self._reverseSpecs
        return ret

    def doPointing(self):
        if self.shouldPoint == 0:
            return
        try:
            pointingCal = Observation.PointingCalTarget.PointingCalTarget(self._srcPointFocus, self._pointFocusSpectralSpec)
            pointingCal.setSubscanDuration(self.pointingSubscanDuration)
            pointingCal.setDataOrigin('CHANNEL_AVERAGE_CROSS')
            # For Baseline run avoid extra delay results
            pointingCal.setDelayCalReduction(False)
            self.logInfo('Executing PointingCal on ' + self._srcPointFocus.sourceName + '...')
            pointingCal.execute(self._obsmode)
            self.logInfo('Completed PointingCal on ' + self._srcPointFocus.sourceName)
            result = pointingCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0: 
                for key in result.keys():
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                pointingCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No pointing results!")
        except BaseException, ex:
            print ex
            msg = "Error executing pointing on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex
                    
    def doFocus(self):
        if self.shouldFocus == 0:
            return
        try:
            focusCal = Observation.FocusCalTarget.FocusCalTarget(
                    SubscanFieldSource = self._srcPointFocus,
                    Axis = 'Z',
                    SpectralSpec = self._pointFocusSpectralSpec,
                    DataOrigin = 'CHANNEL_AVERAGE_CROSS',
                    SubscanDuration = self.focusSubscanDuration,
                    OneWay = False,
                    NumPositions = 7)
            self.logInfo('Executing FocusCal on ' + self._srcPointFocus.sourceName + '...')
            focusCal.execute(self._obsmode)
            self.logInfo('Completed FocusCal on ' + self._srcPointFocus.sourceName)
            result = focusCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0: 
                for key in result.keys():
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                focusCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No focus results!")
        except BaseException, ex:
            print ex
            msg = "Error executing focus on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex
                    
    def doAtmCals(self):
        ss = self._calSpectralSpecs[0]
        try:
                atm = Observation.AtmCalTarget.AtmCalTarget(self._srcPointFocus, ss, doHotLoad=True)
                atm.setOnlineProcessing(True)
                atm.setDataOrigin('FULL_RESOLUTION_AUTO')
                atm.setDoZero(False)
                atm.setSubscanDuration(self.atmSubscanDuration)
                atm.setIntegrationTime(1.5)
                atm.setWVRCalReduction(self.haveOperationalWVR)
                atm.setApplyWVR(False)
                self.logInfo('Executing AtmCal on ' + self._srcPointFocus.sourceName + '...')
                atm.execute(self._obsmode)
                self.logInfo('Completed AtmCal on ' + self._srcPointFocus.sourceName)
                if not self.haveOperationalWVR:
                    return
                # For now keep this as a sanity check, but we no longer apply the result
                self.WVRResult = atm.checkWVRResult(self._array)
		self.logInfo("Retrieved WVR result: %s" % str(self.WVRResult))
                if not "OSS" in self._array._arrayName:
                    if self.WVRResult is None:
                        raise Exception("WVR Result is None, aborting execution")
        except BaseException, ex:
                print ex
                msg = "Error executing AtmCal on source %s" % self._srcPointFocus.sourceName
                self.logError(msg)
                self.closeExecution(ex)
                raise ex

    def doCalSource(self, src):
        for ss in self.orderedSpecs():
            try:
                # TODO: this needs to be configurable somehow.
                # Thermal vs. syncrotron -- what a hack :)
                subscanDuration = 45.0
                #subscanDuration = 90.0
                if (src.sourceName[0].lower() == 'j' and src.sourceName[1].isdigit()) or (src.sourceName[0]=='3' and src.sourceName[1].lower()=='c'):
                    f = 1.0e-9 * ss.getMeanFrequency()
                    fs = f / 300.0
                    subscanDuration *= fs*fs
                delayCal = Observation.DelayCalTarget.DelayCalTarget(src, ss)
                delayCal.setCalibrationSet(PyDataModelEnumeration.PyCalibrationSet.ANTENNA_POSITIONS)
                delayCal.setPhaseCalReduction(True)
                delayCal.setSubscanDuration(subscanDuration)
                delayCal.setIntegrationTime(1.0)
                self.logInfo('Executing DelayCal on ' + src.sourceName + '...')
                delayCal.execute(self._obsmode)
                self.logInfo('Completed DelayCal on ' + src.sourceName)
            except BaseException, ex:
                print ex
                msg = "Error executing cal survey scans on source %s" % src.sourceName
                self.logError(msg)
                self.closeExecution(ex)
                raise ex

    def doCalObservations(self):
        for i in range(self.repeatCount):
            for src in self._calSources:
                isObs = False
                try:
                    isObs = self.isObservable(src, 600)
                except BaseException, ex:
                    self.logException('Exception thrown by isObservable() when checking source %s, considering this fatal!' % src.sourceName, ex)
                    self.closeExecution(ex)
                    raise ex
                if not isObs:
                    self.logInfo("Skipping source '%s' as not observable" % src.sourceName)
                    continue
                self.doCalSource(src)



obs = ObsCalBaselines()
obs.parseOptions()
obs.checkAntennas()
obs.startPrepareForExecution()
try:
    obs.generateTunings()
    obs.populateSourceList()
    obs.findPointFocusSource()
    obs.checkForOperationalWVR()
    obs.setTelCalParams()
except BaseException, ex:
    obs.logException("Error in methods run during execution/obsmode startup", ex)
    obs.completePrepareForExecution()
    obs.closeExecution(ex)
    raise ex
obs.completePrepareForExecution()
obs.logInfo("Executing first pointing...")
obs.doPointing()
obs.logInfo("Executing second pointing -- make sure results are good!...")
obs.doPointing()
obs.logInfo("Executing focus...")
obs.doFocus()
obs.logInfo("Executing third pointing after focus -- make sure results are good!...")
obs.doPointing()
obs.logInfo("Executing AtmCal...")
obs.doAtmCals()
#obs.logInfo("Setting/checking DelayServer model type...")
#obs.setDelayModelType()
obs.logInfo("Executing Calibration observations...")
obs.doCalObservations()
obs.closeExecution()

