#!/usr/bin/python
#  Copyright (C) 2012 by Carnegie Mellon University.
#  
#  @OPENSOURCE_HEADER_START@
#  Use of the Rayon and related source code is subject to the terms of
#  the following licenses:
#  
#  GNU Public License (GPL) Rights pursuant to Version 2, June 1991
#  Government Purpose License Rights (GPLR) pursuant to DFARS 252.227.7013
#  
#  NO WARRANTY
#  
#  ANY INFORMATION, MATERIALS, SERVICES, INTELLECTUAL PROPERTY OR OTHER 
#  PROPERTY OR RIGHTS GRANTED OR PROVIDED BY CARNEGIE MELLON UNIVERSITY 
#  PURSUANT TO THIS LICENSE (HEREINAFTER THE "DELIVERABLES") ARE ON AN 
#  "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY 
#  KIND, EITHER EXPRESS OR IMPLIED AS TO ANY MATTER INCLUDING, BUT NOT 
#  LIMITED TO, WARRANTY OF FITNESS FOR A PARTICULAR PURPOSE, 
#  MERCHANTABILITY, INFORMATIONAL CONTENT, NONINFRINGEMENT, OR ERROR-FREE 
#  OPERATION. CARNEGIE MELLON UNIVERSITY SHALL NOT BE LIABLE FOR INDIRECT, 
#  SPECIAL OR CONSEQUENTIAL DAMAGES, SUCH AS LOSS OF PROFITS OR INABILITY 
#  TO USE SAID INTELLECTUAL PROPERTY, UNDER THIS LICENSE, REGARDLESS OF 
#  WHETHER SUCH PARTY WAS AWARE OF THE POSSIBILITY OF SUCH DAMAGES. 
#  LICENSEE AGREES THAT IT WILL NOT MAKE ANY WARRANTY ON BEHALF OF 
#  CARNEGIE MELLON UNIVERSITY, EXPRESS OR IMPLIED, TO ANY PERSON 
#  CONCERNING THE APPLICATION OF OR THE RESULTS TO BE OBTAINED WITH THE 
#  DELIVERABLES UNDER THIS LICENSE.
#  
#  Licensee hereby agrees to defend, indemnify, and hold harmless Carnegie 
#  Mellon University, its trustees, officers, employees, and agents from 
#  all claims or demands made against them (and any related losses, 
#  expenses, or attorney's fees) arising out of, or relating to Licensee's 
#  and/or its sub licensees' negligent use or willful misuse of or 
#  negligent conduct or willful misconduct regarding the Software, 
#  facilities, or other rights or assistance granted by Carnegie Mellon 
#  University under this License, including, but not limited to, any 
#  claims of product liability, personal injury, death, damage to 
#  property, or violation of any laws or regulations.
#  
#  Carnegie Mellon University Software Engineering Institute authored 
#  documents are sponsored by the U.S. Department of Defense under 
#  Contract FA8721-05-C-0003. Carnegie Mellon University retains 
#  copyrights in all material produced under this contract. The U.S. 
#  Government retains a non-exclusive, royalty-free license to publish or 
#  reproduce these documents, or allow others to do so, for U.S. 
#  Government purposes only pursuant to the copyright license under the 
#  contract clause at 252.227.7013.
#  @OPENSOURCE_HEADER_END@

from __future__ import division
import rayon
from rayon import (
    _hilbert,
    backend_cairo,
    backgrounds,
    charts,
    data,
    datautils,
    miscutils,
    plots,
    raster,
    scales,
)
from rayon.cliutils import (
    mk_tickspec_parser,
    mkcb_input,
    mkcb_ticks,
    mkcb_args,
    mkcb_color,
    mkcb_intfromip
)
import netsa.data.nice

import cairo

from optparse import OptionParser, OptionValueError, SUPPRESS_HELP
import array
import datetime
import os
import math
import pprint
import socket
import struct
import sys

class ConfigurationError(Exception):
    pass

DEBUG = False
def debug(msg):
    if DEBUG:
        print msg

#
# Utility functions
#

def inet_atoi(addr):
    """
    Returns an integer representation of the given dotted-quad string
    IPv4 address.
    """
    return int(struct.unpack('!L', socket.inet_aton(addr))[0])

def inet_itoa(num):
    """
    Returns a dotted-quad string representation of the given integer
    IPv4 address
    """
    return socket.inet_ntoa(struct.pack('!L', num))

#
# Argument processing
#
defaults = {
    # Basic
    'input_path'          : '-',
    'output_path'         : None,
    'width'               : 1024,
    'height'              : 1024,
    'index_input'         : 0,
    'val_input'           : 1,
    'index_max'           : None,
    # Value floors and ceilings
    'val_floor'           : None,
    # val_floor_pct does have a default, but it's handled programmatically
    'val_floor_pct'       : None,
    'val_ceiling'         : None,
    'val_ceiling_pct'     : None,
    # Value coloring options
    'color_floor'         : miscutils.parse_color_string("#aaaaaaff"),
    'color_ceiling'       : miscutils.parse_color_string("#000000ff"),
    'color_num_steps'     : 221,
    'color_lo_outliers'   : None,
    'color_hi_outliers'   : None,
    'color_log_scale'     : False,
    # Misc
    #'title'               : None,
    'background_color'    : miscutils.parse_color_string("#ffffffff"),
    'print_options'       : False,
    'ip_data'             : True,
    'cidr_netmask'        : 20,
    'do_overlay'          : True,
    'overlay_file'        : None,
    'binary'              : False,
    'quiet'               : False,
}



def parse_options():
    op = OptionParser(usage="%prog [options] - "
                      "plot points and values along a Hilbert curve")
    op.set_defaults(**defaults)
    op.add_option("--input-path", dest="input_path",
                  help="input file ('-' == STDIN)")
    op.add_option("--output-path", dest="output_path",
                  help="(REQUIRED) output path")
    op.add_option("--width", type="int", dest="width",
                  help="image width (pixels for PNG, otherwise points) "
                  "(Default: %d)" % defaults['width'])
    op.add_option("--height", type="int", dest="height",
                  help="image height (pixels for PNG, otherwise points) "
                  "(Default: %d)" % defaults['height'])
    op.add_option("--index-input", dest="index_input",
                  action="callback", callback=mkcb_input("index_input"),
                  type="str", metavar="NAME_OR_IDX",
                  help="Name or index (from left, starting at 0) of "
                  "column to containing index variable (e.g., IP address)")
    op.add_option("--val-input", dest="val_input",
                  action="callback", callback=mkcb_input("val_input"),
                  type="str", metavar="NAME_OR_IDX",
                  help="Name or index (from left, starting at 0) of "
                  "column to containing value variable.")
    op.add_option("--index-max", dest="index_max",
                  action="callback", callback=mkcb_intfromip("index_max"),
                  type="str", metavar="INT_OR_IP",
                  help="Largest index to be plotted on the Hilbert curve. "
                  "(Default: 255.255.255.255 for IPs. Next power of 2 "
                  "greater than the largest data index for non-IPs)")
    op.add_option("--floor", dest="val_floor", type="float",
                  help="Smallest value to plot, in absolute terms")
    op.add_option("--floor-pct", dest="val_floor_pct", type="int",
                  help="Smallest value to plot, as a percentile of the data "
                  "(Default: 10. Set to 0 for no floor.)")
    op.add_option("--ceil", dest="val_ceiling", type="float",
                  help="Largest value to plot with normal gradient, "
                  "in absolute terms")
    op.add_option("--ceil-pct", dest="val_ceiling_pct", type="int",
                  help="Largest value to plot with normal gradient, "
                  "as a percentile of the data (Default: no ceiling)")
    op.add_option("--num-colors", dest="color_num_steps", type="int",
                  help="Number of steps in color gradient from floor-color "
                  "to ceil-color (Default: 221)")
    op.add_option('--floor-color', metavar="COLOR",
                  action="callback", callback=mkcb_color("color_floor"),
                  dest="color_floor", type="str",
                  help="Color of minimum value "
                  "(Default: \"#AAAAAAFF\" (light gray))")
    op.add_option('--ceil-color', metavar="COLOR",
                  action="callback", callback=mkcb_color("color_ceiling"),
                  dest="color_ceiling", type="str",
                  help="Color of maximum value "
                  "(Default: \"#000000FF\" (black))")
    op.add_option('--basement-color', metavar="COLOR",
                  action="callback", callback=mkcb_color("color_lo_outliers"),
                  dest="color_lo_outliers", type="str",
                  help="Color of values below minimum value "
                  "(Default: not plotted)")
    op.add_option('--attic-color', metavar="COLOR",
                  action="callback", callback=mkcb_color("color_hi_outliers"),
                  dest="color_hi_outliers", type="str",
                  help="Color of values above minimum value "
                  "(Default: not plotted)")
    op.add_option('--log-colors', action="store_true", dest="color_log_scale",
                  help="Compute color values along a log scale "
                  "(Default: off)")
    #op.add_option("--title", dest="title", help="chart title")
    op.add_option('--background-color', metavar="COLOR",
                  action="callback", callback=mkcb_color("background_color"),
                  dest="background_color", type="str",
                  help="Color of background (Default: \"#FFFFFFFF\" (white))")
    op.add_option('--print-options', action="store_true", dest="print_options",
                  help=SUPPRESS_HELP)
    op.add_option('--non-ip-data', action="store_false", dest="ip_data",
                  help=SUPPRESS_HELP)
    op.add_option('--cidr-netmask', dest="cidr_netmask", type="int",
                  help="If input is IPv4 data, size of subnets to display, "
                  "in no. of bits not to mask. (Default: 20, Max: 24)")
    op.add_option('--no-overlay', action="store_false", dest="do_overlay",
                  help="Don't overlay another image onto the Hilbert plot")
    op.add_option('--overlay-file', dest="overlay_file", metavar="FILE",
                  help="Use FILE as overlay image. (Default: hardcoded map "
                  "of IPv4 space")
    op.add_option('--binary-plot', action="store_true", dest="binary",
                  help="Plot each value as either on or off, rather "
                  "than as a count, using --ceil-color for point colors "
                  "(Default: no, unless input data contains only one column)")
    op.add_option('--quiet', action="store_true", dest="quiet",
                  help="Suppress warnings")
    options, args = op.parse_args()

    # Sanity checks

    # -- Subnet mask
    if options.cidr_netmask > 32 or options.cidr_netmask < 1:
        raise ConfigurationError(
            "invalid netmask (%d). Valid values are 1-32.")
        

    # Is subnet-mask divisible by two? We can plot odd-numbered
    # subnets, but they're...odd. So for now, it's disabled.
    if options.cidr_netmask % 2 != 0:
        if not opts.quiet:
            options.cidr_netmask += 1
            print ("Warning: odd-numbered subnet masks are not "
                   "supported at this time. Generating a Hilbert curve "
                   "for %d instead" % options.cidr_netmask)

    # -- Value scales
    if options.val_floor is not None and options.val_floor_pct is not None:
        raise ConfigurationError(
            "Ambiguous input: specify --floor or --floor-pct, not both")
    if options.val_ceiling is not None and options.val_ceiling_pct is not None:
        raise ConfigurationError(
            "Ambiguous input: specify --ceil or --ceil-pct, not both")

    # If a floor is not specified, impose a 10% floor. This is
    # generally useful for reducing noise.
    if (options.val_floor is None and
        options.val_floor_pct is None):
        options.val_floor_pct = 10


    def sanity_check_boundary_pct(attr, optname):
        val = getattr(options, attr)
        if val < 0 or val > 100 or int(val) != val:
            raise ConfigurationError("--%s must be be integer between "
                                     "0 and 100" % optname)

    if options.val_floor_pct is not None:
        sanity_check_boundary_pct("val_floor_pct", "val-floor-pct")
    if options.val_ceiling_pct is not None:
        sanity_check_boundary_pct("val_ceiling_pct", "val-ceil-pct")

    # -- Default colors
    if options.color_lo_outliers is None:
        options.color_lo_outliers = options.background_color
    if options.color_hi_outliers is None:
        options.color_hi_outliers = options.background_color

    # -- Overlay
    if options.overlay_file is None:
        options.overlay_file = rayon.data_file("hilbert_overlay")

    return options


#
# Debugging
#

def pprint_opts(opts):
    pp = pprint.PrettyPrinter()
    pp.pprint(opts.__dict__)


#
# Building the visualization
#

def get_context(opts):
    # Figure out which backend we're working with
    if opts.output_path.endswith(".svg"):
        surface = cairo.SVGSurface(opts.output_path, opts.width, opts.height)
        raster_output = False
    elif opts.output_path.endswith(".pdf"):
        surface = cairo.PDFSurface(opts.output_path, opts.width, opts.height)
        raster_output = False
    elif opts.output_path.endswith(".png"):
        surface = cairo.ImageSurface(cairo.FORMAT_ARGB32,
                                     opts.width, opts.height)
        raster_output = True
    else:
        raise Exception("Unknown image format for '%s'" % opts.output_path)

    return surface, backend_cairo.Context(cairo.Context(surface),
                                          opts.width, opts.height,
                                          raster_output)
                                          

printed_cidr_warning = False


def get_data(opts):
    debug("Reading data")
    if opts.input_path == "-":
        input_data = data.Dataset.from_stream(sys.stdin)
    else:
        input_data = data.Dataset.from_file(opts.input_path)

    if input_data.numcols() == 1:
        # Fake a 2-column dataset. 
        index_col = input_data.column(opts.index_input)
        val_col = data.ConstantColumn(1, len(index_col))
        input_data = data.Dataset((index_col, val_col))
        opts.index_input = 0
        opts.val_input = 1
        opts.binary == True

    # If ceiling is not specified, use max() of data as ceiling.
    if (opts.val_ceiling is None and
        opts.val_ceiling_pct is None):
        opts.val_ceiling = input_data.get_column(opts.val_input).max()
    

    # In practice, I don't see this ever being != 2, but I'll persist
    # in the charade for this function, at least.
    num_dimensions = 2 

    if opts.ip_data:
        num_bits = int(opts.cidr_netmask / num_dimensions)
    else:
        if opts.index_max is None:
            index_max = input_data.column(opts.index_input).max()
        else:
            index_max = opts.index_max
        num_bits = int(math.ceil(math.log(index_max, 2)) / num_dimensions)
        # Don't even try to plot more than 12 bits per dimension
        if num_bits > 12:
            raise ConfigurationError("Max index must be less than 2**24")
        
    # Create a grid to contain coordinate data
    debug("Allocating initial array (2**%d per side)" % num_bits)
    grid_data = array.array(
        "L", (0 for _ in xrange(2**(num_dimensions * num_bits))))
    grid_width = 2**(num_bits)

    # Insert data into grid.
    debug("Inserting data into grid")
    max_x = max_y = 0
    warn_on_capped_bins = False
    for row in input_data:
        # Note that if the user passes in CIDR data, it will be an
        # error (It used to be a warning and ignored.)
        ip = int(row[opts.index_input])
        
        # !! IMPORTANT !!  The "natural" way to do the Hilbert
        # algorithm (which i2c uses) is different from the way we use
        # it. (Specifically, the X and Y coordinates are swapped.)
        # That's because we wanted to make something that would
        # correspond to -- uh, this cartoon: http://xkcd.com/195/
        y, x = _hilbert.i2c(num_bits, num_dimensions, ip)
        if x > max_x: max_x = x
        if y > max_y: max_y = y
        val = row[opts.val_input]
        offset = (y * grid_width) + x
        try:
            grid_data[offset] += val
        except OverflowError:
            warn_on_capped_bins = True
            grid_data[offset] = sys.maxint

    if warn_on_capped_bins:
        print >>sys.stderr, "WARNING: some bins have been capped at LONG_MAX"


    # Identify the nearest power of 2 to the target resolution and
    # rebin the grid.
    target_width = int(2**round(math.log(opts.width, 2)))
    target_bits = int(math.log(target_width, 2))
    if target_bits != num_bits:
        debug("Rebinning into %d per side" % target_bits)
        target_data = array.array(
            "f", (0 for _ in xrange(target_width**num_dimensions)))
        scale_factor = 2**(num_bits - target_bits)
        debug("Scale factor is %s" % scale_factor)
        if scale_factor < 1:
            debug("Target has more bins than grid")
            # Target has more bins. Subdivide ours to get an equal number
            # bin = (val * scale_factor) / num_dimensions
            bit_diff = target_bits - num_bits
            debug("Bit difference: %s" % bit_diff)
            for index, val in enumerate(grid_data):
                x_index = index % grid_width
                y_index = index // grid_width
                target_x_index = int(x_index * (1 / scale_factor))
                debug("target_x: %s * (1 / %s) = %s" % (
                    x_index, scale_factor, target_x_index))
                target_y_index = int(y_index * (1 / scale_factor))
                debug("target_y: %s * (1 / %s) = %s" % (
                    y_index, scale_factor, target_y_index))
                target_index = (target_y_index * target_width) + target_x_index
                debug("(%s, %s) -> (%s, %s) [%s]" % (x_index, y_index,
                                                     target_x_index,
                                                     target_y_index,
                                                     target_index))
                # Divide value in bin by 1/scale_factor, then divide
                # by num_dimensions
                denom = (num_dimensions**bit_diff)
                share_val = val / denom
                debug("share_val = %s / %s" % (val, denom))
                # (Okay, the illusion that we can do this
                # multidimensionally is wearing thin, I grant you.)
                for col_offset in xrange(num_dimensions**bit_diff):
                    for row_offset in xrange(num_dimensions**bit_diff):
                        target_bin = int((target_y_index + col_offset) *
                                         target_width + 
                                         (target_x_index + row_offset))
                        debug("\ttarget_data[%s] += %s" % (target_bin,
                                                           share_val))
                        target_data[target_bin] += share_val
                
        else:
            debug("Target has fewer bins than grid")
            # Target has fewer bins. Bin ours up to target
            for index, val in enumerate(grid_data):
                x_index = index % grid_width
                y_index = index // grid_width
                target_x_index = int(math.floor(x_index * (1 / scale_factor)))
                target_y_index = int(math.floor(y_index * (1 / scale_factor)))
                target_bin = (target_y_index * target_width) + target_x_index
                target_data[target_bin] += val
        grid_data = target_data
        grid_width = target_width
                

    return grid_data, grid_width


def colorize_data(grid_data, grid_width, opts):
    # If we are using percentiles to determine coloring, we need to
    # sort the data at this point.
    if (not opts.binary and
        (opts.val_floor_pct is not None or
         opts.val_ceiling_pct is not None)):
        stats = datautils.Statistics(
            [cell for cell in grid_data if cell != 0])
        if opts.val_floor_pct is not None:
            opts.val_floor = stats.percentile(opts.val_floor_pct)
            opts.val_floor_pct = None
        if opts.val_ceiling_pct is not None:
            opts.val_ceiling = stats.percentile(opts.val_ceiling_pct)
            opts.val_ceiling_pct = None

    if opts.binary:
        def colorize(val):
            if val:
                return opts.color_ceiling
            else:
                return opts.background_color
    else:
        lo_color = opts.color_lo_outliers
        hi_color = opts.color_hi_outliers
        # Get color scale
        if opts.color_log_scale:
            colorscale = scales.CountingLogGradientScale(
                opts.val_floor, opts.val_ceiling,
                opts.color_floor, opts.color_ceiling,
                opts.color_num_steps)
        else:
            colorscale = scales.LinearGradientScale(
                opts.val_floor, opts.val_ceiling,
                opts.color_floor, opts.color_ceiling,
                opts.color_num_steps)
        # Get colorizing function
        def colorize(val):
            if val == 0:
                return opts.background_color
            elif opts.val_floor is not None and val < opts.val_floor:
                return lo_color
            elif opts.val_ceiling is not None and val > opts.val_ceiling:
                return hi_color
            else:
                return colorscale(val)

    rgb = raster.Bitmap(grid_width, grid_width, fmt=raster.ARGB32_PM,)
    for index, cell in enumerate(grid_data):
        color = colorize(grid_data[index])
        rgb.set_pixel(index, color, fmt=raster.RGBA32)

    return rgb
                

def get_chart(ctx, backend, opts):
    chart = charts.SquareChart(backend)

    # We're not really plotting using the Plot. (Mild abuse of the
    # system.) But put an empty plot in the chart because it will kick
    # otherwise.
    #chart.set_plot(plots.Multiplot())

    # Note that grid_data is _not_ a dataset. It's a bitmap.
    grid_data, grid_width = get_data(opts)

    # convert grid_data to bitmap rgb_data based on color value
    # information
    rgb_data = colorize_data(grid_data, grid_width, opts)
    composite_background = backgrounds.CompositeBackground()
    # The hilbert
    composite_background.add(backgrounds.BitmapBackground(rgb_data,
                                                          fill="scale"))
    # The overlay
    if opts.do_overlay:
        composite_background.add(
            backgrounds.PNGBackground(opts.overlay_file, fill="scale"))
    chart.set_plot_background(composite_background)
    return chart


def main():
    opts = parse_options()
    if opts.print_options:
        pprint_opts(opts)
        return

    surface, ctx = get_context(opts)
    chart = get_chart(ctx, backend_cairo, opts)
    chart(ctx)
    
    if opts.output_path.endswith(".png"):
        surface.write_to_png(opts.output_path)

                                        




try:
    main()
except KeyboardInterrupt:
    print "Application interrupted by user"
    sys.exit(1)
except ConfigurationError, e:
    print "Configuration error: %s" % str(e)
    print "Type 'ryhilbert --help' for more information on arguments"
    sys.exit(1)
except Exception, e:
    if os.getenv("RY_PDB") is not None:
        print str(e)
        import pdb
        pdb.post_mortem(sys.exc_info()[2])
        raise
    else:
        raise
