#!/usr/bin/python3

# Copyright (c) 2021, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only

"""
A small daemon to monitor and control GNSS devices.
"""

import importlib
import os
import select
import signal
import sys
import systemd.daemon
import vci
import zmq

from vyatta import configd


class VyattaGnssDaemon:
    """
    This daemon is responsible for loading the GNSS plugins, scanning
    for the GNSS devices, and providing a message based service for
    communication.
    """
    POLL_TIMEOUT = 10000                # 10 seconds
    instance_count = 0

    PUB_ENDPOINT = "ipc:///tmp/gnssd_pub.socket"
    REP_ENDPOINT = "ipc:///tmp/gnssd_rep.socket"

    def __init__(self):
        def sigusr1_handler(signum, frame):
            """
            Force poller to exit early
            """
            raise InterruptedError()

        # Use SIGUSR1 to trigger a configuration update
        signal.signal(signal.SIGUSR1, sigusr1_handler)

        # Create a PUB socket
        self.ctx = zmq.Context.instance()
        self.pub_socket = self.ctx.socket(zmq.PUB)
        self.pub_socket.bind(self.PUB_ENDPOINT)

        # Create a REP socket
        self.rep_socket = self.ctx.socket(zmq.REP)
        listen_fds = systemd.daemon.listen_fds()
        if len(listen_fds) >= 1:
            self.rep_socket.set(zmq.USE_FD, listen_fds[0])
        self.rep_socket.bind(self.REP_ENDPOINT)

        # Allow others to talk to us
        if self.REP_ENDPOINT.startswith("ipc://"):
            os.chmod(self.REP_ENDPOINT[6:], 0o770)

        self.instances = []
        self.load_gnss_modules()
        self.running = True
        self.client = None
        self.update_configurations()

    def register(self, instance):
        """
        Given to plugins to register their device with the daemon.
        """
        self.instances.append(instance)
        instance.set_instance(self.instance_count)
        instance.update_hardware_status()
        self.instance_count += 1

    def load_gnss_modules(self):
        """
        Load and probe using any discovered GNSS plugins.
        """
        def iterate_modules():
            """
            Search all the python library paths for GNSS plugins.
            """
            for path in sys.path:
                mod_path = os.path.join(path, 'vyatta', 'gnss-plugins')
                if not os.path.isdir(mod_path):
                    continue
                for dentry in os.scandir(mod_path):
                    if dentry.is_file() and \
                      (dentry.name.endswith('.py')
                       or dentry.name.endswith('.pyc')):
                        mod_name = 'vyatta.gnss-plugins.' + \
                          os.path.splitext(dentry.name)[0]
                        yield mod_name

        for module_path in iterate_modules():
            class_name = module_path.split('.')[-1]
            try:
                module = importlib.import_module(module_path)
                driver = getattr(module, class_name)
                print('Probing for', class_name, 'devices')
                device = driver()
                if device:
                    device.probe(self.register)
            except:
                print('Failed to load', class_name)

    def get_rep_fd(self):
        """
        Get the reply socket from a ZMQ connection.
        """
        return self.rep_socket.get(zmq.FD)

    def get_status(self):
        """
        Fetch the status from each GNSS device.
        """
        status = []
        for instance in self.instances:
            status.append(instance.get_status())
        return {'data': status, 'result': 'OK'}

    def do_update_hardware_status(self):
        """
        Call the update_hardware_status method for each device.
        """
        for instance in self.instances:
            instance.update_hardware_status()

    def find_instance(self, instance):
        """
        Find the GNSS device for the registered instance.
        """
        for the_instance in self.instances:
            if the_instance.get_instance() == instance:
                return the_instance
        return None

    def do_stop(self, json):
        """
        Stop a GNSS device.
        """
        if 'instance' in json:
            instance = self.find_instance(json['instance'])
            if not instance:
                print('Failed to find GNSS instance')
                return {'data': 'No such GNSS', 'result': 'FAIL'}
            status = instance.stop()
            if not status:
                print('Failed to stop GNSS instance', instance.get_instance())
                return {'data': 'Failed to stop GNSS', 'result': 'FAIL'}
        return {'data': '', 'result': 'OK'}

    def do_start(self, json):
        """
        Stop a GNSS device.
        """
        if 'instance' in json:
            instance = self.find_instance(json['instance'])
            if not instance:
                print('Failed to find GNSS instance')
                return {'data': 'No such GNSS instance', 'result': 'FAIL'}
            status = instance.start()
            if not status:
                print('Failed to start GNSS instance', instance.get_instance())
                return {'data': 'Failed to start GNSS', 'result': 'FAIL'}
        return {'data': '', 'result': 'OK'}

    def do_shutdown(self):
        """
        Shutdown the GNSS daemon.
        """
        self.running = False
        return {'data': '', 'result': 'OK'}

    def handle_rep_fd_event(self):
        """
        Handle the command requests over the ZMQ socket.
        """
        while self.rep_socket.getsockopt(zmq.EVENTS) & zmq.POLLIN:
            json = self.rep_socket.recv_json()
            rep_json = {'data': 'missing command', 'result': 'FAIL'}
            if 'command' in json:
                command = json['command']
                if command == 'STATUS':
                    rep_json = self.get_status()
                elif command == 'STOP':
                    rep_json = self.do_stop(json)
                elif command == 'START':
                    rep_json = self.do_start(json)
                elif command == 'SHUTDOWN':
                    rep_json = self.do_shutdown()
                else:
                    rep_json = {'data': 'unknown command', 'result': 'FAIL'}

            self.rep_socket.send_json(rep_json)

    def main_loop(self, file_evmask_tuple_list):
        """
        Read forever from the ZMQ socket. However, if we timeout because
        the socket is idle, update the hardware status on the GNSS devices.
        """
        poller = select.poll()

        for (fd, evmask) in file_evmask_tuple_list:
            poller.register(fd, evmask)

        while self.running:
            try:
                events = poller.poll(self.POLL_TIMEOUT)
            except InterruptedError:
                events = []

            for (fd, _) in events:
                if fd == self.get_rep_fd():
                    self.handle_rep_fd_event()
                else:
                    raise Exception("Unexpected event on fd {}".format(fd))
            if not events:
                self.update_configurations()
                self.do_update_hardware_status()

    def update_configurations(self):
        SERVICE = 'vyatta-services-v1:service'
        SERVICE_GNSS = 'vyatta-service-gnss-v1:gnss'

        def get_config(instance_number, instances):
            for instance in instances:
                if instance['instance-number'] == instance_number:
                    return instance
            return None

        try:
            if self.client is None:
                self.client = vci.Client()
            state = self.client.state_by_model('net.vyatta.vci.service.gnss.v1')
            configs = state[SERVICE][SERVICE_GNSS]['instance']
        except vci.Exception:
            self.client = None
            print('Unable to connect to VCI service')
            return
        except KeyError:
            print('Unable to find configuration state in VCI')
            return

        for the_instance in self.instances:
            antenna_delay = the_instance.get_antenna_delay()
            instance_number = the_instance.get_instance()

            config = get_config(instance_number, configs)
            if not config:
                continue

            antenna_delay = config['antenna-delay']
            if the_instance.get_antenna_delay() != antenna_delay:
                the_instance.set_antenna_delay(antenna_delay)


if __name__ == "__main__":
    gnssd = VyattaGnssDaemon()
    gnssd.main_loop([(gnssd.get_rep_fd(), select.POLLIN)])
