#! /usr/bin/env python
#*******************************************************************************
# ALMA - Atacama Large Millimiter Array
# (c) Associated Universities Inc., 2009 
# 
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
# 
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA
#
# "@(#) $Id: AmpCalSurvey24h.py 241594 2017-02-27 21:53:18Z javarias $"

#
# forcing global imports is due to an OSS problem
#
global os
import os
global sys
import sys
global math
import math
global datetime
import datetime

global CCL
import CCL.Global

global Control
import Control

global ControlExceptionsImpl
import ControlExceptionsImpl

global AcsutilPy
import AcsutilPy.FindFile

global Observation
import Observation.AtmCalTarget
import Observation.AmplitudeCalTarget
import Observation.PointingCalTarget
import Observation.FocusCalTarget
import Observation.ObsCalBase
import Observation.Global


class AmpCalSurvey24h(Observation.ObsCalBase.ObsCalBase):

    azoffset = math.radians(150.0/3600.0)

    options = [
        Observation.ObsCalBase.scriptOption("RepeatCount", int, 1),
        Observation.ObsCalBase.scriptOption("PointingSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("AtmSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("SBRSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("FocusSubscanDuration", float, 5.76),
        Observation.ObsCalBase.scriptOption("dumpDuration", float, 0.192),
        Observation.ObsCalBase.scriptOption("channelAverageDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("integrationDuration", float, 2.88),
        #Observation.ObsCalBase.scriptOption("AtmIntegrationDuration", float, 0.576),
        Observation.ObsCalBase.scriptOption("tpIntegrationDuration", float, 0.016),
        Observation.ObsCalBase.scriptOption("ElLimit", str, "20 deg"),
        # Negative value means to use a time-of-day dependent default value
        Observation.ObsCalBase.scriptOption("NumTargets", int, -1),
        Observation.ObsCalBase.scriptOption("pointFocusBand", int, 7),
        Observation.ObsCalBase.scriptOption("bandList", str, "7,9"),
        Observation.ObsCalBase.scriptOption("scheduleFilenameSuffix", str, ""),
        Observation.ObsCalBase.scriptOption("scheduleFilename", str, ""),
        Observation.ObsCalBase.scriptOption("doTestSourceResolution", bool, False),
        Observation.ObsCalBase.scriptOption("logFile", str, "NoLog"),
        Observation.ObsCalBase.scriptOption("extraSourceListFile", str, "")
    ]

    def __init__(self):
        Observation.ObsCalBase.ObsCalBase.__init__(self)
        self._srcPointFocus = None
        self._doSBRatio = True
        self._reverseSpecs = False
        self.observedSources = []
        self.gridNames = [src.sourceName for src in self.sourceHelper.getAllGridSources()]

    def getDefaultNumTargets(self):
        # sunrise, daytime and sunset we want to finish sooner
        dt = datetime.datetime.utcnow()
        ret = 14
        if dt.hour >= 11 or dt.hour < 2:
            ret = 6
            self.logInfo("Hour of the day is %d, so using reduced default number of sources: %d" % (dt.hour, ret))
        else:
            self.logInfo("Hour of the day is %d, so using full default number of sources: %d" % (dt.hour, ret))
        return ret

    def parseOptions(self):
        self.repeatCount             = self.args.RepeatCount
        self.pointingSubscanDuration = self.args.PointingSubscanDuration
        self.atmSubscanDuration      = self.args.AtmSubscanDuration
        self.sbrSubscanDuration      = self.args.SBRSubscanDuration
        self.focusSubscanDuration    = self.args.FocusSubscanDuration
        self.dumpDuration            = self.args.dumpDuration
        self.channelAverageDuration  = self.args.channelAverageDuration
        self.integrationDuration     = self.args.integrationDuration
        #self.atmIntegrationDuration  = self.args.AtmIntegrationDuration
        self.tpIntegrationDuration   = self.args.tpIntegrationDuration
        self.elLimit                 = self.args.ElLimit
        # expert NumTargets parameter or command-line option takes precedence if used
        self.numTargets              = self.getDefaultNumTargets()
        if self.args.NumTargets > 0:
            self.logInfo("User-specified NumTargets: %d" % self.args.NumTargets) 
            self.numTargets = self.args.NumTargets
        self.pointFocusBand          = self.args.pointFocusBand
        bandStr                      = self.args.bandList
        self.scheduleFilenameSuffix  = self.args.scheduleFilenameSuffix
        self.scheduleFilename        = self.args.scheduleFilename
        self.doTestSourceResolution  = self.args.doTestSourceResolution
        self.logFile                 = self.args.logFile
        self.extraSourceListFile     = self.args.extraSourceListFile
        self.bandList = []
        for s in bandStr.split(','):
            n = int(s)
            if n < 1 or n > 10:
                raise Exception("Invalid band number in band list: '%s'" % s)
            self.bandList.append(n)
        self.logInfo("Band list: %s" % str(self.bandList))

        # Configure subscanDuration for each band
        # FIXME: this should also support command line options somehow
        self.subscanDurationDict     = dict()
        for band in range(1, 11):
            key = "band%dScanDuration" % band
            params = [param for param in self.expertParameters if str(param.Keyword) == key]
            if len(params) == 0:
                continue
            durationStr = params[0].Value
            self.subscanDurationDict[band] = float(durationStr)
            self.logInfo("Band%d : subscanDuration = %f [sec]" % (band, self.subscanDurationDict[band]))

    def generateTunings(self):
        corrType = self._array.getCorrelatorType()
        self._pointFocusSpectralSpec = self._tuningHelper.GenerateSpectralSpec(
                band = self.pointFocusBand,
                intent = "interferometry_continuum",
                corrType = corrType,
                dualMode = True,
                dump = self.dumpDuration,
                channelAverage = self.channelAverageDuration,
                integration = self.integrationDuration,
                tpSampleTime = self.tpIntegrationDuration)
        self._pointFocusSpectralSpec.name = "Band %d pointing/focus" % self.pointFocusBand
        self._calSpectralSpecs = []
        for band in self.bandList:
            ss = self._tuningHelper.GenerateSpectralSpec(
                            band = band,
                            intent = "calsurvey",
                            corrType = corrType,
                            dualMode = True,
                            dump = self.dumpDuration,
                            channelAverage = self.channelAverageDuration,
                            integration = self.integrationDuration,
                            tpSampleTime = self.tpIntegrationDuration)
            ss.name = "Band %d calsurvey" % band
            self._calSpectralSpecs.append(ss)

    def readObservingLog(self):
        if self.logFile.lower() == 'NoLog'.lower():
            self.logInfo("readObservingLog(): log file is '%s', assuming no sources have ever observed." % self.logFile)
            return
        try:
            fp = open(os.path.expanduser(self.logFile), 'r')
        except:
            self.logInfo("readObservingLog(): log file '%s' not readable, assuming no sources have ever observed." % self.logFile)
            return
        numLines = 0
        for line in fp:
            if len(line)==0 or line[0] == '\n' or line[0] == '#':
                continue
            numLines += 1
            name = line.split(',',2)[0]
            if name not in self.observedSources:
                self.observedSources.append(name)
        self.logInfo("readObservingLog(): read %d source lines, now have %d sources observed already" % (numLines, len(self.observedSources)))

    def readExtraSourceList(self, fname=None):
        import re
        from CCL.FieldSource import EquatorialSource
        if fname is None:
            fname = self.extraSourceListFile
        if fname == "":
            return

        self.logInfo("Reading (extra) source list from '%s'..." % (fname))
        if not os.path.isabs(fname):
            ex = ControlExceptionsImpl.IllegalParameterErrorExImpl()
            ex.setData(Control.EX_USER_ERROR_MSG,
                       "Source list file name should be specified as an absolute path '%s'" % fname)
            raise ex

        self.logInfo("Will read extra sources from '%s'" % (fname))
        pat_sep = re.compile("[\t ]+")
        with open(fname) as f:
            lines = f.readlines()
        lines = [line.strip() for line in lines]
        for line in lines:
            if line.startswith("#"):
                continue
            tokens = pat_sep.split(line)
            try:
                srcName, srcRa, srcDec = tokens
                # Just testing format...
                _, _= float(srcRa), float(srcDec)
            except:
                self.logWarning("Invalid line in the source list : will ignore this [%s]" % (line))
                continue
            src = CCL.FieldSource.EquatorialSource("%s deg" % srcRa,
                                                   "%s deg" % srcDec,
                                                   sourceName=srcName)
            self.logInfo("Adding %s from the external source list [Ra=%s Dec=%s]" % (srcName, srcRa, srcDec))
            self.sourceHelper.specialEquatorialSources.append(src)

    def readSchedule(self, fname=None):
        if fname is None:
            if self.scheduleFilename != "" and os.path.isfile(self.scheduleFilename):
                fname = self.scheduleFilename
                filePath = fname
            else:
                fname = "config/AmpCalSurvey24h_schedule" + self.scheduleFilenameSuffix + ".txt"
                filePath = AcsutilPy.FindFile.findFile(fname)[0]
        self.logInfo("Reading schedule from file '%s'..." % filePath)
        if filePath == '':
            ex = ControlExceptionsImpl.IllegalParameterErrorExImpl()
            ex.setData(Control.EX_USER_ERROR_MSG, 
                       "Unable to find schedule file '%s'" % fname)
            raise ex
        fp = open(filePath, 'r')
        self._schedule = []
        for line in fp:
            line = line.lstrip()
            if len(line)==0 or line[0] == '\n' or line[0] == '#':
                continue
            subs = line.split(None,2)
            if len(subs) < 2:
                self.logWarning("Skipping invalid line: '%s'" % line)
                continue
            if subs[1] == '---':
                continue
            tsSubs = subs[0].split(':')
            ts = 3600*int(tsSubs[0]) + 60*int(tsSubs[1])
            self._schedule.append([ts, subs[1]])
        for i in range(self.numTargets+1):
            self._schedule.append([self._schedule[i][0]+(24*60*60), self._schedule[i][1]])
        self.logInfo("Full schedule list (LST secs, name): %s" % str(self._schedule))

    def testScheduleSourceResolution(self):
        for ent in self._schedule:
            name = ent[1]
            self.logInfo("Checking we can get source '%s'..." % name)
            try:
                self.sourceHelper.getSource(name)
            except:
                self.logError("sourceHelper.getSource('%s') threw an exceptions -- SOMETHING IS WRONG WITH THIS SOURCE" % name)

    def populateSourceList(self, nameList=None):
        self._calSources = []
        if nameList is None:
            nameList = []
            lstSecs = self.sourceHelper.getLSTsecondOfDay()
            self.logInfo("Entering schedule at LST second: %d (%02d:%02d:%02d)" % (lstSecs, lstSecs//3600, (lstSecs//60)%60, lstSecs%60))
            n = 0
            for ent in self._schedule:
                if ent[0] < lstSecs:
                    continue
                srcName = ent[1]
                isAmpCal = srcName in self.sourceHelper.planetSources + self.gridNames
                if not isAmpCal and srcName in self.observedSources:
                    self.logInfo("Source '%s' has already been observed: skip it" % srcName)
                    continue
                self.logInfo("Adding source name '%s' to target list" % srcName)
                nameList.append(srcName)
                n += 1
                if n >= self.numTargets:
                    break
        self.logInfo("Target name list: %s" % str(nameList))
        self._calSources = self.sourceHelper.getSources(nameList, onlyALMA=False)
        self.logInfo("Source list: %s" % str(self._calSources))

    def orderedSpecs(self):
        ret = self._calSpectralSpecs
        if self._reverseSpecs:
            ret = reversed(self._calSpectralSpecs)
        self._reverseSpecs = not self._reverseSpecs
        return ret

    def doPointing(self):
        try:
            pointingCal = Observation.PointingCalTarget.PointingCalTarget(self._srcPointFocus, self._pointFocusSpectralSpec)
            pointingCal.setSubscanDuration(self.pointingSubscanDuration)
            pointingCal.setDataOrigin('CHANNEL_AVERAGE_CROSS')
            pointingCal.setDelayCalReduction(True)
            self.logInfo('Executing PointingCal on ' + self._srcPointFocus.sourceName + '...')
            pointingCal.execute(self._obsmode)
            self.logInfo('Completed PointingCal on ' + self._srcPointFocus.sourceName)
            result = pointingCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0:
                for key in list(result.keys()):
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                pointingCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No pointing results!")
        except BaseException as ex:
            print(ex)
            msg = "Error executing pointing on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    def doFocus(self):
        try:
            focusCal = Observation.FocusCalTarget.FocusCalTarget(
                    SubscanFieldSource = self._srcPointFocus,
                    Axis = 'Z',
                    SpectralSpec = self._pointFocusSpectralSpec,
                    DataOrigin = 'CHANNEL_AVERAGE_CROSS',
                    SubscanDuration = self.focusSubscanDuration,
                    OneWay = False,
                    NumPositions = 7)
            self.logInfo('Executing FocusCal on ' + self._srcPointFocus.sourceName + '...')
            focusCal.execute(self._obsmode)
            self.logInfo('Completed FocusCal on ' + self._srcPointFocus.sourceName)
            result = focusCal.checkResult(self._array)
            self.logInfo("Result is: %s" % str(result))
            if len(result) > 0:
                for key in list(result.keys()):
                    self.logInfo("Found solution for %s using polarization(s) %s" %
                            (key, result[key]))
                focusCal.applyResult(self._obsmode, result)
            else:
                if not "OSS" in self._array._arrayName:
                    raise Exception("No focus results!")
        except BaseException as ex:
            print(ex)
            msg = "Error executing focus on source %s" % self._srcPointFocus.sourceName
            self.logError(msg)
            self.closeExecution(ex)
            raise ex

    # Hack. This should be elsewhere.
    def setMaximumAttenuatorSetting(self):
        if not "OSS" in self._array._arrayName:
            calResults = self._array.getCalResults()
        antennas = self._array.antennas()
        antAtten = []
        maxAtten = Control.AttenuatorSetting(ifProcPol0=[31.5,31.5,31.5,31.5],
                ifProcPol1=[31.5,31.5,31.5,31.5],
                ifswPol0USB=15,
                ifswPol1USB=15,
                ifswPol0LSB=15,
                ifswPol1LSB=15)
        for i in antennas :
            antAtten.append(Control.NamedAttenuatorSetting(i,maxAtten))
        self.logInfo("Setting attenuator setting 'Maximum Attenuator Setting': %s" % str(antAtten))
        if not "OSS" in self._array._arrayName:
            calResults.setAttenuatorSettings("Maximum Attenuator Setting",
                antAtten)

    def doSBRatios(self):
        if not self._doSBRatio:
            return
        self.setMaximumAttenuatorSetting()
        for ss in self.orderedSpecs():
            # We're moving over to using TelCal default SBRs, but in CYCLE2 the
            # default for bands 7 and above is bad, so we need to do one for
            # that. TelCal can only remember one SBR result anyway, so there's
            # no point at all in measuring the other bands.
            if ss.getMeanFrequency() < 275.0e9:
                continue
            try:
                subscanDuration = 0.1e-9 * ss.getMeanFrequency()
                sbrCal = Observation.SBRatioCalTarget.SBRatioCalTarget(
                    SubscanFieldSource = self._srcPointFocus,
                    SpectralSpec = ss,
                    DataOrigin = 'CHANNEL_AVERAGE_CROSS',
                    SubscanDuration = subscanDuration,
                    IntegrationTime = subscanDuration)
                self.logInfo('Executing SBRatioCal on ' + self._srcPointFocus.sourceName + '...')
                sbrCal.execute(self._obsmode)
                self.logInfo('Completed SBRatioCal on ' + self._srcPointFocus.sourceName)
            except BaseException as ex:
                print(ex)
                msg = "Error executing SBRatio on source %s" % self._srcPointFocus.sourceName
                self.logError(msg)
                self.closeExecution(ex)
                raise ex
            # Doing this in here is a bit ugly, but convenient.
            # Idea is to get a real-time display of SBRs from TelCalSpy,
            # plus to measure BB detector zero points
            try:
                atm = Observation.AtmCalTarget.AtmCalTarget(
                    SubscanFieldSource = self._srcPointFocus,
                    SpectralSpec = ss,
                    DataOrigin = 'FULL_RESOLUTION_AUTO',
                    doZero = False,
                    SubscanDuration = self.atmSubscanDuration,
                    IntegrationTime = 1.5,
                    doHotLoad = True)
                atm.setOnlineProcessing(True)
                try:
                    sqlTest = ss.SquareLawSetup.integrationDuration
                    atm.setDataOrigin('TOTAL_POWER')
                    atm.setDoZero(True)
                except: pass
                atm.setWVRCalReduction(True)
                # setting this to True wastes time
                atm.setApplyWVR(False)
                # Currently we need to set a reference source to use an offset.
                atm._referenceSource=atm._source
                try:
                        atm._referenceOffset=CCL.SourceOffset.stroke(self.azoffset,0,0,0,Control.HORIZON)
                        self.logInfo("Using 10.6 and later SourceOffset.stroke for AtmCal reference position")
                except:
                        self.logInfo("Using 10.4 and earlier SourceOffset for AtmCal reference position")
                        atm._referenceOffset=Control.SourceOffset(self.azoffset,0,0,0,Control.HORIZON)
                self.logInfo('Executing AtmCal on ' + self._srcPointFocus.sourceName + '...')
                atm.execute(self._obsmode)
                self.logInfo('Completed AtmCal on ' + self._srcPointFocus.sourceName)
            except BaseException as ex:
                print(ex)
                msg = "Error executing AtmCal on source %s" % self._srcPointFocus.sourceName
                self.logError(msg)
                self.closeExecution(ex)
                raise ex

    def doCalSource(self, src):
        for ss in self.orderedSpecs():
            try:
                atm = Observation.AtmCalTarget.AtmCalTarget(src, ss, doHotLoad=True)
                atm.setOnlineProcessing(True)
                atm.setDataOrigin('FULL_RESOLUTION_AUTO')
                atm.setDoZero(False)
                atm.setSubscanDuration(self.atmSubscanDuration)
                atm.setIntegrationTime(1.5)
                atm.setWVRCalReduction(True)
                # setting this to True wastes time
                atm.setApplyWVR(False)
                # Currently we need to set a reference source to use an offset.
                atm._referenceSource=atm._source
                try:
                        atm._referenceOffset=CCL.SourceOffset.stroke(self.azoffset,0,0,0,Control.HORIZON)
                        self.logInfo("Using 10.6 and later SourceOffset.stroke for AtmCal reference position")
                except:
                        self.logInfo("Using 10.4 and earlier SourceOffset for AtmCal reference position")
                        atm._referenceOffset=Control.SourceOffset(self.azoffset,0,0,0,Control.HORIZON)
                self.logInfo('Executing AtmCal on ' + src.sourceName + '...')
                atm.execute(self._obsmode)
                self.logInfo('Completed AtmCal on ' + src.sourceName)

                band = int(ss.FrequencySetup.receiverBand.replace("ALMA_RB_", ""))
                if band in self.subscanDurationDict:
                    # If subscan duration is explicitly specified by band*subscanDuration
                    # expert parameter, use it.
                    subscanDuration = self.subscanDurationDict[band]
                else:
                    # Otherwise, use original algorithm to determine band-dependent
                    # subscan duration.

                    # TODO: this needs to be configurable somehow.
                    # Thermal vs. syncrotron -- what a hack :)
                    subscanDuration = 120.0
                    if src.sourceName[0].lower() == 'j' and src.sourceName[1].isdigit():
                        f = 1.0e-9 * ss.getMeanFrequency()
                        fs = f / 300.0
                        subscanDuration *= fs*fs
                    # Cap the time, so we don't mess up the schedule too much, and so we don't cause CORBA timeouts
                    if subscanDuration > 180.0:
                        subscanDuration = 180.0
                ampliCal = Observation.AmplitudeCalTarget.AmplitudeCalTarget(src, ss)
                ampliCal.setSubscanDuration(subscanDuration)
                ampliCal.setIntegrationTime(1.0)
                self.logInfo('Executing AmplitudeCal on ' + src.sourceName + '...')
                ampliCal.execute(self._obsmode)
                self.logInfo('Completed AmplitudeCal on ' + src.sourceName)
            except BaseException as ex:
                print(ex)
                msg = "Error executing cal survey scans on source %s" % src.sourceName
                self.logError(msg)
                self.closeExecution(ex)
                raise ex
        self.writeLogEntry(src.sourceName)

    def writeLogEntry(self, sourceName):
        if self.logFile.lower() == 'NoLog'.lower():
            self.logInfo("writeLogEntry(): log file is '%s', won't write log entry" % self.logFile)
            return
        self.logInfo("writeLogEntry(): adding record for '%s'" % sourceName)
        if Observation.Global.simulatedArray():
            return
        oldUmask = os.umask(0)
        fd = os.open(os.path.expanduser(self.logFile), os.O_WRONLY | os.O_CREAT | os.O_APPEND, 0o666)
        fp = os.fdopen(fd, "a")
        fp.write("%s, %s\n" % (str(sourceName), str(self.uid)))
        fp.close()
        os.umask(oldUmask)

    def doCalObservations(self):
        for src in self._calSources:
            if not self.isObservable(src, 600):
                self.logInfo("Skipping source '%s' as not observable" % src.sourceName)
                continue
            self.doCalSource(src)


obs = AmpCalSurvey24h()
obs.parseOptions()
obs.checkAntennas()
obs.startPrepareForExecution()
try:
    obs.generateTunings()
    obs.readExtraSourceList()
    obs.readObservingLog()
    obs.readSchedule()
    # This is for debugging only
    if obs.doTestSourceResolution:
        obs.testScheduleSourceResolution()
    obs.populateSourceList()
    obs.findPointFocusSource()
except BaseException as ex:
    obs.logException("Error in methods run during execution/obsmode startup", ex)
    obs.completePrepareForExecution()
    obs.closeExecution(ex)
    raise ex
obs.completePrepareForExecution()
obs.logInfo("Executing first pointing...")
obs.doPointing()
obs.logInfo("Executing second pointing -- make sure results are good!...")
obs.doPointing()
obs.logInfo("Executing focus...")
obs.doFocus()
obs.logInfo("Executing third pointing after focus -- make sure results are good!...")
obs.doPointing()
obs.logInfo("Executing Calibration observations...")
obs.doCalObservations()
obs.closeExecution()
