#! /usr/bin/python3
#
# Copyright (c) 2017-2018, AT&T Intellectual Property.  All rights reserved.
#
# SPDX-License-Identifier: LGPL-2.1-only
#
'''
Move non-dataplane, non-virtualization tasks to cpu 0 to avoid tasks getting
stuck on a real-time CPU and to leave virtualization CPUs free for guests.
'''

import os
import logging
import argparse
import sys
import subprocess

from vyatta import cpuset
from vyatta import configd
from vplaned import Controller

CGROUP_MOUNT = '/sys/fs/cgroup'
ONLINE_CORES = '/sys/devices/system/cpu/online'
WORKSET_ROOT = '/sys/devices/virtual/workqueue'
SOFTIRQD_ROOT = '/proc/irq'

SYSTEM_CPUSET = 'system.slice'
DP_SLICE = 'dataplane.slice'
VIRT_SLICE = 'virtualization.slice'
DEFAULT_CPUSET_DIRS = [SYSTEM_CPUSET]

VM_DIRS = [
    'libvirtd.service',
    'libvirt-guests.service',
    ]
DATAPLANE_DIRS = ['vyatta-dataplane.service', DP_SLICE]

LOG = logging.getLogger('cpu-shield')
LOG.setLevel(logging.WARNING)

def fs_type_for_mount(mountpoint):
    '''
    Return the filesystem type for a mount point
    '''
    with open('/proc/mounts', 'r') as f:
        for l in f:
            l = l.split()
            if l[1] == mountpoint:
                return l[2]
    return None

def write_to_file(fname, data):
    ''' dummy '''
    with open(fname, 'w') as write_f:
        write_f.write(data)

def get_dataplane_cpus():
    '''
    Get the set of cores the dataplane is using for forwarding.
    '''
    try:
        with Controller() as controller:
            for dp in controller.get_dataplanes():
                if dp.local:
                    with dp:
                        j = dp.json_command("cpu")
                        return cpuset.Cpuset(j['forwarding_cores'], True)
    except:
        pass

    return cpuset.Cpuset('')

def get_vm_cpus():
    '''
    Get the set of cores the VMs are configured to use for forwarding.
    '''
    cset = cpuset.Cpuset('')
    try:
        VM_CONFIG_STRING = 'virtualization guest'
        client = configd.Client()
        cfg = client.tree_get_dict(VM_CONFIG_STRING)
        for guest in cfg['guest']:
            try:
                cset.add_cpuset(cpuset.Cpuset(guest['cpuset']))
            except KeyError:
                pass
            if 'vcpupin' in guest and 'vcpu' in guest['vcpupin']:
                for vcpu in guest['vcpupin']['vcpu']:
                    try:
                        cset.add_cpuset(cpuset.Cpuset(vcpu['cpuset']))
                    except KeyError:
                        pass
    except:
        #This is expected if not on a VNF image.
        pass
    return cset

def get_online_cpus():
    with open(ONLINE_CORES, 'r') as read_f:
        online_cores = read_f.read()
        return cpuset.Cpuset(online_cores)

def get_system_cpus(dp_core_string):
    cpus = get_online_cpus()
    if dp_core_string:
        dp = cpuset.Cpuset(dp_core_string, True)
    else:
        dp = get_dataplane_cpus()
    vm = get_vm_cpus()
    LOG.info("System cpus: %s - dp %s - vm %s", cpus.range(),
             dp.range(), vm.range())

    cpus.remove_cpuset(dp)
    cpus.remove_cpuset(vm)
    # Ensure that there is at least one core in the system CPU set
    if not cpus.range():
        cpus.add_cpuset(cpuset.Cpuset("0"))
    LOG.info("Allowing cpus %s for system processes" % cpus.range())

    return cpus

def restrict_kworkers_cpus(cpus):
    ''' Restict the kworker workgroups to the set of system cores. '''
    LOG.info("Restricting kworker cpus to %s", cpus.range())
    for root, child_dirs, files in os.walk(WORKSET_ROOT):
        if'cpumask' in files:
            f = os.path.join(root, 'cpumask')
            LOG.debug("Setting %s to %s (%s)", f, cpus.mask(), cpus.range())
            write_to_file(f, cpus.mask())

def restrict_softirqd_cpus(cpus):
    '''
    Restict the softirqd threads to the set of system cores. Some IRQs
    can not be moved, so the write to the file may fail.
    '''
    LOG.info("Restricting softirqd cpus to %s", cpus.range())
    for root, child_dirs, files in os.walk(SOFTIRQD_ROOT):
        if 'smp_affinity' in files:
            f = os.path.join(root, 'smp_affinity')
            LOG.debug("Setting %s to %s (%s)", f, cpus.mask(), cpus.range())
            try:
                write_to_file(f, cpus.mask())
            except:
                LOG.debug("Setting %s to %s permission denied", f, cpus.range())

class CpushieldCg1(object):
    """ Cpushield using cgroupv1 """

    SYSTEMD_ROOT = CGROUP_MOUNT + '/systemd'
    DATAPLANE_SERVICE_ROOT = CGROUP_MOUNT + '/systemd/system.slice/vyatta-dataplane.service'
    CPUSET_ROOT = CGROUP_MOUNT + '/cpuset'
    DATAPLANE_SLICE_ROOT = CGROUP_MOUNT + '/systemd/' + DP_SLICE

    def _setup_cset_dir(self, dirpath, cpus, mems):
        '''
        Create and populate one cpuset directory. Don't modify if the dir
        already exists.
        '''
        if not os.path.exists(dirpath):
            os.makedirs(dirpath, 0o755)
            write_to_file(os.path.join(dirpath, 'cpuset.cpus'), cpus)
            write_to_file(os.path.join(dirpath, 'cpuset.mems'), mems)

    def _move_one(self, tasks_dir, cpuset_dir):
        '''
        Move tasks from tasks_dir to cpuset_dir.
        '''
        with open(os.path.join(tasks_dir, 'tasks')) as read_f:
            tasks = read_f.read().split()
        if not tasks:
            return
        # write one at a time to catch close errors
        LOG.debug("Moving from %s to %s", tasks_dir, cpuset_dir)
        for tid in tasks:
            try:
                with open(os.path.join(cpuset_dir, 'tasks'), 'w') as write_f:
                    write_f.write(tid)
                    LOG.debug("Task %s moved to %s", tid, os.path.join(cpuset_dir, 'tasks'))
            except IOError:
                LOG.debug("Task %s can't be moved to %s", tid, os.path.join(cpuset_dir, 'tasks'))

    def _move_tree_tasks(self, start_dir, cpuset_dir, exclude_dirs):
        '''
        Move threads (pids in tasks file) from start_dir and subdirs into
        cpuset_dir.
        '''
        LOG.info("Moving from %s to %s", start_dir, cpuset_dir)
        for root, child_dirs, dummy in os.walk(start_dir):
            for exclude_d in exclude_dirs:
                try:
                    child_dirs.remove(exclude_d)
                except ValueError:
                    pass
            self._move_one(root, cpuset_dir)

    def _move_tasks_to_system(self):
        '''
        Move all threads other than the VM/dataplane tasks into the system cpuset.
        '''
        cpuset = os.path.join(self.CPUSET_ROOT, 'system.slice')
        self._move_tree_tasks(self.SYSTEMD_ROOT, cpuset, VM_DIRS + DATAPLANE_DIRS)

    def _get_cpuset_mems(self):
        with open(os.path.join(self.CPUSET_ROOT, 'cpuset.mems')) as mems_f:
            mems = mems_f.read()
            return mems;

    def set_system_cpus(self, dp_core_string):
        '''
        Set the system cpuset to use the online set of cores minus the dataplane
        cores.
        '''
        cpus = get_system_cpus(dp_core_string)
        cpuset_path = os.path.join(self.CPUSET_ROOT, SYSTEM_CPUSET)
        f = os.path.join(cpuset_path, 'cpuset.cpus')
        LOG.info("Setting %s to %s", f, cpus.range())
        try:
            write_to_file(f, cpus.range())
        except:
            LOG.info("Setting %s to %s failed", f, cpus.range())
        return cpus

    def setup_dataplane(self):
        '''
        Move the dataplane which is currently initialising into the
        correct cpuset (toplevel) if not already there.
        '''
        LOG.info("Setting dataplane into root cpuset")
        if os.path.exists(self.DATAPLANE_SERVICE_ROOT):
            self._move_one(self.DATAPLANE_SERVICE_ROOT, self.CPUSET_ROOT);
        if os.path.exists(self.DATAPLANE_SLICE_ROOT):
            self._move_tree_tasks(self.DATAPLANE_SLICE_ROOT, self.CPUSET_ROOT, []);

    def setup_system(self):
        '''
        Create the cpuset groups and move system tasks into the system
        cpuset
        '''
        cpus = get_online_cpus()
        mems = self._get_cpuset_mems()
        for c in DEFAULT_CPUSET_DIRS:
            systemd_dir = os.path.join(self.SYSTEMD_ROOT, c)
            cpuset_dir = os.path.join(self.CPUSET_ROOT, c)
            self._setup_cset_dir(cpuset_dir, cpus.range(), mems)
        self._move_tasks_to_system()

    def set_vm_cpus(self, dp_core_string):
        '''
        Set the cpus for the VMs that are running and do not have a cpuset configured.
        '''
        if not os.path.exists("/opt/vyatta/sbin/vyatta-hypervisor"):
            return
        subprocess.call(["/opt/vyatta/sbin/vyatta-hypervisor",
                         "--action=update-cpuset"])

class CpushieldCg2(object):
    """ Cpushield using cgroupv2 """

    CPUS_FILE = 'cpuset.cpus'
    CONTROL_FILE = 'cgroup.subtree_control'
    CGROUP_EXCLUSIONS = ['machine.slice', DP_SLICE, VIRT_SLICE]

    def set_system_cpus(self, dp_core_string):
        '''
        Set the system cpuset to use the online set of cores minus the dataplane
        cores.
        '''
        cpus = get_system_cpus(dp_core_string)
        for entry in os.scandir(CGROUP_MOUNT):
            if entry.is_dir() and not entry.name in self.CGROUP_EXCLUSIONS:
                cpus_file = os.path.join(entry.path, self.CPUS_FILE)
                try:
                    write_to_file(cpus_file, cpus.range())
                except FileNotFoundError:
                    LOG.info("Setting %s to %s failed. cpu-shield service not yet started", cpus_file, cpus.range())
                except:
                    LOG.info("Setting %s to %s failed", cpus_file, cpus.range())
        return cpus

    def setup_dataplane(self):
        '''
        It's in the dataplane slice so nothing to be done
        '''

    def setup_system(self):
        '''
        Enable the cpuset controller
        '''
        root_control_file = os.path.join(CGROUP_MOUNT, self.CONTROL_FILE)
        try:
            write_to_file(root_control_file, "+cpuset")
        except:
            LOG.error("Enabling cpuset cgroupv2 controller failed. Does your kernel have support for the cpuset controller?")
            return

    def set_vm_cpus(self, dp_core_string):
        '''
        Allow libvirtd and the guests to use the necessary CPUs
        '''
        virt_cpus = get_vm_cpus()
        virt_cpus.add_cpuset(get_system_cpus(dp_core_string))
        LOG.debug("Allowing cpus %s for libvirtd and guests" % virt_cpus.range())

        # Update the cpuset for the VMs that are running
        if not os.path.exists("/opt/vyatta/sbin/vyatta-hypervisor"):
            return
        subprocess.call(["/opt/vyatta/sbin/vyatta-hypervisor",
                         "--action=update-cpuset"])

    def write_cpuset(self, slice, cpus):
        slice_dir = os.path.join(CGROUP_MOUNT, slice)
        if not os.path.exists(slice_dir):
            os.makedirs(slice_dir, 0o755)
        slice_cpus_file = os.path.join(slice_dir, self.CPUS_FILE)
        try:
            write_to_file(slice_cpus_file, cpus.range())
        except:
            LOG.info("Setting %s to %s failed", slice_cpus_file, cpus.range())

def main(args):
    '''
    Create cpusets so that the dataplane forwarders always get full use of
    the cpus they are using. This is called on startup, or whenever the set
    of cores the dataplane is using are modified.
    '''

    if (args.startup):
        LOG.info("cpu_shield startup")

    if fs_type_for_mount(CGROUP_MOUNT) == 'cgroup2':
        LOG.info("Using cgroupv2")
        cpushield = CpushieldCg2()

        if args.startup and args.early:
            cpushield.setup_system()
            # Don't do any other initialisation since we can't really
            # work out the set of system CPUs until both the
            # configuration has been parsed and the dataplane is up
            #
            # This will be done instead at the later init point
            sys.exit();

    else:
        LOG.info("Using cgroupv1")
        if args.early and args.startup:
            LOG.info("Nothing to do for cgroupv1 early init")
            sys.exit();
        cpushield = CpushieldCg1()

        if (args.startup):
            # Setup first time only.
            cpushield.setup_system()

    if (args.dataplane_init):
        cpushield.setup_dataplane()
        sys.exit();

    # All set up, now update the cpus.
    sys_cpus = cpushield.set_system_cpus(args.dp)
    cpushield.set_vm_cpus(args.dp)
    restrict_kworkers_cpus(sys_cpus)
    restrict_softirqd_cpus(sys_cpus)

if __name__ == '__main__':
    logging.basicConfig()
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataplane_init", action='store_true',
                        help="Put the dataplane into the correct cpuset at init")
    parser.add_argument("--update", action='store_true',
                        help="Only update the cpus, don't create cpusets or move tasks")
    parser.add_argument("--dp",
                        help="Provide the current dataplane fwding cores as a mask")
    parser.add_argument("--early", action='store_true',
                        help="Perform early system startup initialisation")
    parser.add_argument("--startup", action='store_true',
                        help="Perform system startup initialisation")
 
    args = parser.parse_args()
    if not args.startup and not args.dataplane_init and not args.update:
        print("One argument of --startup, --dataplane_init or --update is required")
        sys.exit()
    main(args)
