#!/usr/bin/python

import os
import sys
import argparse
import time
import math
from gi.repository import Ufo
from tofu.reco import setup_padding
from tofu.util import range_from


def reco(width, height, num_projections, args):
    pm = Ufo.PluginManager()

    generate = pm.get_task('dummy-data')
    bp = pm.get_task('backproject')
    null = pm.get_task('null')

    generate.set_properties(number=height, width=width, height=num_projections)
    bp.set_properties(axis_pos=width / 2, angle_step=math.pi / num_projections)
    null.set_properties(force_download=True)

    if args.backproject_mode:
        bp.set_properties(mode=args.backproject_mode)

    g = Ufo.TaskGraph()

    if not args.disable_filtering:
        fft = pm.get_task('fft')
        ifft = pm.get_task('ifft')
        fltr = pm.get_task('filter')

        fft.set_properties(dimensions=1)
        ifft.set_properties(dimensions=1)

        if args.enable_padding:
            pad = pm.get_task('pad')
            crop = pm.get_task('cut-roi')
            setup_padding(pad, crop, width, num_projections)

            g.connect_nodes(generate, pad)
            g.connect_nodes(pad, fft)
            g.connect_nodes(fft, fltr)
            g.connect_nodes(fltr, ifft)
            g.connect_nodes(ifft, crop)
            g.connect_nodes(crop, bp)
            g.connect_nodes(bp, null)
        else:
            g.connect_nodes(generate, fft)
            g.connect_nodes(fft, fltr)
            g.connect_nodes(fltr, ifft)
            g.connect_nodes(ifft, bp)
            g.connect_nodes(bp, null)
    else:
        g.connect_nodes(generate, bp)
        g.connect_nodes(bp, null)

    scheduler = Ufo.Scheduler()

    if args.addresses:
        scheduler.set_properties(remotes=args.addresses)

    if args.enable_tracing:
        scheduler.set_properties(enable_tracing=True)

    scheduler.run(g)
    return scheduler.props.time


def measure(width, height, num_projections, output, args):
    exec_times = []
    total_times = []

    for i in range(args.num_runs):
        start = time.time()
        exec_times.append(reco(width, height, num_projections, args))
        total_times.append(time.time() - start)

    exec_time = sum(exec_times) / len(exec_times)
    total_time = sum(total_times) / len(total_times)
    overhead = (total_time / exec_time - 1.0) * 100
    input_bandwidth = width * height * num_projections * 4 / exec_time / 1024. / 1024.
    output_bandwidth = width * width * height * 4 / exec_time / 1024. / 1024.
    slice_bandwidth = height / exec_time

    # Four bytes of our output bandwidth constitute one slice pixel, for each
    # pixel we have to do roughly n * 6 floating point ops (2 mad, 1 add, 1
    # interpolation)
    flops = output_bandwidth / 4 * 6 * num_projections / 1024

    msg = ("width={:<6d} height={:<6d} n_proj={:<6d}  "
           "exec={:.4f}s  total={:.4f}s  overhead={:.2f}%  "
           "bandwidth_i={:.2f}MB/s  bandwidth_o={:.2f}MB/s slices={:.2f}/s  "
           "flops={:.2f}GFLOPs\n")

    output.write(msg.format(width, height, num_projections,
                            exec_time, total_time, overhead,
                            input_bandwidth, output_bandwidth, slice_bandwidth,
                            flops))
    output.flush()


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--width', type=str, default='1024',
                        help="Width or range of width of a generated projection")

    parser.add_argument('--height', type=str, default='1024',
                        help="Height or range of height of a generated projection")

    parser.add_argument('--num-projections', type=str, default='512',
                        help="Number or range of number of projections")

    parser.add_argument('--num-runs', type=int, default=3,
                        help="Number of runs")

    parser.add_argument('--nv-devices', type=str, default=None,
                        help="Range of devices to be scanned over")

    parser.add_argument('-o', '--output', type=str,
                        help="Output file for results")

    parser.add_argument('-a', '--addresses', type=str, nargs='*', metavar='ADDR',
                        help="Addresses of remote servers running `ufod'")

    parser.add_argument('--disable-filtering', action='store_true', default=False,
                        help="Disable filtering of sinograms")

    parser.add_argument('--enable-tracing', action='store_true', default=False,
                        help="Enable tracing")

    parser.add_argument('--backproject-mode', choices=['nearest', 'texture'],
                        help="Backproject mode")

    parser.add_argument('--enable-padding', action='store_true', default=False,
                        help="Enable padding to prevent reconstruction ring")

    args = parser.parse_args()
    output = sys.stdout if not args.output else open(args.output, 'w')

    if args.nv_devices:
        var = ','.join(str(n) for n in range(*range_from(args.nv_devices)))
        os.environ['CUDA_VISIBLE_DEVICES'] = var

    for width in range(*range_from(args.width)):
        for height in range(*range_from(args.height)):
            for num_projections in range(*range_from(args.num_projections)):
                measure(width, height, num_projections, output, args)


if __name__ == '__main__':
    main()
