#!/usr/bin/python3

import os
import sys
import re
import subprocess
import threading
import argparse
import textwrap
import json
import time
import logging
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Iterator
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum


class OutputFormat(Enum):
    TTY = 'tty'
    HTML = 'html'
    PIPE = 'pipe'
    JSON = 'json'


@dataclass
class PeerInfo:
    name: str
    pubkey: str
    ip: str
    online: Optional[bool] = None
    endpoint: str = ''
    latest_handshake: str = ''
    transfer_rx: int = 0
    transfer_tx: int = 0
    allowed_ips: str = ''
    persistent_keepalive: str = ''

    @property
    def handshake_seconds(self) -> Optional[int]:
        """Parse handshake time and return seconds ago"""
        if not self.latest_handshake or self.latest_handshake == '(none)':
            return None

        # Parse formats like "1 minute, 30 seconds ago" or "2 hours, 5 minutes ago"
        # Extract all time components and sum them
        total_seconds = 0
        matches = re.findall(r'(\d+)\s+(\w+)', self.latest_handshake)

        if not matches:
            return None

        for value_str, unit in matches:
            value = int(value_str)
            if 'second' in unit:
                total_seconds += value
            elif 'minute' in unit:
                total_seconds += value * 60
            elif 'hour' in unit:
                total_seconds += value * 3600
            elif 'day' in unit:
                total_seconds += value * 86400

        return total_seconds if total_seconds > 0 else None


@dataclass
class InterfaceInfo:
    name: str
    public_key: str = ''
    listening_port: str = ''
    peers: Dict[str, PeerInfo] = field(default_factory=dict)


class PreserveWhiteSpaceWrapRawTextHelpFormatter(argparse.RawTextHelpFormatter):
    def __add_whitespace(self, idx: int, iWSpace: int, text: str) -> str:
        if idx == 0:
            return text
        return (' ' * iWSpace) + text

    def _split_lines(self, text: str, width: int) -> List[str]:
        textRows = text.splitlines()
        for idx, line in enumerate(textRows):
            search = re.search(r'\s*[\d\-]*\.?\s*', line)
            if line.strip() == '':
                textRows[idx] = ' '
            elif search:
                lWSpace = search.end()
                lines = [self.__add_whitespace(i, lWSpace, x)
                        for i, x in enumerate(textwrap.wrap(line, width))]
                textRows[idx] = lines
        return [item for sublist in textRows for item in sublist]


class ColorFormatter:
    def __init__(self, output_format: OutputFormat):
        self.format = output_format
        self._setup_colors()

    def _setup_colors(self):
        if self.format == OutputFormat.HTML:
            self.red = '<span style="color: red;">'
            self.red_bold = '<span style="color: red; font-weight: bold;">'
            self.green = '<span style="color: green;">'
            self.green_bold = '<span style="color: green; font-weight: bold;">'
            self.yellow = '<span style="color: orange;">'
            self.yellow_bold = '<span style="color: orange; font-weight: bold;">'
            self.blue = '<span style="color: blue;">'
            self.blue_bold = '<span style="color: blue; font-weight: bold;">'
            self.bold = '<span style="font-weight: bold;">'
            self.end = '</span>'
        elif self.format == OutputFormat.TTY:
            self.red = '\033[0;31m'
            self.red_bold = '\033[1;31m'
            self.green = '\033[0;32m'
            self.green_bold = '\033[1;32m'
            self.yellow = '\033[0;33m'
            self.yellow_bold = '\033[1;33m'
            self.blue = '\033[0;34m'
            self.blue_bold = '\033[1;34m'
            self.bold = '\033[1m'
            self.end = '\033[0m'
        else:
            self.red = self.red_bold = self.green = self.green_bold = ''
            self.yellow = self.yellow_bold = self.blue = self.blue_bold = ''
            self.bold = self.end = ''


class WireGuardInfo:
    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.interfaces: Dict[str, InterfaceInfo] = {}
        self.config_path = Path('/etc/wireguard')

        # Setup logging
        log_level = logging.DEBUG if args.verbose else logging.WARNING
        logging.basicConfig(
            level=log_level,
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

        # Setup output format
        if args.json:
            self.output_format = OutputFormat.JSON
        elif args.html:
            self.output_format = OutputFormat.HTML
        elif sys.stdout.isatty() or args.tty:
            self.output_format = OutputFormat.TTY
        else:
            self.output_format = OutputFormat.PIPE

        self.colors = ColorFormatter(self.output_format)

    def human_readable_bytes(self, bytes_val: int) -> str:
        """Convert bytes to human readable format"""
        for unit in ['B', 'KiB', 'MiB', 'GiB', 'TiB']:
            if bytes_val < 1024.0:
                return f"{bytes_val:.2f} {unit}"
            bytes_val /= 1024.0
        return f"{bytes_val:.2f} PiB"

    def check_root(self):
        """Check if running as root"""
        if os.getuid() != 0:
            print(f"{self.colors.red_bold}\nERROR: {self.colors.yellow}"
                  f"This script must run as root\n{self.colors.end}")
            sys.exit(1)

    def check_wireguard_installed(self) -> bool:
        """Check if WireGuard is installed"""
        try:
            subprocess.run(['wg', '--version'],
                          capture_output=True,
                          check=True)
            return True
        except (subprocess.CalledProcessError, FileNotFoundError):
            print(f"{self.colors.red_bold}ERROR: {self.colors.yellow}"
                  f"WireGuard (wg) command not found. Please install wireguard-tools{self.colors.end}")
            return False

    def get_up_interfaces(self, interface_names: List[str]) -> List[str]:
        """Get list of interfaces that are currently up"""
        try:
            output = subprocess.check_output(['ip', 'link'], stderr=subprocess.DEVNULL)
            up = []
            for line in re.findall(rb"\s?\d+: .*:[^\n]+", output):
                interface = line.split(b':')[1].decode('utf-8').strip()
                if interface in interface_names:
                    up.append(interface)
            return up
        except subprocess.CalledProcessError as e:
            self.logger.error(f"Failed to get interface list: {e}")
            return []

    def read_config(self, interface: str) -> bool:
        """Read WireGuard config file and extract peer information"""
        config_file = self.config_path / f"{interface}.conf"

        if not config_file.exists():
            self.logger.warning(f"Config file not found: {config_file}")
            return False

        try:
            with open(config_file) as cfg:
                peer_section = False
                peer_name = "*nameless*"
                peer_pubkey = ""
                peer_ip = ""

                for line in cfg.readlines():
                    line = line.strip()

                    if line == "[Peer]":
                        if peer_section and peer_pubkey:
                            if interface not in self.interfaces:
                                self.interfaces[interface] = InterfaceInfo(name=interface)

                            self.interfaces[interface].peers[peer_pubkey] = PeerInfo(
                                name=peer_name,
                                pubkey=peer_pubkey,
                                ip=peer_ip
                            )
                            peer_name = "*nameless*"
                            peer_pubkey = ""
                            peer_ip = ""
                        peer_section = True

                    if peer_section:
                        if line.startswith("PublicKey"):
                            peer_pubkey = line.split('=', 1)[-1].strip()
                        elif re.match(r"#?\s*Name", line):
                            peer_name = line.split('=', 1)[-1].strip()
                        elif line.startswith("AllowedIPs") and not peer_ip:
                            peer_ip = line.split('=', 1)[-1].strip().split(',')[0].split('/')[0]

                # Don't forget the last peer
                if peer_section and peer_pubkey:
                    if interface not in self.interfaces:
                        self.interfaces[interface] = InterfaceInfo(name=interface)

                    self.interfaces[interface].peers[peer_pubkey] = PeerInfo(
                        name=peer_name,
                        pubkey=peer_pubkey,
                        ip=peer_ip
                    )

            return True

        except Exception as e:
            self.logger.error(f"Failed to read config file {config_file}: {e}")
            return False

    def ping_peer(self, peer: PeerInfo, timeout: int = 1):
        """Ping a peer to check if it's online"""
        if not peer.ip:
            self.logger.warning(f"No IP address for peer {peer.name}, skipping ping")
            peer.online = False
            return

        try:
            # Determine if this is an IPv6 address
            is_ipv6 = ':' in peer.ip
            ping_cmd = 'ping6' if is_ipv6 else 'ping'

            retcode = subprocess.call(
                [ping_cmd, '-c1', f'-W{timeout}', peer.ip],
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL
            )
            peer.online = (retcode == 0)
        except FileNotFoundError:
            # If ping6 doesn't exist, try with ping -6
            if is_ipv6:
                try:
                    retcode = subprocess.call(
                        ['ping', '-6', '-c1', f'-W{timeout}', peer.ip],
                        stdout=subprocess.DEVNULL,
                        stderr=subprocess.DEVNULL
                    )
                    peer.online = (retcode == 0)
                except Exception as e:
                    self.logger.error(f"Failed to ping IPv6 address {peer.name} ({peer.ip}): {e}")
                    peer.online = False
            else:
                self.logger.error(f"ping command not found")
                peer.online = False
        except Exception as e:
            self.logger.error(f"Failed to ping {peer.name} ({peer.ip}): {e}")
            peer.online = False

    def ping_all_peers(self):
        """Ping all peers in parallel"""
        threads = []
        for interface_info in self.interfaces.values():
            for peer in interface_info.peers.values():
                if peer.ip:
                    th = threading.Thread(
                        target=self.ping_peer,
                        args=(peer, self.args.ping_timeout),
                        daemon=True
                    )
                    threads.append(th)
                    th.start()

        for th in threads:
            th.join()

    def parse_wg_show(self, interface: str) -> bool:
        """Parse output from 'wg show' command"""
        try:
            output = subprocess.check_output(
                ['wg', 'show', interface],
                stderr=subprocess.DEVNULL
            ).decode('utf-8')

            current_peer = None

            for line in output.split('\n'):
                line = line.strip()

                if line.startswith('interface:'):
                    interface_name = line.split(':', 1)[1].strip()
                    if interface_name not in self.interfaces:
                        self.interfaces[interface_name] = InterfaceInfo(name=interface_name)

                elif line.startswith('public key:'):
                    self.interfaces[interface].public_key = line.split(':', 1)[1].strip()

                elif line.startswith('listening port:'):
                    self.interfaces[interface].listening_port = line.split(':', 1)[1].strip()

                elif line.startswith('peer:'):
                    peer_pubkey = line.split(':', 1)[1].strip()

                    # Ensure interface exists
                    if interface not in self.interfaces:
                        self.interfaces[interface] = InterfaceInfo(name=interface)

                    # If peer exists in config, use it; otherwise create a new one
                    if peer_pubkey in self.interfaces[interface].peers:
                        current_peer = self.interfaces[interface].peers[peer_pubkey]
                    else:
                        # Peer exists in wg show but not in config - create it dynamically
                        self.logger.info(f"Found peer {peer_pubkey} in 'wg show' but not in config - adding dynamically")
                        current_peer = PeerInfo(
                            name=f"*unknown-{peer_pubkey[:8]}*",
                            pubkey=peer_pubkey,
                            ip=""
                        )
                        self.interfaces[interface].peers[peer_pubkey] = current_peer

                elif current_peer and line:
                    key = line.split(':')[0].strip()
                    value = line.split(':', 1)[1].strip() if ':' in line else ''

                    if key == 'endpoint':
                        current_peer.endpoint = value
                    elif key == 'latest handshake':
                        current_peer.latest_handshake = value
                    elif key == 'transfer':
                        parts = value.split(',')
                        if len(parts) == 2:
                            rx = parts[0].strip().split()[0]
                            tx = parts[1].strip().split()[0]
                            current_peer.transfer_rx = self.parse_bytes(rx)
                            current_peer.transfer_tx = self.parse_bytes(tx)
                    elif key == 'allowed ips':
                        current_peer.allowed_ips = value
                        # If this peer was unknown and has no IP, try to extract first allowed IP
                        if not current_peer.ip and value:
                            first_ip = value.split(',')[0].split('/')[0].strip()
                            current_peer.ip = first_ip
                    elif key == 'persistent keepalive':
                        current_peer.persistent_keepalive = value

            return True

        except subprocess.CalledProcessError as e:
            self.logger.error(f"Failed to run 'wg show {interface}': {e}")
            return False
        except Exception as e:
            self.logger.error(f"Error parsing wg show output: {e}")
            return False

    def parse_bytes(self, value: str) -> int:
        """Parse byte values from wg show output"""
        try:
            units = {
                'B': 1,
                'KiB': 1024,
                'MiB': 1024**2,
                'GiB': 1024**3,
                'TiB': 1024**4
            }

            for unit, multiplier in units.items():
                if unit in value:
                    num = float(value.replace(unit, '').strip())
                    return int(num * multiplier)

            return int(float(value))
        except:
            return 0

    def get_peer_status_color(self, peer: PeerInfo) -> Tuple[str, str]:
        """Determine color formatting based on peer status"""
        if self.args.ping and peer.online is not None:
            if peer.online:
                return self.colors.green, self.colors.green_bold
            else:
                return self.colors.red, self.colors.red_bold

        # Check handshake recency
        seconds = peer.handshake_seconds
        if seconds is not None:
            if seconds < 300:  # Less than 5 minutes
                return self.colors.green, self.colors.green_bold
            elif seconds < 3600:  # Less than 1 hour
                return self.colors.yellow, self.colors.yellow_bold
            else:
                return self.colors.red, self.colors.red_bold

        return self.colors.red, self.colors.red_bold

    def get_status_symbol(self, peer: PeerInfo) -> str:
        """Get status symbol for peer"""
        if self.output_format != OutputFormat.TTY:
            return ''

        if self.args.ping and peer.online is not None:
            return '●' if peer.online else '○'

        seconds = peer.handshake_seconds
        if seconds is not None:
            if seconds < 300:
                return '●'
            elif seconds < 3600:
                return '◐'
            else:
                return '○'
        return '○'

    def sort_peers(self, peers: Dict[str, PeerInfo]) -> List[PeerInfo]:
        """Sort peers based on command line arguments"""
        peer_list = list(peers.values())

        if self.args.sort == 'name':
            peer_list.sort(key=lambda p: p.name.lower())
        elif self.args.sort == 'handshake':
            peer_list.sort(key=lambda p: p.handshake_seconds if p.handshake_seconds else 999999)
        elif self.args.sort == 'transfer':
            peer_list.sort(key=lambda p: p.transfer_rx + p.transfer_tx, reverse=True)
        elif self.args.sort == 'status':
            peer_list.sort(key=lambda p: (
                0 if (p.online if p.online is not None else False) else 1,
                p.handshake_seconds if p.handshake_seconds else 999999
            ))

        return peer_list

    def filter_peers(self, peers: List[PeerInfo]) -> List[PeerInfo]:
        """Filter peers based on command line arguments"""
        if self.args.filter == 'online':
            return [p for p in peers if p.online or (p.handshake_seconds and p.handshake_seconds < 300)]
        elif self.args.filter == 'offline':
            return [p for p in peers if not p.online and (not p.handshake_seconds or p.handshake_seconds >= 300)]
        return peers

    def show_interface(self, interface_info: InterfaceInfo):
        """Display information for a single interface"""
        print(f"{self.colors.yellow_bold}interface{self.colors.end}: "
              f"{self.colors.yellow}{interface_info.name}{self.colors.end}")

        if interface_info.public_key:
            print(f"  {self.colors.bold}public key{self.colors.end}: {interface_info.public_key}")
        if interface_info.listening_port:
            print(f"  {self.colors.bold}listening port{self.colors.end}: {interface_info.listening_port}")

        peers = self.sort_peers(interface_info.peers)
        peers = self.filter_peers(peers)

        # Show summary
        total_peers = len(interface_info.peers)
        online_count = sum(1 for p in interface_info.peers.values()
                          if p.online or (p.handshake_seconds and p.handshake_seconds < 300))

        if self.args.ping or any(p.handshake_seconds for p in interface_info.peers.values()):
            print(f"  {self.colors.bold}peers{self.colors.end}: "
                  f"{self.colors.green}{online_count}{self.colors.end}/"
                  f"{total_peers} online")

        for peer in peers:
            color, color_bold = self.get_peer_status_color(peer)
            symbol = self.get_status_symbol(peer)
            symbol_str = f"{symbol} " if symbol else ""

            print(f"\n  {color_bold}peer{self.colors.end}: "
                  f"{color}{symbol_str}{peer.name} ({peer.pubkey}){self.colors.end}")

            if peer.endpoint:
                print(f"    {self.colors.bold}endpoint{self.colors.end}: {peer.endpoint}")
            if peer.latest_handshake:
                print(f"    {self.colors.bold}latest handshake{self.colors.end}: {peer.latest_handshake}")
            if peer.transfer_rx or peer.transfer_tx:
                rx_human = self.human_readable_bytes(peer.transfer_rx)
                tx_human = self.human_readable_bytes(peer.transfer_tx)
                print(f"    {self.colors.bold}transfer{self.colors.end}: "
                      f"{rx_human} received, {tx_human} sent")
            if peer.allowed_ips:
                print(f"    {self.colors.bold}allowed ips{self.colors.end}: {peer.allowed_ips}")
            if peer.persistent_keepalive:
                print(f"    {self.colors.bold}persistent keepalive{self.colors.end}: {peer.persistent_keepalive}")

    def show_all(self):
        """Display information for all interfaces"""
        if self.output_format == OutputFormat.HTML:
            print('<pre>')

        interface_names = list(self.interfaces.keys())
        for idx, interface_name in enumerate(interface_names):
            self.show_interface(self.interfaces[interface_name])
            if idx < len(interface_names) - 1:
                print()

        if self.output_format == OutputFormat.HTML:
            print('</pre>')

    def output_json(self):
        """Output data in JSON format"""
        data = {}
        for interface_name, interface_info in self.interfaces.items():
            data[interface_name] = {
                'public_key': interface_info.public_key,
                'listening_port': interface_info.listening_port,
                'peers': []
            }

            for peer in interface_info.peers.values():
                peer_data = {
                    'name': peer.name,
                    'public_key': peer.pubkey,
                    'ip': peer.ip,
                    'endpoint': peer.endpoint,
                    'latest_handshake': peer.latest_handshake,
                    'handshake_seconds': peer.handshake_seconds,
                    'transfer_rx': peer.transfer_rx,
                    'transfer_tx': peer.transfer_tx,
                    'allowed_ips': peer.allowed_ips,
                    'persistent_keepalive': peer.persistent_keepalive
                }

                if self.args.ping:
                    peer_data['online'] = peer.online

                data[interface_name]['peers'].append(peer_data)

        print(json.dumps(data, indent=2))

    def run_once(self):
        """Run the tool once"""
        # Get list of interfaces
        config_files = [f.stem for f in self.config_path.glob('*.conf')]

        if self.args.interface:
            interfaces = [self.args.interface]
        else:
            interfaces = config_files

        up_interfaces = self.get_up_interfaces(interfaces)

        if not up_interfaces:
            if self.output_format != OutputFormat.JSON:
                print(f"{self.colors.yellow}No active WireGuard interfaces found{self.colors.end}")
            return

        # Read configs
        for interface in up_interfaces:
            self.read_config(interface)
            self.parse_wg_show(interface)

        # Ping if requested
        if self.args.ping:
            self.ping_all_peers()

        # Output results
        if self.output_format == OutputFormat.JSON:
            self.output_json()
        else:
            if self.args.timestamp:
                now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                print(f"{self.colors.blue_bold}[{now}]{self.colors.end}\n")
            self.show_all()

    def run_watch(self):
        """Run in watch mode"""
        try:
            while True:
                # Clear screen
                if self.output_format == OutputFormat.TTY:
                    os.system('clear')

                self.interfaces.clear()
                self.run_once()

                time.sleep(self.args.watch)
        except KeyboardInterrupt:
            print(f"\n{self.colors.yellow}Exiting...{self.colors.end}")

    def run(self):
        """Main entry point"""
        self.check_root()

        if not self.check_wireguard_installed():
            sys.exit(1)

        if self.args.watch:
            self.run_watch()
        else:
            self.run_once()


def main():
    ap = argparse.ArgumentParser(
        formatter_class=PreserveWhiteSpaceWrapRawTextHelpFormatter,
        description=textwrap.dedent('''\
            Wireguard Info - Enhanced
            =========================

            This tool enhances the output of 'wg show' to include node names,
            connection status, and various useful metrics.

            It expects you to use wg-quick and reads the wg-quick config at
            /etc/wireguard/INTERFACE.conf

            The human readable peer names are expected in the wg-quick config
            within a comment like this:
            [Peer]
            # Name = Very secret node in antarctica

        '''))

    # Output format options
    output_group = ap.add_argument_group('Output Options')
    output_group.add_argument('--html', action='store_true',
                             help="Format output as HTML")
    output_group.add_argument('--json', action='store_true',
                             help="Format output as JSON")
    output_group.add_argument('--tty', action='store_true',
                             help="Force terminal colors even when writing to pipe")
    output_group.add_argument('--timestamp', action='store_true',
                             help="Show timestamp in output")

    # Interface selection
    interface_group = ap.add_argument_group('Interface Selection')
    interface_group.add_argument('-i', '--interface',
                                help="Only show status for this interface")

    # Connectivity options
    connectivity_group = ap.add_argument_group('Connectivity Options')
    connectivity_group.add_argument('-p', '--ping', action='store_true',
                                   help="Ping all nodes (in parallel) and show online status")
    connectivity_group.add_argument('--ping-timeout', type=int, default=1,
                                   help="Timeout for ping in seconds (default: 1)")

    # Sorting and filtering
    sort_group = ap.add_argument_group('Sorting and Filtering')
    sort_group.add_argument('--sort', choices=['name', 'handshake', 'transfer', 'status'],
                           help="Sort peers by: name, handshake time, transfer amount, or status")
    sort_group.add_argument('--filter', choices=['online', 'offline'],
                           help="Filter to show only online or offline peers")

    # Watch mode
    watch_group = ap.add_argument_group('Watch Mode')
    watch_group.add_argument('-w', '--watch', type=int, metavar='SECONDS',
                            help="Refresh display every N seconds (watch mode)")

    # Debug options
    debug_group = ap.add_argument_group('Debug Options')
    debug_group.add_argument('-v', '--verbose', action='store_true',
                            help="Enable verbose logging")

    args = ap.parse_args()

    # Run the tool
    wg_info = WireGuardInfo(args)
    wg_info.run()


if __name__ == '__main__':
    main()
