#!/usr/bin/env python3

# Copyright (c) 2018, AT&T Intellectual Property.
# All rights reserved.
#
# SPDX-License-Identifier: GPL-2.0-only

import os
import sys
import yaml
import subprocess

class VyattaDiagnostics(object):
    COMPLETIONS = 'completions'
    DIAGNOSTICS = 'diagnostics'
    PRERUN = 'prerun'
    POSTRUN = 'postrun'
    PROBE = 'probe'
    RUN = 'run'
    HARDWARE_DIAGNOSTICS = '/opt/vyatta/hardware-diagnostics'

    def __init__(self):
        self.hwDiags = {}
        for filename in os.listdir(self.HARDWARE_DIAGNOSTICS):
            file = open(self.HARDWARE_DIAGNOSTICS + '/' + filename, 'r')
            hwdiags = yaml.load(file)
            if self.probeDiags(hwdiags):
                self.hwDiags.update(hwdiags[self.DIAGNOSTICS])

    def runDiag(self, path):
        diag = self.getDiag(path[0])
        if not diag:
            print("Unknown hardware diagnostic");
            return
        if self.PRERUN in diag:
            subprocess.call(diag[self.PRERUN], shell=True)
        if self.RUN in diag:
            cmdline = diag[self.RUN].replace("%args%", " ".join(path[1:]))
            subprocess.call(cmdline, shell=True)
        if self.POSTRUN in diag:
            subprocess.call(diag[self.POSTRUN], shell=True)

    def listDiags(self, path):
        if len(path) == 0:
            print(" ".join(self.hwDiags.keys()))
        else:
            diagname = path[0]
            diag = self.getDiag(diagname)
            if diag:
                completion_path = " ".join(path)
                if completion_path in diag[self.COMPLETIONS]:
                    print(diag[self.COMPLETIONS][completion_path])
            else:
                print("Unknown hardware diagnostic");

    def getDiag(self, diag):
        if diag in self.hwDiags:
            return self.hwDiags[diag]
        else:
            return None

    def probeDiags(self, hwdiags):
        if self.PROBE in hwdiags:
            rc = subprocess.call(hwdiags[self.PROBE], shell=True)
            if rc == 0:
                return True
            else:
                return False
        else:
            return True

def usage():
    print("usage: vyatta-diag-op [--listDiags [current path]|--runDiag diagname]")
    sys.exit(1)

diags = VyattaDiagnostics();

argc = len(sys.argv)
if argc == 1:
    usage()
elif sys.argv[1] == "--listDiags":
    diags.listDiags(sys.argv[2:]);
elif sys.argv[1] == "--runDiag" and argc > 2:
    diags.runDiag(sys.argv[2:]);
else:
    print(" ".join(sys.argv))
    usage()
