#!/usr/bin/python3

"""Python script to plot tally data generated by OpenMC."""

import os
import sys
import argparse

import six.moves.tkinter as tk
import six.moves.tkinter_filedialog as filedialog
import six.moves.tkinter_font as font
import six.moves.tkinter_messagebox as messagebox
import six.moves.tkinter_ttk as ttk
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.backends.backend_tkagg import NavigationToolbar2TkAgg
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np

from openmc import StatePoint, MeshFilter


class MeshPlotter(tk.Frame):
    def __init__(self, parent, filename):
        tk.Frame.__init__(self, parent)

        self.labels = {'Cell': 'Cell:', 'Cellborn': 'Cell born:',
                       'Surface': 'Surface:', 'Material': 'Material:',
                       'Universe': 'Universe:', 'Energy': 'Energy in:',
                       'Energyout': 'Energy out:'}

        self.filterBoxes = {}

        # Read data from source or leakage fraction file
        self.get_file_data(filename)

        # Set up top-level window
        top = self.winfo_toplevel()
        top.title('Mesh Tally Plotter: ' + filename)
        top.rowconfigure(0, weight=1)
        top.columnconfigure(0, weight=1)
        self.grid(sticky=tk.W+tk.N)

        # Create widgets and draw to screen
        self.create_widgets()
        self.update()

    def create_widgets(self):
        figureFrame = tk.Frame(self)
        figureFrame.grid(row=0, column=0)

        # Create the Figure and Canvas
        self.dpi = 100
        self.fig = Figure((5.0, 5.0), dpi=self.dpi)
        self.canvas = FigureCanvasTkAgg(self.fig, master=figureFrame)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        # Create the navigation toolbar, tied to the canvas
        self.mpl_toolbar = NavigationToolbar2TkAgg(self.canvas, figureFrame)
        self.mpl_toolbar.update()
        self.canvas._tkcanvas.pack(side=tk.TOP, fill=tk.BOTH, expand=1)

        # Create frame for comboboxes
        self.selectFrame = tk.Frame(self)
        self.selectFrame.grid(row=1, column=0, sticky=tk.W+tk.E)

        # Tally selection
        labelTally = tk.Label(self.selectFrame, text='Tally:')
        labelTally.grid(row=0, column=0, sticky=tk.W)
        self.tallyBox = ttk.Combobox(self.selectFrame, state='readonly')
        self.tallyBox['values'] = [self.datafile.tallies[i].id
                                   for i in self.meshTallies]
        self.tallyBox.current(0)
        self.tallyBox.grid(row=0, column=1, sticky=tk.W+tk.E)
        self.tallyBox.bind('<<ComboboxSelected>>', self.update)

        # Planar basis selection
        labelBasis = tk.Label(self.selectFrame, text='Basis:')
        labelBasis.grid(row=1, column=0, sticky=tk.W)
        self.basisBox = ttk.Combobox(self.selectFrame, state='readonly')
        self.basisBox['values'] = ('xy', 'yz', 'xz')
        self.basisBox.current(0)
        self.basisBox.grid(row=1, column=1, sticky=tk.W+tk.E)
        self.basisBox.bind('<<ComboboxSelected>>', self.update)

        # Axial level
        labelAxial = tk.Label(self.selectFrame, text='Axial level:')
        labelAxial.grid(row=2, column=0, sticky=tk.W)
        self.axialBox = ttk.Combobox(self.selectFrame, state='readonly')
        self.axialBox.grid(row=2, column=1, sticky=tk.W+tk.E)
        self.axialBox.bind('<<ComboboxSelected>>', self.redraw)

        # Option for mean/uncertainty
        labelMean = tk.Label(self.selectFrame, text='Mean/Uncertainty:')
        labelMean.grid(row=3, column=0, sticky=tk.W)
        self.meanBox = ttk.Combobox(self.selectFrame, state='readonly')
        self.meanBox['values'] = ('Mean', 'Absolute uncertainty',
                                  'Relative uncertainty')
        self.meanBox.current(0)
        self.meanBox.grid(row=3, column=1, sticky=tk.W+tk.E)
        self.meanBox.bind('<<ComboboxSelected>>', self.update)

        # Scores
        labelScore = tk.Label(self.selectFrame, text='Score:')
        labelScore.grid(row=4, column=0, sticky=tk.W)
        self.scoreBox = ttk.Combobox(self.selectFrame, state='readonly')
        self.scoreBox.grid(row=4, column=1, sticky=tk.W+tk.E)
        self.scoreBox.bind('<<ComboboxSelected>>', self.redraw)

        # Filter label
        boldfont = font.Font(weight='bold')
        labelFilters = tk.Label(self.selectFrame, text='Filters:',
                                font=boldfont)
        labelFilters.grid(row=5, column=0, sticky=tk.W)

    def update(self, event=None):
        if not event:
            widget = None
        else:
            widget = event.widget

        tally_id = self.meshTallies[self.tallyBox.current()]
        selectedTally = self.datafile.tallies[tally_id]

        # Get mesh for selected tally
        self.mesh = selectedTally.find_filter(MeshFilter).mesh

        # Get mesh dimensions
        if len(self.mesh.dimension) == 2:
            self.nx, self.ny = self.mesh.dimension
            self.nz = 1
        else:
            self.nx, self.ny, self.nz = self.mesh.dimension

        # Repopulate comboboxes baesd on current basis selection
        text = self.basisBox.get()
        if text == 'xy':
            self.axialBox['values'] = [str(i+1) for i in range(self.nz)]
        elif text == 'yz':
            self.axialBox['values'] = [str(i+1) for i in range(self.nx)]
        else:
            self.axialBox['values'] = [str(i+1) for i in range(self.ny)]
        self.axialBox.current(0)

        # If update() was called by a change in the basis combobox, we don't
        # need to repopulate the filters
        if widget == self.basisBox:
            self.redraw()
            return

        # Update scores
        self.scoreBox['values'] = selectedTally.scores
        self.scoreBox.current(0)

        # Remove any filter labels/comboboxes that exist
        for row in range(6, self.selectFrame.grid_size()[1]):
            for w in self.selectFrame.grid_slaves(row=row):
                w.grid_forget()
                w.destroy()

        # create a label/combobox for each filter in selected tally
        count = 0
        for f in selectedTally.filters:
            filterType = f.short_name
            if filterType == 'Mesh':
                continue
            count += 1

            # Create label and combobox for this filter
            label = tk.Label(self.selectFrame, text=self.labels[filterType])
            label.grid(row=count+6, column=0, sticky=tk.W)
            combobox = ttk.Combobox(self.selectFrame, state='readonly')
            self.filterBoxes[filterType] = combobox

            # Set combobox items
            if filterType in ['Energy', 'Energyout']:
                combobox['values'] = ['{0} to {1}'.format(*f.bins[i:i+2])
                                      for i in range(len(f.bins) - 1)]
            else:
                combobox['values'] = [str(i) for i in f.bins]

            combobox.current(0)
            combobox.grid(row=count+6, column=1, sticky=tk.W+tk.E)
            combobox.bind('<<ComboboxSelected>>', self.redraw)

        # If There are no filters, leave a 'None available' message
        if count == 0:
            count += 1
            label = tk.Label(self.selectFrame, text="None Available")
            label.grid(row=count+6, column=0, sticky=tk.W)

        self.redraw()

    def redraw(self, event=None):
        basis = self.basisBox.current() + 1
        axial_level = self.axialBox.current() + 1
        mbvalue = self.meanBox.get()

        # Get selected tally
        tally_id = self.meshTallies[self.tallyBox.current()]
        selectedTally = self.datafile.tallies[tally_id]

        # Create spec_list
        spec_list = []
        for f in selectedTally.filters:
            if f.short_name == 'Mesh':
                mesh_filter = f
                continue
            elif f.short_name in ['Energy', 'Energyout']:
                index = self.filterBoxes[f.short_name].current()
                ebin = (f.bins[index], f.bins[index + 1])
                spec_list.append((type(f), (ebin,)))
            else:
                index = self.filterBoxes[f.short_name].current()
                spec_list.append((type(f), (index,)))

        text = self.basisBox.get()
        if text == 'xy':
            dims = (self.nx, self.ny)
        elif text == 'yz':
            dims = (self.ny, self.nz)
        else:
            dims = (self.nx, self.nz)
        matrix = np.zeros(dims)
        for i in range(dims[0]):
            for j in range(dims[1]):
                if text == 'xy':
                    meshtuple = (i + 1, j + 1, axial_level)
                elif text == 'yz':
                    meshtuple = (axial_level, i + 1, j + 1)
                else:
                    meshtuple = (i + 1, axial_level, j + 1)
                filters, filter_bins = zip(*spec_list + [
                    (type(mesh_filter), (meshtuple,))])
                mean = selectedTally.get_values(
                    [self.scoreBox.get()], filters, filter_bins)
                stdev = selectedTally.get_values(
                    [self.scoreBox.get()], filters, filter_bins,
                    value='std_dev')
                if mbvalue == 'Mean':
                    matrix[i, j] = mean
                elif mbvalue == 'Absolute uncertainty':
                    matrix[i, j] = stdev
                else:
                    if mean > 0.:
                        matrix[i, j] = stdev/mean
                    else:
                        matrix[i, j] = 0.

        # Clear the figure
        self.fig.clear()

        # Make figure, set up color bar
        self.axes = self.fig.add_subplot(111)
        cax = self.axes.imshow(matrix.transpose(), vmin=0.0, vmax=matrix.max(),
                               interpolation='none', origin='lower')
        self.fig.colorbar(cax)

        self.axes.set_xticks([])
        self.axes.set_yticks([])
        self.axes.set_aspect('equal')

        # Draw canvas
        self.canvas.draw()

    def get_file_data(self, filename):
        # Create StatePoint object and read in data
        self.datafile = StatePoint(filename)

        # Find which tallies are mesh tallies
        self.meshTallies = []
        for itally, tally in self.datafile.tallies.items():
            if any([isinstance(f, MeshFilter) for f in tally.filters]):
                self.meshTallies.append(itally)

        if not self.meshTallies:
            messagebox.showerror("Invalid StatePoint File",
                                 "File does not contain mesh tallies!")
            sys.exit(1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('statepoint', nargs='?', help='Statepoint file')
    args = parser.parse_args()

    # Hide root window
    root = tk.Tk()
    root.withdraw()

    # If no filename given as command-line argument, open file dialog
    if args.statepoint is None:
        filename = filedialog.askopenfilename(title='Select statepoint file',
                                              initialdir='.')
    else:
        filename = args.statepoint

    if filename:
        # Check to make sure file exists
        if not os.path.isfile(filename):
            messagebox.showerror("File not found",
                                 "Could not find regular file: " + filename)
            sys.exit(1)

        app = MeshPlotter(root, filename)
        root.deiconify()
        root.mainloop()
