#!/usr/bin/python3

__author__ = 'Altertech https://www.altertech.com/'
__license__ = 'MIT'
__version__ = '0.0.12'

from pathlib import Path
from pySMART import Device
from collections import OrderedDict
import rapidtables
import neotermcolor
import os
import sys
import argparse
import shutil

colored = neotermcolor.colored

temp_warn = 40
temp_crit = 45


def c2f(t, nc):
    return round(t * 1.8 + 32) if nc and t is not None else t


ap = argparse.ArgumentParser(description='Show HDD/SSD list')
ap.add_argument('-V',
                '--version',
                action='store_true',
                help='Print version and exit')
ap.add_argument('--temp-warn',
                type=int,
                metavar='TEMP',
                help='Warning temperature, default: 40 C')
ap.add_argument('--temp-crit',
                type=int,
                metavar='TEMP',
                help='Critical temperature, default: 45 C')
ap.add_argument('--fh',
                '--fahrenheit',
                dest='fh',
                action='store_true',
                help='Temperature in Fahrenheit')
ap.add_argument('-R', '--raw', action='store_true', help='Suppress colors')
ap.add_argument('-y',
                '--full',
                action='store_true',
                help='Display full disk info')
ap.add_argument('-e',
                '--errors',
                action='store_true',
                help='Display only disks with errors / critical temperature')
ap.add_argument('-s',
                '--no-header',
                action='store_true',
                help='Suppress header')
ap.add_argument('-J', '--json', action='store_true', help='Output as JSON')
a = ap.parse_args()

if a.temp_warn is None:
    a.temp_warn = c2f(temp_warn, a.fh)

if a.temp_crit is None:
    a.temp_crit = c2f(temp_crit, a.fh)

if a.version:
    print(__version__)
    sys.exit()

if os.getuid():
    print('please run me as root')
    sys.exit(5)

if not shutil.which('smartctl'):
    print('smartctl not found, please install smartmontools')
    sys.exit(6)

if a.raw:
    neotermcolor._isatty = False

data = []
temps = []
errors = []

exit_code = 0

dev_dir = Path('/dev')

devs = []

for m in ('nvme[0-999]', 'sd[a-z]', 'hd[a-z]'):
    devs += list(dev_dir.glob(m))

for devpath in devs:
    if neotermcolor._isatty:
        print(': {} '.format(colored(devpath, color='cyan')), end='')
    # don't use is_block_device - it's False for some NVMe SSDs
    if devpath.exists():
        dev = Device(devpath.as_posix())
        d = OrderedDict()
        d['Disk'] = dev.name
        d['Model'] = dev.model
        d['Serial'] = dev.serial
        dev.temperature = c2f(dev.temperature, a.fh)
        if dev.temperature:
            d['Temp'] = dev.temperature if a.json else '{} {}'.format(
                dev.temperature, 'F' if a.fh else 'C')
        else:
            d['Temp'] = ''
        lcc = ''
        poh = ''
        err = False
        for attr in dev.attributes:
            try:
                if attr:
                    if attr.name in (
                            'End-to-End_Error',
                            'Seek_Error_Rate',
                            'Raw_Read_Error_Rate',
                            'UDMA_CRC_Error_Count',
                    ) and int(attr.raw) > 0:
                        err = True
                    elif attr.name == 'Load_Cycle_Count':
                        lcc = int(attr.raw)
                    elif attr.name == 'Power_On_Hours':
                        poh = int(attr.raw)
            except:
                pass
        if a.full:
            d['PoH'] = poh
            d['LCC'] = lcc
            d['Int'] = dev.interface
            d['SSD'] = 'ssd' if dev.is_ssd else ''
            d['Capacity'] = dev.capacity
            d['RRate'] = dev.rotation_rate
            d['Firmware'] = dev.firmware
        if not a.errors or err or (dev.temperature and
                                   dev.temperature >= a.temp_crit):
            temps.append(dev.temperature)
            data.append(d)
            errors.append(err)
        if neotermcolor._isatty:
            print('[{}]'.format(colored(dev.serial, color='cyan',
                                        attrs='bold')),
                  end='')
    if neotermcolor._isatty:
        sys.stdout.write("\033[K")
        print('\r', end='')
        sys.stdout.flush()

if a.json:
    import json
    print(json.dumps(data))
elif data:
    header, rows = rapidtables.format_table(
        data,
        fmt=rapidtables.FORMAT_GENERATOR_COLS,
        align=(rapidtables.ALIGN_CENTER, rapidtables.ALIGN_LEFT,
               rapidtables.ALIGN_LEFT, rapidtables.ALIGN_RIGHT,
               rapidtables.ALIGN_RIGHT, rapidtables.ALIGN_RIGHT,
               rapidtables.ALIGN_CENTER, rapidtables.ALIGN_CENTER,
               rapidtables.ALIGN_RIGHT, rapidtables.ALIGN_RIGHT,
               rapidtables.ALIGN_LEFT))
    spacer = '  '
    if not a.no_header:
        print(colored(spacer.join(header), color='blue'))
        print(colored('-' * sum([(len(x) + 2) for x in header]), color='grey'))
    for r, d, temp, e in zip(rows, data, temps, errors):
        print(colored(r[0], color='red' if e else 'cyan') + spacer, end='')
        print(colored(r[1], color='red' if e else 'white') + spacer, end='')
        print(colored(r[2], color='red' if e else 'cyan', attrs='bold') +
              spacer,
              end='')
        try:
            if temp >= a.temp_crit:
                tcolor = 'red'
                exit_code = 1
            elif temp >= a.temp_warn:
                tcolor = 'yellow'
            else:
                tcolor = 'green'
        except:
            tcolor = None
        print(colored(r[3], color=tcolor, attrs='bold') + spacer, end='')
        if a.full:
            print(colored(r[4] + spacer), end='')
            print(colored(r[5] + spacer, color='cyan'), end='')
            print(colored(r[6] + spacer), end='')
            print(colored(r[7] + spacer,
                          color='magenta' if r[7].strip() == 'ssd' else None),
                  end='')
            print(colored(r[8] + spacer, attrs='bold'), end='')
            print(colored(r[9] + spacer, color='cyan'), end='')
            print(colored(r[10] + spacer))
        else:
            print()
else:
    if neotermcolor._isatty:
        sys.stdout.write("\033[K")

sys.exit(2 if True in errors else exit_code)
