#!/usr/bin/python
#
# Copyright (C) 2012  Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


### Plot the horizon distance as a function of time from PSDs computed from
### gstlal_reference_psd


import sys
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot
import numpy
import lal
import lal.series
from glue.ligolw import utils as ligolw_utils
from gstlal import plotutil
from gstlal import reference_psd

if len(sys.argv) < 3:
	print "USAGE gstlal_plot_psd output_name psd_file1 psd_file2 ..."
	sys.exit()

outname = sys.argv[1]
horizons = {}
times = {}
for f in sys.argv[2:]:
	psds = lal.series.read_psd_xmldoc(ligolw_utils.load_filename(f, verbose = True, contenthandler = lal.series.PSDContentHandler))
	for ifo, psd in psds.items():
		if psd is not None:
			times.setdefault(ifo, []).append(int(psd.epoch))
			horizons.setdefault(ifo, []).append(reference_psd.HorizonDistance(10., 2048., psd.deltaF, 1.4, 1.4)(psd, 8.)[0])

pyplot.figure(figsize=(12,4))
pyplot.subplot(121)
minh, maxh = (float("inf"), 0)
mint = min([min(t) for t in times.values() if t])
for ifo in horizons:
	if len(horizons[ifo]) > 0:
		pyplot.semilogy((numpy.array(times[ifo]) - mint) / 1000., horizons[ifo], 'x', color = plotutil.colour_from_instruments([ifo]), label = ifo)
		maxh = max(maxh, max(horizons[ifo]))
		minh = min(minh, min(horizons[ifo]))
#pyplot.legend()
pyplot.xlabel('Time (ks) from GPS %d' % mint)
pyplot.ylabel('Mpc')
pyplot.grid()
pyplot.subplot(122)
binvec = numpy.linspace(minh, maxh, 25)
for ifo in horizons:
	if len(horizons[ifo]) > 0:
		pyplot.hist(horizons[ifo], binvec, color = plotutil.colour_from_instruments([ifo]), alpha = 0.5, label = ifo)
pyplot.legend()
pyplot.xlabel("Mpc")
pyplot.ylabel("Count")
pyplot.savefig(outname)
