#!/usr/bin/python3 -s
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2019)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Check whether a list of channel records has changed between two reference
times
"""

import re
import os.path

import numpy

import gwdatafind
from gwpy.table import Table
from gwpy.io import gwf as io_gwf

from gwdetchar import cli
from gwdetchar.io.datafind import get_data

__author__ = 'Alex Urban <alexander.urban@ligo.org>'
__credits__ = 'Andrew Lundgren <andrew.lundgren@ligo.org>, ' \
              'Joshua Smith <joshua.smith@ligo.org>, ' \
              'Duncan Macleod <duncan.macleod@ligo.org>'


# -- parse command line -------------------------------------------------------

parser = cli.create_parser(description=__doc__)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
cli.add_frametype_option(parser,
                         help='the frametype name, defaults to second trends '
                              'for the selected interferometer')
cli.add_nproc_option(parser)
parser.add_argument('-o', '--output', default='changes.csv',
                    help='Path to output data file, default: %(default)s')
parser.add_argument('-c', '--channels', default=None, required=False,
                    help="file containing columnar list of channels to "
                         "process, default is to find all relevant channels "
                         "from frames")
parser.add_argument('-s', '--search', nargs='*', default=[],
                    help='process channels matching these regex patterns, '
                         'can be given multiple times, default is to analyze '
                         'all relevant channels from frames')
parser.add_argument('-p', '--preview', default=10, type=int,
                    help='Time (seconds) to test that channel is typically '
                         'kept constant, default: %(default)s')

args = parser.parse_args()

# set up logger
logger = cli.logger(name=os.path.basename(__file__))

# get IFO and frametype
ifo = args.ifo.upper()
obs = ifo[0]
frametype = args.frametype or '{}_T'.format(ifo)
preview_time = max(1, args.preview)

# get paths to frame files
cache1 = gwdatafind.find_urls(obs, frametype, args.gpsstart-preview_time,
                              args.gpsstart)
cache2 = gwdatafind.find_urls(obs, frametype, args.gpsend, args.gpsend+1)

# get list of channels to analyze
logger.info('Determining channels to analyze')
available = set(io_gwf.iter_channel_names(cache1[-1]))
available &= set(io_gwf.iter_channel_names(cache2[0]))
if args.channels:
    allchannels = set(numpy.loadtxt(args.channels, dtype=str, ndmin=1))
    channels = list(allchannels & available)
else:
    channels = [x for x in available if x.endswith('.mean')]
if args.search:  # if requested, search for channels matching regex patterns
    re_requested = re.compile('({})'.format('|'.join(args.search)))
    channels = [x for x in channels if re_requested.search(x)]
logger.debug('Found {} channels in frames'.format(len(channels)))

# get data from frames
data1 = get_data(
    channels, start=args.gpsstart-preview_time, end=args.gpsstart,
    source=cache1, nproc=args.nproc, verbose='Reading initial data:'.rjust(30))
data2 = get_data(
    channels, start=args.gpsend, end=args.gpsend+1, source=cache2,
    nproc=args.nproc, verbose='Reading final data:'.rjust(30))

logger.debug('Analyzing {} channels'.format(len(data1)))

# initialize columns
changes = []
value1 = []
value2 = []
diff = []

# identify channels
for channel in data1:
    xoft1 = data1[channel].value
    xoft2 = data2[channel].value
    if numpy.any(numpy.diff(xoft1) != 0):
        continue
    if xoft1[-1] == xoft2[0]:
        continue
    changes.append(channel)
    value1.append(xoft1[-1])
    value2.append(xoft2[0])
    diff.append(xoft2[0] - xoft1[-1])

# record output
logger.debug('Analysis complete')
table = Table([changes, value1, value2, diff],
              names=('channel', 'initial_value', 'final_value', 'difference'))

# log output
logger.info('The following {} channels record a state change between '
            '{} and {}:\n\n'.format(len(changes), args.gpsstart, args.gpsend))
print(table)
print('\n\n')

# save output
table.write(args.output, overwrite=True)
logger.info('Output written to {}'.format(args.output))
