#!/usr/libexec/platform-python
"""Cluster tools
Usage:
  cluster_tools.py [-hv] [--register]
  cluster_tools.py [-hv] [--update]
  cluster_tools.py [-hv] [--deregister]
  cluster_tools.py --version
Options:
  -h --help            show this help message and exit
  --register           register with the cluster
  --update             check for new nodes to register with
  --deregister         de-register from the cluster
  --version            show version and exit
  -v --verbose         print status messages
"""


import socket
import sys
import os
import re
import time
import logging
from pathlib import Path
import subprocess as sp
import paramiko
from paramiko.client import SSHClient
from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import NoValidConnectionsError
from docopt import docopt
from contextlib import contextmanager
from typing import List
import warnings

UUID_PATTERN = re.compile(r"ipc_uuid *= *(.*)$")

LOGLEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()

logger = logging.getLogger("cluster_tools")
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter("%(asctime)s %(name)-13s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level=LOGLEVEL)

# supress warnings from paramiko on EL7
warnings.filterwarnings("ignore", message="signer and verifier have been deprecated.*", module="paramiko")

class SSHConnectionException(Exception):
    """Raised on failure to establish an ssh connection"""

    pass


class SSHCommandException(Exception):
    """Raised on failure to execute a command over ssh"""

    pass


class SSHTimeoutException(Exception):
    """Raised on timeout of ssh action"""

    pass


class RegistrationFailure(Exception):
    """Raised when registration has failed and we should bail out"""

    pass


class DeRegistrationFailure(Exception):
    """Raised when de-registration has failed and we should bail out"""

    pass


class UpdateFailure(Exception):
    """Raised when a failure has occurded while updating the cluster configuration"""

    pass


class ShellExecutionFailure(Exception):
    """Raised on failure to execute a shell cmd"""

    pass


class LockAcquisitionFailure(Exception):
    """Raised when lock could not be aquired"""

    pass


def execute_shell_cmd(cmd: list) -> str:
    """Executes a shell command

    Args:
        cmd: A comma delimited list of strings which is the command to execute
    Raises:
        ShellExecutionFailure: on failure to execute the command
    Returns:
        The output of the execution
    """
    p = None
    try:
        p = sp.run(cmd, stdout=sp.PIPE, stderr=sp.PIPE, check=True)
    except sp.CalledProcessError as ex:
        if p:
            logger.error(
                f"Shell error: {p.stdout.decode('utf-8')} {p.stderr.decode('utf-8')}"
            )
        raise ShellExecutionFailure from ex
    else:
        return p.stdout.decode("utf-8")


def get_uuid() -> str:
    """Gets the uuid from the merlin config

    If no uuid is found it generates a new one, adds it to the
    merlin config and then returns it.
    """
    with open(MERLIN_CONFIG, "r") as merlin_config:
        for line in merlin_config:
            match = UUID_PATTERN.match(line)
            if match:
                return match.group(1)

    # No UUID found
    uuid = generate_uuid()
    set_uuid(uuid)
    return uuid


def set_uuid(uuid: str):
    merlin_uuid = "ipc_uuid = {}\n".format(uuid)
    with open(MERLIN_CONFIG, "a") as file_handle:
        file_handle.write(merlin_uuid)


def generate_uuid() -> str:
    """Generates and returns new UUID"""

    try:
        output = execute_shell_cmd(["mon", "id", "generate"])
    except ShellExecutionFailure:
        logger.error("Failed to generate UUID")
        sys.exit(1)
    # Strip the newline
    return output.strip()


def parse_sync_files(sync_files: str) -> list:
    if sync_files:
        return [file_name.strip() for file_name in sync_files.split(",")]
    return []


def parse_hostgroups(hostgroups: str) -> List[str]:
    return [group.strip() for group in hostgroups.split(",")]


MERLIN_CONFIG = "/etc/merlin/merlin.conf"
HOME = os.getenv('HOME', False)
KNOWN_HOSTS = HOME + "/.ssh/known_hosts"
PRIVATE_KEY = HOME + "/.ssh/id_rsa"
DEFAULT_ENCRYPTION_KEYS_LOCATION = Path("/etc/merlin")
UUID = get_uuid()
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
MASTER_PORT = os.environ.get("MASTER_PORT", 15551)
POLLER_SSH_PORT = os.environ.get("POLLER_SSH_PORT", 22)
FILES_TO_SYNC = parse_sync_files(os.environ.get("FILES_TO_SYNC", ""))

try:
    POLLER_NAME = os.environ["POLLER_NAME"]
    MASTER_ADDRESS = os.environ["MASTER_ADDRESS"]
    MASTER_NAME = os.environ["MASTER_NAME"]
    POLLER_ADDRESS = os.environ["POLLER_ADDRESS"]
    POLLER_NAME = os.environ["POLLER_NAME"]
    POLLER_HOSTGROUPS = parse_hostgroups(os.environ["POLLER_HOSTGROUPS"])
    LOCK_TIMEOUT = int(os.environ.get("LOCK_TIMEOUT", 300))
except KeyError as err:
    logger.error(f"Required environment variable {err} was not set, exiting")
    sys.exit(1)
except ValueError as err:
    logger.error(f"Conversion error, full error: {err}")
    sys.exit(1)


def add_cluster_update():
    """Adds cluster_update to merlin.conf"""

    with open(MERLIN_CONFIG, "r") as merlin_config:
        for line in merlin_config:

            if "cluster_update" in line:
                logger.debug("cluster_update already set")
                return

    try:
        execute_shell_cmd(
            [
                "sed",
                "-i",
                "/module {/a\\\tcluster_update = \/usr\/bin\/merlin_cluster_tools --update",
                MERLIN_CONFIG,
            ]
        )
    except ShellExecutionFailure as ex:
        logger.error("Failed to write cluster_update to merlin conf")
        raise RegistrationFailure from ex


@contextmanager
def setup_ssh_connection(hostname: str, port: int = 22, timeout: float = 10.0):
    """Sets up an SSH connection to a node

    Assumes that this node's public key has already
    been added to the remote node's authorized_keys.

    Will auto add the remote node to this node's
    known_hosts.

    Args:
        hostname: Hostname of the node
        port: SSH port of the node, default: 22
        timeout: Timeout in seconds for setting up the connection
    Raises:
        SSHConnectionException
    Returns:
        SSHClient
    """
    client = SSHClient()
    try:
        client.load_host_keys(KNOWN_HOSTS)
    except FileNotFoundError:
        logger.debug(f"{KNOWN_HOSTS} not found, creating.")
        open(KNOWN_HOSTS, "w").close()
        client.load_host_keys(KNOWN_HOSTS)

    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    try:
        client.connect(
            hostname=hostname, key_filename=PRIVATE_KEY, port=port, timeout=timeout
        )
        yield client
    except socket.timeout as ex:
        logger.error(
            f"Failed to establish connection to {hostname}, connection timed out."
        )
        raise SSHConnectionException from ex
    except NoValidConnectionsError as ex:
        logger.error(f"Failed to establish connection to {hostname}.")
        raise SSHConnectionException from ex
    finally:
        client.close()


def execute_ssh_command(client: SSHClient, cmd: str) -> List[str]:
    """Executes a command over SSH

    Args:
        client: SSH-connection
        cmd: Command to execute
    Raises:
        SSHCommandException
    Returns:
        The output of the command if successful
    """
    _, stdout, stderr = client.exec_command(cmd)
    exit_code = stdout.channel.recv_exit_status()
    if exit_code != 0:
        err_msg = f"SSH error: {' '.join(stdout.readlines() + stderr.readlines())}"
        if err_msg:
            logger.error(err_msg)
        raise SSHCommandException
    else:
        output = stdout.readlines()
        # Strip whitespace and newlines
        return [item.strip() for item in output]


def add_master(master_name, master_address, poller_name, master_port=15551):
    """Add a config entry for a master node in the merlin config

    Args:
        master_name: Name of the new master
        master_address: Address of the new master
        poller_name: Name of this poller
        master_port: Port which merlin communicates over, default 15551
    Raises:
        pass
    """
    logger.info(f"Adding {master_name} as a master")
    sync_seperator = "\n    "
    config = (
        f"master {master_name} {{\n"
        f"  address = {master_address}\n"
        f"  port = {master_port}\n"
        f"  encrypted = 1\n"
        f"  publickey = {DEFAULT_ENCRYPTION_KEYS_LOCATION}/{master_name}.pub\n"
        f"  accept_runcmd = 1\n"
        f"  sync {{\n"
        f"    {sync_seperator.join(FILES_TO_SYNC)}\n"
        f"  }}\n"
        f"  object_config {{\n"
        f"    fetch_name = {poller_name}\n"
        f"    fetch = mon oconf fetch --sync {master_name}\n"
        f"  }}\n"
        f"}}\n"
    )
    logger.debug(f"Config entry:\n {config}")

    with open(MERLIN_CONFIG, "a") as merlin_config:
        merlin_config.write(config)


def get_remote_cluster_info(client: SSHClient):
    """Retrives the cluster info from a node

    Args:
        client: SSH-connection to the node
    """
    logger.debug("Parsing remote cluster info")
    try:
        output = execute_ssh_command(client, "mon node show")
    except SSHCommandException as ex:
        logger.error("Failed to fetch remote cluster info")
        raise ex
    else:
        return parse_mon_node_show(output)


def get_local_cluster_info() -> dict:
    """Retrives the local cluster info

    Args:
        client: SSH-connection to the node
    """
    logger.debug("Parsing local cluster info")
    try:
        output = execute_shell_cmd(["mon", "node", "show"])
    except ShellExecutionFailure:
        # Might not be an issue, will be caught later if it is
        logger.debug("Could not fetch local cluster info")
        return dict()
    else:
        if output == "No nodes configured\n":
            return dict()
        return parse_mon_node_show(output.split("\n"))


def parse_mon_node_show(cluster_info):
    """Parses the output of mon node show

    Args:
        cluster_info: The raw data as returned by mon node show
    Returns:
        parsed_info: A dict of the form {node: {var: val, ...}, ...}
    """
    parsed_info = {}
    current_node = None
    for line in cluster_info:
        if line in ["", "\n"]:
            continue
        if line.startswith("#"):
            # New node, looks like "# <name>\n"
            current_node = line[2:].strip()
            parsed_info[current_node] = {}
        else:
            # Normal peer info, "<VAR>=<val>\n"
            var, val = line.split("=")
            parsed_info[current_node][var.lower()] = val.strip()
    return parsed_info


def add_poller(name: str, address: str):
    """Add a poller entry to the merlin config

    Args:
        name: Name of the new poller
        address: Address to the new poller
    """
    logger.info(f"Adding {name} as a poller-peer")
    try:
        execute_shell_cmd(
            ["mon", "node", "add", name, "type=peer", "address=" + address]
        )
    except ShellExecutionFailure as ex:
        logger.error(f"Failed to add poller-peer {name} to local config")
        raise RegistrationFailure from ex


def fetch_config():
    """Fetches the naemon config"""
    logger.info(f"Fetching config from {MASTER_NAME}")
    try:
        execute_shell_cmd(["mon", "oconf", "fetch", "--sync", MASTER_NAME])
    except ShellExecutionFailure as ex:
        logger.error(f"Could not fetch config from {MASTER_NAME}")
        raise RegistrationFailure from ex


def remove_node(name: str):
    """Removes a node from the configuration

    Args:
        name: Name of the node to be removed
    """
    logger.info(f"Removing node: {name}")
    try:
        execute_shell_cmd(["mon", "node", "remove", name])
    except ShellExecutionFailure:
        logger.error("Could not remove node {name} from local config")
    try:
        execute_shell_cmd(
            ["rm", "-f", f"DEFAULT_ENCRYPTION_KEYS_LOCATION/{name}.pub"],
        )
    except ShellExecutionFailure:
        logger.error(f"Could not remove public encryption key belonging to {name}")


def hostgroups_exist(client: SSHClient, hostgroups: List[str]) -> bool:
    """Queries livestatus on a node to check if hostgroups exists

    Args:
        client: SSH-connection to the node
        hostgroups: Hostgroup names to check
    Returns:
        True, if all the hostgroups exists
    """

    logger.info(
        f"Verifying that hostgroups {', '.join(hostgroups)} exists on the master"
    )
    regex_hostgroups = "|".join(hostgroups)
    cmd = f"mon query ls hostgroups -c name name -r '{regex_hostgroups}'"
    try:
        output = execute_ssh_command(client, cmd)
    except SSHCommandException as ex:
        logger.debug(ex, exc_info=True)
        return False
    for group in hostgroups:
        if group not in output:
            logger.error(f"Hostgroup {group} does not exist")
            return False
    return True


def restart():
    """Runs mon restart locally"""
    logger.info("Executing mon restart")
    try:
        execute_shell_cmd(["mon", "restart"])
    except ShellExecutionFailure as ex:
        logger.error("Could not execute mon restart")
        logger.debug(ex)


def update():
    """Updates the cluster configuration

        Contacts the master and checks for new nodes and nodes which have been removed

    Raises:
        UpdateFailure: On failure to establish a connection to the master.
    """
    logger.info("Updating local cluster config")

    local_cluster_info = get_local_cluster_info()
    # Add our selves to the configuration so we aren't flagged as new
    local_cluster_info[POLLER_NAME] = {}
    # Get remote configuration
    try:
        with setup_ssh_connection(MASTER_ADDRESS) as master:
            remote_cluster_info = get_remote_cluster_info(master)
            # Add the master to its own config so it isn't flagged for removal
            remote_cluster_info[MASTER_NAME] = {}
    except SSHConnectionException as ex:
        logger.error(
            f"Failed to establish an ssh connection to {MASTER_NAME} during update"
        )
        raise UpdateFailure() from ex
    except SSHCommandException as ex:
        logger.error(f"Failed to fetch cluster info from {MASTER_NAME}")
        raise UpdateFailure() from ex

    nodes_to_add = remote_cluster_info.keys() - local_cluster_info.keys()
    nodes_to_remove = local_cluster_info.keys() - remote_cluster_info.keys()

    if not (nodes_to_add or nodes_to_remove):
        # Nothing to update
        logger.debug("Update called but there was nothing to do")
        return

    for node in nodes_to_add:
        node_info = remote_cluster_info[node]
        if node_info["type"] == "peer":
            logger.info(f"New master: {node_info['name']}")
            add_master(
                node_info["name"],
                node_info["address"],
                POLLER_NAME,
                node_info["port"],
            )
            try:
                with setup_ssh_connection(node_info["address"]) as client:
                    public_key = exchange_encryption_keys(client, node_info["name"])
                    hostgroups = ",".join(POLLER_HOSTGROUPS)
                    cmd = (
                        f"mon slim-poller register "
                        f"--poller-name={POLLER_NAME} "
                        f"--poller-address={POLLER_ADDRESS} "
                        f"--public-key={public_key} "
                        f"--uuid={UUID} "
                        f"--hostgroups={hostgroups}"
                    )
                    execute_ssh_command(client, cmd)
            except RegistrationFailure as ex:
                logger.error("Failed to register with new master")
                raise UpdateFailure from ex
            except (SSHConnectionException, SSHCommandException) as ex:
                logger.error(f"SSH connection to {node_info['name']} failed")
                raise UpdateFailure from ex

        elif (
            node_info["type"] == "poller"
            and parse_hostgroups(node_info["hostgroup"]) == POLLER_HOSTGROUPS
        ):
            logger.info(f"New poller: {node_info['name']}")
            try:
                add_poller(name=node_info["name"], address=node_info["address"])
            except RegistrationFailure as ex:
                logger.error("Failed to update config with new poller")
                raise UpdateFailure from ex

    for node in nodes_to_remove:
        remove_node(local_cluster_info[node]["name"])

    restart()


def exchange_encryption_keys(client: SSHClient, name: str) -> str:
    """Exchanges public encryption keys with another node

    Args:
        client: SSHClient connection the the remote node
        name: Name of the remote node
    Returns:
        Path to this node's public key on the remote node
    Raises:
        RegistrationFailure: On failure to exchange keys
    """
    logger.info(f"Exchanging keys with {name}")
    # Location of encryption keys on remote node
    encryption_keys_location = DEFAULT_ENCRYPTION_KEYS_LOCATION
    try:
        # Check if path to private key has been set
        output = execute_ssh_command(
            client, f"awk '/ipc_privatekey/{{print $NF}}' {MERLIN_CONFIG}"
        )
        if output:
            logger.debug(f"Found path to private key in merlin config on {name}")
            # Output will be a list with one item
            encryption_keys_location = Path(output.pop()).parent
        else:
            logger.info(f"No path to encryption key found in config file for {name}. Attempting to generate...")
            try:
                execute_ssh_command(
                    client, f"mon merlinkey generate --write"
                )
            except SSHCommandException as ex:
                logger.error(f"Couldn't generate encryption keys on {name}")
                raise RegistrationFailure from ex
    except SSHCommandException as ex:
        logger.error("Failed to awk the merlin config for location of private key")
        raise RegistrationFailure from ex
    try:
        with SFTPClient.from_transport(client.get_transport()) as sftp:
            sftp.get(
                str(encryption_keys_location / "key.pub"),
                str(DEFAULT_ENCRYPTION_KEYS_LOCATION / f"{name}.pub"),
            )
            sftp.put(
                str(DEFAULT_ENCRYPTION_KEYS_LOCATION / "key.pub"),
                str(encryption_keys_location / f"{POLLER_NAME}.pub"),
            )
    except FileNotFoundError as ex:
        logger.error(f"Encryption keys not found on remote master {name}")
        raise RegistrationFailure from ex
    return str(encryption_keys_location / f"{POLLER_NAME}.pub")


def generate_encryption_keys():
    """Generates local encryption keys

    TODO: When the changes to mon merlinkey generate have
    been merged this whole function can be reduced to
    a few lines.

    Raises:
        RegistrationFailure
    """
    logger.debug("Generating local encryption keys")
    # Check if keys already exist
    priv_key_exists = (DEFAULT_ENCRYPTION_KEYS_LOCATION / "key.priv").exists()
    pub_key_exists = (DEFAULT_ENCRYPTION_KEYS_LOCATION / "key.pub").exists()
    if pub_key_exists and priv_key_exists:
        # No need to generate encryption keys
        return
    if priv_key_exists or pub_key_exists:
        logger.error(
            "Only one of keys in encryption key pair exists, manual intervention needed"
        )
        raise RegistrationFailure
    try:
        execute_shell_cmd(["mon", "merlinkey", "generate"])
    except ShellExecutionFailure as ex:
        logger.error("Could not generate local encryption keys")
        raise RegistrationFailure from ex
    # Write location of private key to merlin config
    with open(MERLIN_CONFIG, "a") as fh:
        fh.write(f"ipc_privatekey = {DEFAULT_ENCRYPTION_KEYS_LOCATION}/key.priv\n")


def clean_up_local_state():
    """Clean up local state"""
    logger.debug("Cleaning up local state")
    cmd = ["mon", "node", "remove"]
    local_info = get_local_cluster_info()
    if not local_info:
        return
    cmd.extend(local_info.keys())
    try:
        execute_shell_cmd(cmd)
    except ShellExecutionFailure:
        raise DeRegistrationFailure


def deregister():
    time_spent_waiting = 0
    wait_time = 5
    exception = None

    deregistration_command = f"mon slim-poller deregister --poller-name={POLLER_NAME}"
    while time_spent_waiting < LOCK_TIMEOUT:
        try:
            with setup_ssh_connection(MASTER_ADDRESS) as master:
                execute_ssh_command(master, deregistration_command)
            return
        except (SSHConnectionException, SSHCommandException) as ex:
            logger.info("Could not aquire lock for deregistration, sleeping")
            time_spent_waiting += wait_time
            time.sleep(wait_time)
            exception = ex

    logger.critical(
        f"Lock could not be aquired within {LOCK_TIMEOUT} seconds for deregistration. The slim-poller will now do a hard exit potentially leaving nodes in the cluster with bad configs."
    )
    raise LockAcquisitionFailure from exception


def register():
    time_spent_waiting = 0
    wait_time = 5
    try:
        with setup_ssh_connection(MASTER_ADDRESS) as master:
            if not hostgroups_exist(master, POLLER_HOSTGROUPS):
                logger.error("Hostgroups not found on master: {POLLER_HOSTGROUPS}")
                raise RegistrationFailure

            # Register poller
            public_key = exchange_encryption_keys(master, MASTER_NAME)
            hostgroups = ",".join(POLLER_HOSTGROUPS)
            cmd = (
                f"mon slim-poller register "
                f"--poller-name={POLLER_NAME} "
                f"--poller-address={POLLER_ADDRESS} "
                f"--public-key={public_key} "
                f"--uuid={UUID} "
                f"--hostgroups={hostgroups} "
                "--cluster"
            )
            while time_spent_waiting < LOCK_TIMEOUT:
                try:
                    output = execute_ssh_command(master, cmd)
                    break
                except SSHCommandException:
                    logger.info("Could not aquire lock for registration, sleeping")
                    time_spent_waiting += wait_time
                    time.sleep(wait_time)
            else:
                logger.error(f"Lock could not be aquired within {LOCK_TIMEOUT} seconds")
                raise LockAcquisitionFailure
    except (SSHConnectionException, SSHCommandException) as ex:
        logger.error("Could not register with the master")
        raise RegistrationFailure from ex

    # Add the designated master
    add_master(MASTER_NAME, MASTER_ADDRESS, POLLER_NAME)

    # Registration cmd ends with the output of mon node show.
    # We'll use that to get information about all master peers
    # and add them to our config
    cluster_info = parse_mon_node_show(output)
    for node in cluster_info.values():
        if node["type"] == "peer":
            logger.info(f"New master: {node['name']}")
            add_master(
                node["name"],
                node["address"],
                POLLER_NAME,
                node["port"],
            )

            try:
                with setup_ssh_connection(node["address"]) as client:
                    exchange_encryption_keys(client, node["name"])
            except (SSHConnectionException, RegistrationFailure) as ex:
                logger.error("Failed to exchange encryption keys with new master")
                raise RegistrationFailure from ex

        elif (
            node["type"] == "poller"
            and node["name"] != POLLER_NAME
            and parse_hostgroups(node["hostgroup"]) == POLLER_HOSTGROUPS
        ):
            logger.info(f"New poller: {node['name']}")
            try:
                add_poller(name=node["name"], address=node["address"])
            except RegistrationFailure as ex:
                logger.error("Failed to add new poller to local config")
                raise ex


def main():
    arguments = docopt(__doc__, version="1.0")
    if arguments.get("-v") or arguments.get("--verbose"):
        logger.setLevel(logging.DEBUG)
    if arguments.get("--deregister"):
        try:
            deregister()
            clean_up_local_state()
        except (DeRegistrationFailure, LockAcquisitionFailure) as ex:
            logger.error("De-registration failed, manual intervention needed")
            logger.debug(ex, exc_info=True)
            sys.exit(1)
        except Exception as ex:
            logger.error("Unknown error encountered during de-registration")
            logger.exception(ex)
            sys.exit(1)

    elif arguments.get("--update"):
        try:
            update()
        except UpdateFailure as ex:
            logger.error("Update failed")
            logger.debug(ex, exc_info=True)
            sys.exit(1)
        except Exception as ex:
            logger.error("Unknown error encountered during update")
            logger.exception(ex)
            sys.exit(1)

    elif arguments.get("--register"):
        if MASTER_NAME in get_local_cluster_info():
            logger.info(
                "It appears this node is already configured, no registration needed."
            )
            sys.exit(0)
        try:
            generate_encryption_keys()

            register()

            fetch_config()

            add_cluster_update()
        except (SSHConnectionException, SSHCommandException) as ex:
            logger.info("Could not aquire lock, exiting")
            logger.debug(ex, exc_info=True)
            sys.exit(1)
        except RegistrationFailure as ex:
            # Bail and clean up as much as possible
            logger.error(
                "Error occurred during registration, bailing out and cleaning up"
            )
            logger.debug(ex, exc_info=True)
            deregister()
            sys.exit(1)
        except Exception as ex:
            logger.error("Unknown error occured during registration")
            logger.exception(ex)
            sys.exit(1)

        # Sanity check that the poller is now added
        if MASTER_NAME not in get_local_cluster_info():
            logger.error(
                "Something went wrong during setup. "
                f"{MASTER_NAME} not in local configuration"
            )
        else:
            logger.info(f"Poller {POLLER_NAME} successfully registered with cluster")
    else:
        logger.info("Nothing to do")


if __name__ == "__main__":
    main()
