#!/usr/bin/python
#
# Copyright (C) 2011 Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file
# An HTCondor DAG generator to recolor frame data

"""
This program makes a dag to recolor frames
"""

__author__ = 'Chad Hanna <chad.hanna@ligo.org>'


##############################################################################
# import standard modules
import sys, os, copy, math
from optparse import OptionParser
import subprocess, socket, tempfile

##############################################################################
# import the modules we need to build the pipeline
from glue import pipeline
from ligo import segments
import glue.ligolw.utils as ligolw_utils
import glue.ligolw.utils.segments as ligolw_segments
from gstlal import datasource
from gstlal import dagparts
from lal import series as lalseries
from lal.utils import CacheEntry

#
# Classes for generating reference psds
#

class gstlal_reference_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_reference_psd job
	"""
	def __init__(self, group, user, executable=dagparts.which('gstlal_reference_psd'), tag_base='gstlal_reference_psd'):
		"""
		"""
		self.__prog__ = 'gstlal_reference_psd'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		if group is not None:
			self.add_condor_cmd('accounting_group', group)
		if user is not None:
			self.add_condor_cmd('accounting_group_user', user)
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_median_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_median_psd job
	"""
	def __init__(self, group, user, executable=dagparts.which('gstlal_median_of_psds'), tag_base='gstlal_median_of_psds'):
		"""
		"""
		self.__prog__ = 'gstlal_median_of_psds'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		if group is not None:
			self.add_condor_cmd('accounting_group', group)
		if user is not None:
			self.add_condor_cmd('accounting_group_user', user)
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_smooth_reference_psd_job(pipeline.CondorDAGJob):
	"""
	A gstlal_smooth_reference_psd job
	"""
	def __init__(self, group, user, executable=dagparts.which('gstlal_psd_polyfit'), tag_base='gstlal_psd_polyfit'):
		"""
		"""
		self.__prog__ = 'gstlal_psd_polyfit'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		if group is not None:
			self.add_condor_cmd('accounting_group', group)
		if user is not None:
			self.add_condor_cmd('accounting_group_user', user)
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_reference_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_reference_psd node
	"""
	def __init__(self, job, dag, frame_cache, gps_start_time, gps_end_time, instrument, channel, injections=None, p_node=[]):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("gps-start-time", gps_start_time)
		self.add_var_opt("gps-end-time", gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_arg("--channel-name=%s=%s" % (instrument, channel))
		if injections:
			self.add_var_opt("injections", injections)
		path = os.getcwd()
		output_name = self.output_name = '%s/%s-%d-%d-reference_psd.xml.gz' % (path, instrument, gps_start_time, gps_end_time)
		self.add_var_opt("write-psd",output_name)
		dag.output_cache.append(CacheEntry(instrument, "-", segments.segment(gps_start_time, gps_end_time), "file://localhost/%s" % (output_name,)))
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


class gstlal_smooth_reference_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_smooth_reference_psd node
	"""
	def __init__(self, job, dag, instrument, input_psd, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		path = os.getcwd()
		#FIXME shouldn't be hardcoding stuff like this
		output_name = self.output_name = input_psd.replace('reference_psd', 'smoothed_reference_psd')
		self.add_var_arg(input_psd)
		self.add_var_opt("output", output_name)
		self.add_var_opt("low-fit-freq", 10)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


class gstlal_median_psd_node(pipeline.CondorDAGNode):
	"""
	A gstlal_median_psd node
	"""
	def __init__(self, job, dag, input_psds, output, p_node=[]):
		pipeline.CondorDAGNode.__init__(self,job)
		path = os.getcwd()
		#FIXME shouldn't be hardcoding stuff like this
		output_name = self.output_name = output
		self.add_var_opt("output-name", output_name)
		for psd in input_psds:
			self.add_file_arg(psd)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


#
# classes for generating recolored frames
#

class gstlal_fake_frames_job(pipeline.CondorDAGJob):
	"""
	A gstlal_fake_frames job
	"""
	def __init__(self, group, user, executable=dagparts.which('gstlal_fake_frames'), tag_base='gstlal_fake_frames'):
		"""
		"""
		self.__prog__ = 'gstlal_fake_frames'
		self.__executable = executable
		self.__universe = 'vanilla'
		pipeline.CondorDAGJob.__init__(self,self.__universe,self.__executable)
		self.add_condor_cmd('getenv','True')
		self.add_condor_cmd('requirements', 'Memory > 1999') #FIXME is this enough?
		self.tag_base = tag_base
		self.add_condor_cmd('environment',"KMP_LIBRARY=serial;MKL_SERIAL=yes")
		self.set_sub_file(tag_base+'.sub')
		if group is not None:
			self.add_condor_cmd('accounting_group', group)
		if user is not None:
			self.add_condor_cmd('accounting_group_user', user)
		self.set_stdout_file('logs/'+tag_base+'-$(macroid)-$(process).out')
		self.set_stderr_file('logs/'+tag_base+'-$(macroid)-$(process).err')


class gstlal_fake_frames_node(pipeline.CondorDAGNode):
	"""
	A gstlal_fake_frames node
	"""
	def __init__(self, job, dag, frame_cache, gps_start_time, gps_end_time, channel, reference_psd, color_psd, sample_rate, injections=None, output_channel_name = None, duration = 4096, output_path = None, frame_type = None, shift = None, whiten_track_psd = False, frames_per_file = 1, p_node=[]):

		pipeline.CondorDAGNode.__init__(self,job)
		self.add_var_opt("frame-cache", frame_cache)
		self.add_var_opt("gps-start-time",gps_start_time)
		self.add_var_opt("gps-end-time",gps_end_time)
		self.add_var_opt("data-source", "frames")
		self.add_var_arg("--channel-name=%s=%s" % (instrument, channel))
		self.add_var_opt("whiten-reference-psd",reference_psd)
		self.add_var_opt("color-psd", color_psd)
		self.add_var_opt("sample-rate", sample_rate)
		if injections is not None:
			self.add_var_opt("injections", injections)
		self.add_var_opt("output-channel-name", output_channel_name)
		self.add_var_opt("frame-duration", duration)
		if output_path is not None:
			self.add_var_opt("output-path", output_path)
		self.add_var_opt("frame-type", frame_type)
		if whiten_track_psd:
			self.add_var_opt("whiten-track-psd",reference_psd)
		if shift:
			self.add_var_opt("shift", shift)
		self.add_var_opt("frames-per-file", frames_per_file)
		for p in p_node:
			self.add_parent(p)
		dag.add_node(self)


def choosesegs(seglists, min_segment_length):
	for instrument, seglist in seglists.iteritems():
		newseglist = segments.segmentlist()
		for seg in seglist:
			if abs(seg) > min_segment_length:
				newseglist.append(segments.segment(seg))
		seglists[instrument] = newseglist


def parse_command_line():
	parser = OptionParser(description = __doc__)
	
	parser.add_option("--frame-cache", metavar = "filename", help = "Set the name of the LAL cache listing the LIGO-Virgo .gwf frame files (optional)")
	parser.add_option("--injections", metavar = "filename", help = "Set the name of the LIGO light-weight XML file from which to load injections (optional).")
	parser.add_option("--channel-name", metavar = "name", action = "append", help = "Set the name of the channels to process.  Can be given multiple times as --channel-name=IFO=CHANNEL-NAME")
	parser.add_option("--frame-segments-file", metavar = "filename", help = "Set the name of the LIGO light-weight XML file from which to load frame segments. Required")
	parser.add_option("--frame-segments-name", metavar = "name", help = "Set the name of the segments to extract from the segment tables. Required")
 
	parser.add_option("--min-segment-length", metavar = "SECONDS", help = "Set the minimum segment length to process (required)", type="float")
	parser.add_option("--shift", metavar = "NANOSECONDS", help = "Number of nanoseconds to delay (negative) or advance (positive) the time stream", type = "int")
	parser.add_option("--sample-rate", metavar = "HZ", default = 16384, type = "int", help = "Sample rate at which to generate the data, should be less than or equal to the sample rate of the measured psds provided, default = 16384 Hz")
	parser.add_option("--whiten-type", metavar="name", help = "Whiten whatever data is coming out of datasource either from the data or from a fixed reference psd if a file is given.  Options are psdperseg|medianofpsdperseg|FILE")
	parser.add_option("--whiten-track-psd", action = "store_true", help = "Calculate PSD from input data and track with time.")
	parser.add_option("--color-psd", metavar = "FILE", help = "Set the name of psd xml file to color the data with")
	parser.add_option("--output-path", metavar = "name", action = "append", help = "Set the instrument dependent output path for frames, defaults to current working directory. eg H1=/path/to/H1/frames. Can be given more than once.")
	parser.add_option("--output-channel-name", metavar = "name", action="append", help = "The name of the channel in the output frames, e.g., --output-channel-name=IFO=CHANNEL-NAME. The default is the same as the channel name. Can be given more than once. Required ")
	parser.add_option("--frame-type", metavar = "name", action = "append", help = "Set the instrument dependent frame type, H1=TYPE. Can be given more than once and is required for each instrument processed.")
	parser.add_option("--frame-duration", metavar = "SECONDS", default = 16, type = "int", help = "Set the duration of the output frames.  The duration of the frame file will be multiplied by --frames-per-file.  Default: 16s")
	parser.add_option("--frames-per-file", metavar = "INT", default = 256, type = "int", help = "Set the number of frames per file.  Default: 256")
	parser.add_option("--accounting-group", metavar = "str", help = "Set the accounting group name, e.g., ligo.dev.o3.cw.directedbinary.production")
	parser.add_option("--accounting-group-user", metavar = "str", help = "Set the accounting group user, e.g., chad.hanna")
	parser.add_option("--verbose", action = "store_true", help = "Be verbose")
	
	options, filenames = parser.parse_args()

	fail = ""
	for option in ("min_segment_length", "frame_type", "frame_segments_file", "frame_segments_name"):
		if getattr(options, option) is None:
			fail += "must provide option %s\n" % (option)
	if fail:
		raise ValueError(fail)

	inchannels = datasource.channel_dict_from_channel_list(options.channel_name)
	outchannels = datasource.channel_dict_from_channel_list(options.output_channel_name)
	frametypes = datasource.channel_dict_from_channel_list(options.frame_type)
	outpaths = datasource.channel_dict_from_channel_list(options.output_path)

	if not (set(frametypes) == set(inchannels) == set(outchannels)):
		raise ValueError('--frame-type, --channel-name and --output-channel-name must contain same instruments')

	return options, inchannels, outchannels, outpaths, frametypes, filenames


options, inchannels, outchannels, outpaths, frametypes, filenames = parse_command_line()

try:
	os.mkdir("logs")
except:
	pass

dag = dagparts.CondorDAG("gstlal_fake_frames_pipe")

seglists = ligolw_segments.segmenttable_get_by_name(ligolw_utils.load_filename(options.frame_segments_file, verbose = options.verbose, contenthandler = ligolw_segments.LIGOLWContentHandler), options.frame_segments_name).coalesce()
choosesegs(seglists, options.min_segment_length)

psdJob = gstlal_reference_psd_job(options.accounting_group, options.accounting_group_user)
smoothJob = gstlal_smooth_reference_psd_job(options.accounting_group, options.accounting_group_user)
medianJob = gstlal_median_psd_job(options.accounting_group, options.accounting_group_user)
colorJob = gstlal_fake_frames_job(options.accounting_group, options.accounting_group_user)

smoothnode = {}
mediannode = {}
p_node = dict([(i, []) for i in seglists])

if options.whiten_type in ("psdperseg", "medianofpsdperseg"):
	psd = {}
	for instrument, seglist in seglists.iteritems():
		mediannode[instrument] = {}
		smoothnode[instrument] = {}
		psd[instrument] = {}
		for seg in seglist:
			#FIXME if there are sements without frame caches this will barf
			psdnode = gstlal_reference_psd_node(psdJob, dag, options.frame_cache, int(seg[0]), int(seg[1]), instrument, inchannels[instrument], injections=None, p_node=[])
			smoothnode[instrument][seg] = gstlal_smooth_reference_psd_node(smoothJob, dag, instrument, psdnode.output_name,  p_node=[psdnode])
			if options.whiten_type == "psdperseg":
				psd[instrument][seg] = smoothnode[instrument][seg].output_name

		mediannode[instrument] = gstlal_median_psd_node(medianJob, dag, [v.output_name for v in smoothnode[instrument].values()], "%s_median_psd.xml.gz" % instrument, p_node=smoothnode[instrument].values())
		p_node[instrument] = [mediannode[instrument]]
		if options.whiten_type == "medianofpsdperseg":
			psd[instrument] = mediannode[instrument].output_name

elif options.whiten_type is not None:
	psd = lalseries.read_psd_xmldoc(ligolw_utils.load_filename(options.whiten_reference_psd, verbose = options.verbose, contenthandler = lalseries.PSDContentHandler))
else:
	psd = dict([(i, None) for i in seglists])

for instrument, seglist in seglists.iteritems():
	try:
		output_path = outpaths[instrument]
	except KeyError:
		output_path = None
	for seg in seglist:
		try:
			reference_psd = psd[instrument][seg]
		except TypeError:
			reference_psd = psd[instrument]
		gstlal_fake_frames_node(colorJob, dag, options.frame_cache, int(seg[0]), int(seg[1]), inchannels[instrument], reference_psd, color_psd=options.color_psd, sample_rate = options.sample_rate, injections=options.injections, output_channel_name = outchannels[instrument], output_path = output_path, duration = options.frame_duration, frame_type = frametypes[instrument], shift = options.shift, whiten_track_psd = options.whiten_track_psd, frames_per_file = options.frames_per_file, p_node=p_node[instrument])
		
dag.write_sub_files()
dag.write_dag()
dag.write_script()
dag.write_cache()
