#!/usr/libexec/platform-python -s
#  -*- coding: utf-8 -*-
# *****************************************************************************
# MLZ library of Tango servers
# Copyright (c) 2015-2020 by the authors, see LICENSE
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
# Module authors:
#   Sandra Seger <sandra.seger@frm2.tum.de>
#
# *****************************************************************************

from __future__ import print_function

import sys
import time
import argparse
from os import path

import serial
import PyTango
from PyTango import DevFailed
from serial.serialutil import SerialException

# Add import path for inplace usage
sys.path.insert(0, path.abspath(path.join(path.dirname(__file__), '..')))

# pylint: disable=wrong-import-position
from entangle.lib.toml import Parser
from entangle.lib.pycompat import unescape, str2bytes, bytes2str, list2str


DEFAULTS = {
    'baudrate': [9600, 19200, 38400, 57600, 115200],
    'bytesize': [7, 8],
    'stopbits': [1, 2],
    'parity': ['none', 'even', 'odd'],
    'flowctrl': ['none', 'rtscts', 'xonxoff'],
    'timeout': 0.2,
    'eol': '\n',
    'tango': False,
    'command': '*IDN?',
    'quick': False,
}

VALID_VALUES = {
    'baudrate': [110, 300, 600, 1200, 2400, 4800, 9600, 14400,
                 19200, 38400, 57600, 115200, 128000, 256000],
    'bytesize': [5, 7, 8],
    'stopbits': [1, 1.5, 2],
    'parity': ['none', 'even', 'odd', 'mark', 'space'],
    'flowctrl': ['none', 'rtscts', 'xonxoff', 'dsrdtr'],
}

PARITY = {
    'none': serial.PARITY_NONE,
    'even': serial.PARITY_EVEN,
    'odd': serial.PARITY_ODD,
    'mark': serial.PARITY_MARK,
    'space': serial.PARITY_SPACE,
}


class SerialArgumentParser(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser(
            description='Scan port with different combinations of serial '
            'parameters.')
        self.add_arguments()
        self.args = self.parser.parse_args()
        self.arguments = vars(self.args)
        self.valid_values = self.validate_values()

    def get_arguments(self):
        return self.arguments

    def validate_values(self):
        if self.args.file:
            with open(self.arguments['file']) as fp:
                content = fp.read().lstrip()
            config = Parser(self.arguments['file'], content).parse_doc()
            constants = config.pop('defaults', {})
        if self.args.quick:
            DEFAULTS.update({
                'bytesize': [8],
                'stopbits': [1],
                'parity': ['none'],
                'flowctrl': ['none'],
            })
        for arg in self.arguments:
            if self.args.file:
                # set parameters from file
                try:
                    if not self.arguments[arg]:
                        self.arguments[arg] = constants[arg]
                except KeyError:
                    pass
            if not self.arguments[arg]:
                # set default parameters
                try:
                    self.arguments[arg] = DEFAULTS[arg]
                except KeyError:
                    if arg == 'file':
                        continue
                    print('You need to configure %s!' % arg)
                    return False
            # check if value is valid
            try:
                if not isinstance(self.arguments[arg], type(VALID_VALUES[arg])):
                    print('%s must be %s not %s' %
                          (arg, type(VALID_VALUES[arg]),
                           type(self.arguments[arg])))
                    return False
                for argument in self.arguments[arg]:
                    if argument == 'all':
                        self.arguments[arg] = VALID_VALUES[arg]
                        break
                    if argument not in VALID_VALUES[arg]:
                        if arg in ['baudrate', 'bytesize', 'stopbits']:
                            try:
                                if float(argument) in VALID_VALUES[arg]:
                                    continue
                            except ValueError:
                                pass
                        print('%r is not a valid value for %s. Must be '
                              'one of %s!' % (argument, arg, VALID_VALUES[arg]))
                        return False
            except KeyError:
                pass
        self.arguments['baudrate'] = [int(b) for b in self.arguments['baudrate']]
        self.arguments['bytesize'] = [int(b) for b in self.arguments['bytesize']]
        self.arguments['stopbits'] = [float(b) for b in self.arguments['stopbits']]
        return True

    def add_arguments(self):
        self.parser.add_argument('dev', help='Device path.')
        self.parser.add_argument(
            '--file', help='TOML file to read parameters from.')
        self.parser.add_argument(
            '-t', '--tango', action='store_true',
            help='Device is a Tango device.')
        self.parser.add_argument(
            '-c', '--command', help='Command to send (default: %r).' %
            DEFAULTS['command'])
        self.parser.add_argument(
            '-i', '--timeout', type=float,
            help='Timeout for communication in seconds (default: %s).' %
            DEFAULTS['timeout'])
        self.parser.add_argument(
            '-e', '--eol', help='End-of-line character to use (default: %r).' %
            DEFAULTS['eol'])
        self.parser.add_argument(
            '-q', '--quick', action='store_true',
            help='Quick scan (only default baudrates).')
        self.parser.add_argument(
            '-b', '--baudrate', action='append',
            help='Add a baudrate to test (default: %s).' %
            ', '.join(map(str, DEFAULTS['baudrate'])))
        self.parser.add_argument(
            '-p', '--parity', action='append',
            help='Add a parity to test (default: %s).' %
            ', '.join(DEFAULTS['parity']))
        self.parser.add_argument(
            '-y', '--bytesize', action='append',
            help='Add a bytesize to test (default: %s).' %
            ', '.join(map(str, DEFAULTS['bytesize'])))
        self.parser.add_argument(
            '-s', '--stopbits', action='append',
            help='Add a stopbit size to test (default: %s).' %
            ', '.join(map(str, (DEFAULTS['stopbits']))))
        self.parser.add_argument(
            '-f', '--flowctrl', action='append',
            help='Add a flow control method to test (default: %s).' %
            ', '.join(DEFAULTS['flowctrl']))


class CommunicationDevice(object):

    def __init__(self):
        pass

    def set_parameters(self, baudrate, bytesize, parity, stopbits, flowctrl):
        pass

    def write(self, string):
        pass

    def read_one_byte(self):
        pass

    def communicate(self, cmd, eol, timeout):
        self.write(cmd + eol)
        end_time = time.time() + timeout
        reply = self.read_one_byte()
        while time.time() <= end_time:
            reply += self.read_one_byte()
            if reply:
                if reply.endswith(eol):
                    break
        return reply


class SerialCommunicationDevice(CommunicationDevice):
    def __init__(self, device_path):
        CommunicationDevice.__init__(self)
        self.dev = serial.Serial(device_path)
        self.dev.timeout = 0.01

    def set_parameters(self, baudrate, bytesize, parity, stopbits, flowctrl):
        self.dev.baudrate = baudrate
        self.dev.bytesize = bytesize
        self.dev.parity = PARITY[parity]
        self.dev.stopbits = stopbits
        if flowctrl == 'none':
            self.dev.rtscts = False
            self.dev.xonxoff = False
            self.dev.dsrdtr = False
        elif flowctrl == 'rtscts':
            self.dev.rtscts = True
            self.dev.xonxoff = False
            self.dev.dsrdtr = False
        elif flowctrl == 'xonxoff':
            self.dev.rtscts = False
            self.dev.xonxoff = True
            self.dev.dsrdtr = False
        elif flowctrl == 'dsrdtr':
            self.dev.rtscts = False
            self.dev.xonxoff = False
            self.dev.dsrdtr = True

    def write(self, string):
        self.dev.write(str2bytes(string))

    def read_one_byte(self):
        return bytes2str(self.dev.read(1))


class TangoCommunicationDevice(CommunicationDevice):
    def __init__(self, device_path):
        CommunicationDevice.__init__(self)
        self.dev = PyTango.DeviceProxy(device_path)
        self.binary_io = any(cmd.cmd_name == 'BinaryCommunicate'
                             for cmd in self.dev.command_list_query())

    def set_parameters(self, baudrate, bytesize, parity, stopbits, flowctrl):
        self.dev.SetProperties(['baudrate', str(baudrate),
                                'parity', str(parity),
                                'bytesize', str(bytesize),
                                'stopbits', str(stopbits),
                                'flowctrl', str(flowctrl)])

    def write(self, string):
        if self.binary_io:
            self.dev.BinaryWrite([ord(c) for c in string])
        else:
            self.dev.Write(string)

    def read_one_byte(self):
        if self.binary_io:
            return list2str(self.dev.BinaryRead(1))
        else:
            return self.dev.Read(1)


def format_params(combination):
    return '{0[0]: <9}{0[1]: <3}{0[2]: <7}{0[3]: <4}{0[4]: <9}'.format(
        combination)


def main():
    parser = SerialArgumentParser()
    if not parser.valid_values:
        return 1
    arguments = parser.get_arguments()
    if arguments['tango']:
        try:
            dev = TangoCommunicationDevice(arguments['dev'])
        except DevFailed:
            print('Could not connect to device: %s' % arguments['dev'])
            return 1
    else:
        try:
            dev = SerialCommunicationDevice(arguments['dev'])
        except SerialException:
            print('Device %s does not exist.' % arguments['dev'])
            return 1

    arguments['command'] = unescape(arguments['command'])
    arguments['eol'] = unescape(arguments['eol'])
    combinations = [[b, bs, p, sb, fc] for b in arguments['baudrate'] for bs
                    in arguments['bytesize'] for p in arguments['parity'] for
                    sb in arguments['stopbits'] for fc in arguments['flowctrl']]
    count = 1
    total = len(combinations)
    working_combinations = []

    print('You are testing %s with these parameters:' % arguments['dev'])
    for arg in ['baudrate', 'bytesize', 'parity', 'stopbits', 'flowctrl']:
        print('  {0: <10} {1}'.format(arg + ':', ', '.join(map(str, arguments[arg]))))
    for arg in ['eol', 'command', 'timeout']:
        print('  {0: <10} {1}'.format(arg + ':', repr(arguments[arg])))
    print('')

    for combination in combinations:
        dev.set_parameters(*combination)
        print('[{0:3}/{1}]  Trying  {2}'.format(count, total,
                                                format_params(combination)),
              end='')
        sys.stdout.flush()
        reply = dev.communicate(arguments['command'], arguments['eol'],
                                arguments['timeout'])
        if reply:
            working_combinations.append(combination + [reply])
            print('  -> reply: %r' % reply)
        else:
            print('')
        count += 1

    if working_combinations:
        print('\nThese combinations got a reply:')
        for combi in working_combinations:
            print('%s-> %r' % (format_params(combi), combi[5]))
        return 0
    else:
        print('\nGot no reply!')
        return 2


if __name__ == '__main__':
    sys.exit(main())
