#!/usr/bin/python3
#
# Copyright (c) 2018-2019 AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only

import os
import sys
import syslog
from subprocess import check_output
from collections import OrderedDict
from vyatta import configd

GlobalSSHConfig = '/etc/ssh/ssh_config'
CustomKEXConfig = '/etc/vyatta/ssh_kex_config'
CustomHKAConfig = '/etc/vyatta/ssh_hka_config'

disallowedKEXList = [ 'diffie-hellman-group1-sha1',
                      'diffie-hellman-group14-sha1',
                      'diffie-hellman-group-exchange-sha1' ]

disallowedHKAList = [ 'ssh-dss' ]

def update_global_config():

    search = "Include {}\n".format(CustomHKAConfig)
    with open(GlobalSSHConfig) as global_config:
        if any(line == search for line in global_config):
            return
        else:
            with open(GlobalSSHConfig, "a") as global_config:
                global_config.write("Include /etc/vyatta/ssh_kex_config\n")
                global_config.write("Include /etc/vyatta/ssh_hka_config\n")

def get_config(name):

    try:
        configd_client = configd.Client()
    except configd.FatalException as f:
        print("can't connect to configd: {}\n".format(f))
        sys.exit(1)

    try:
        top_cfg = configd_client.tree_get_dict(
            "security ssh-client permit {}".format(name),
            configd.Client.CANDIDATE, "json")
        cfg = top_cfg[name]
    except configd.Exception:
        cfg = None

    return(cfg)

#
# Disable by default some legacy SSH Key Exchange Algorithms unless
# explicitly permitted by configuration.
#
# if "default" is True, exclude the default set of weak algorithms
# independent of configuration.
#
def update_kex(default = False):

    global disallowedKEXList

    enabled = None if default else get_config('key-exchange-algorithm')

    if not enabled is None:
        disallowedKEXList = [a for a in disallowedKEXList if a not in enabled]

    if not disallowedKEXList:
        filtercmd = ''
    else:
        disallowed = "|".join(disallowedKEXList)
        filtercmd = " | egrep -v '{}'".format(disallowed)

    kex = check_output("ssh -Q kex {}" .format(filtercmd), shell=True)
    kd = kex.decode("utf-8")
    algs = kd.splitlines()
    algstr = ",".join(algs)

    with open(CustomKEXConfig, "w") as custom_config:
        custom_config.write("Host *\n")
        custom_config.write("    KexAlgorithms {}\n".format(algstr))

    if enabled:
        algstr = ",".join(enabled)
        syslog.syslog(syslog.LOG_WARNING,
                  "SSH: Key Exchange Algorithms {} enabled.\n".format(algstr))

#
# Configure some legacy SSH Host Key Algorithms
#
def update_hka():

    hka = get_config('host-key-algorithm')

    if hka is None:
        try:
            os.remove(CustomHKAConfig)
        except OSError:
            pass
        return
    else:
        hka = hka[0]

    with open(CustomHKAConfig, "w") as custom_config:
        custom_config.write("Host *\n")
        custom_config.write("    HostKeyAlgorithms +{}\n".format(hka))

    syslog.syslog(syslog.LOG_WARNING,
                  "SSH: Legacy Host Key Algorithms {} enabled.\n".format(hka))

if sys.argv[1] == 'defaults':
    if not os.path.isfile(CustomKEXConfig):
        update_global_config()
        update_kex(True)

elif sys.argv[1] == 'update-legacy':
    update_hka()
    update_kex()

else:
    print("No handler\n")

