#!/usr/bin/python3
"""Update OpenMC's deprecated multi-group cross section XML files to the latest
HDF5-based format.

"""

from __future__ import print_function
import os
import warnings
import xml.etree.ElementTree as ET

import argparse
import numpy as np

import openmc.mgxs_library

description = """\
Update OpenMC's deprecated multi-group cross section XML files to the latest
HDF5-based format."""


def parse_args():
    """Read the input files from the commandline."""
    # Create argument parser
    parser = argparse.ArgumentParser(description=description,
                                     formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-i', '--input', type=argparse.FileType('r'),
                        help='input XML file')
    parser.add_argument('-o', '--output', nargs='?', default='',
                        help='output file, in HDF5 format')
    args = vars(parser.parse_args())

    if args['output'] == '':
        filename = args['input'].name
        extension = os.path.splitext(filename)
        if extension == '.xml':
            filename = filename[:filename.rfind('.')] + '.h5'
        args['output'] = filename

    # Parse and return commandline arguments.
    return args


def get_data(element, entry):
    value = element.find(entry)
    if value is not None:
        value = value.text.strip()
    else:
        if entry in element.attrib:
            value = element.attrib[entry].strip()
        else:
            value = None

    return value


if __name__ == '__main__':
    args = parse_args()

    # Parse the XML data.
    tree = ET.parse(args['input'])
    root = tree.getroot()

    # Get old metadata
    temp = tree.find('group_structure').text.strip()
    temp = np.array(temp.split())
    group_structure = temp.astype(np.float)
    # Convert from MeV to eV
    group_structure *= 1.e6
    energy_groups = openmc.mgxs.EnergyGroups(group_structure)
    temp = tree.find('inverse-velocity')
    if temp is not None:
        temp = temp.text.strip()
        temp = np.array(temp.split())
        inverse_velocity = temp.astype(np.float)
    else:
        inverse_velocity = None

    xsd = []
    names = []

    # Now move on to the cross section data itself
    for xsdata_elem in root.iter('xsdata'):
        name = get_data(xsdata_elem, 'name')

        temperature = get_data(xsdata_elem, 'kT')
        if temperature is not None:
            temperature = \
                float(temperature) / openmc.data.K_BOLTZMANN * 1.E6
        else:
            temperature = 294.
        temperatures = [temperature]

        awr = get_data(xsdata_elem, 'awr')
        if awr is not None:
            awr = float(awr)

        representation = get_data(xsdata_elem, 'representation')
        if representation is None:
            representation = 'isotropic'
        if representation == 'angle':
            n_azi = int(get_data(xsdata_elem, 'num_azimuthal'))
            n_pol = int(get_data(xsdata_elem, 'num_polar'))

        scatter_format = get_data(xsdata_elem, 'scatt_type')
        if scatter_format is None:
            scatter_format = 'legendre'

        order = int(get_data(xsdata_elem, 'order'))

        tab_leg = get_data(xsdata_elem, 'tabular_legendre')
        if tab_leg is not None:
            warnings.Warning('The tabular_legendre option has moved to the '
                             'settings.xml file and must be added manually')

        # Either add the data to a previously existing xsdata (if it is
        # for the same 'name' but a different temperature), or create a
        # new one.
        try:
            # It is in our list, so store that entry
            i = names.index(name)
        except ValueError:
            # It is not in our list, so add it
            i = -1
            xsd.append(openmc.XSdata(name, energy_groups,
                                     temperatures=temperatures,
                                     representation=representation))
            if awr is not None:
                xsd[-1].atomic_weight_ratio = awr
            if representation == 'angle':
                xsd[-1].num_azimuthal = n_azi
                xsd[-1].num_polar = n_pol
            xsd[-1].scatter_format = scatter_format
            xsd[-1].order = order
            names.append(name)

        if scatter_format == 'legendre':
            order_dim = order + 1
        else:
            order_dim = order

        if i != -1:
            xsd[i].add_temperature(temperature)

        temp = get_data(xsdata_elem, 'total')
        if temp is not None:
            temp = np.array(temp.split(), dtype=float)
            total = temp.astype(np.float)
            total.shape = xsd[i].xs_shapes['[G]']
            xsd[i].set_total(total, temperature)

        if inverse_velocity is not None:
            xsd[i].set_inverse_velocity(inverse_velocity, temperature)

        temp = get_data(xsdata_elem, 'absorption')
        temp = np.array(temp.split())
        absorption = temp.astype(np.float)
        absorption.shape = xsd[i].xs_shapes['[G]']
        xsd[i].set_absorption(absorption, temperature)

        temp = get_data(xsdata_elem, 'scatter')
        temp = np.array(temp.split())
        scatter = temp.astype(np.float)
        # This is now a flattened-array of something that started with a
        # shape of [Order][G][G']; we need to unflatten and then switch the
        # ordering
        in_shape = (order_dim, energy_groups.num_groups,
                    energy_groups.num_groups)
        if representation == 'angle':
            in_shape = (n_pol, n_azi) + in_shape
            scatter.shape = in_shape
            scatter = np.swapaxes(scatter, 2, 3)
            scatter = np.swapaxes(scatter, 3, 4)
        else:
            scatter.shape = in_shape
            scatter = np.swapaxes(scatter, 0, 1)
            scatter = np.swapaxes(scatter, 1, 2)

        xsd[i].set_scatter_matrix(scatter, temperature)

        temp = get_data(xsdata_elem, 'multiplicity')
        if temp is not None:
            temp = np.array(temp.split())
            multiplicity = temp.astype(np.float)
            multiplicity.shape = xsd[i].xs_shapes["[G][G']"]
            xsd[i].set_multiplicity_matrix(multiplicity, temperature)

        temp = get_data(xsdata_elem, 'fission')
        if temp is not None:
            temp = np.array(temp.split())
            fission = temp.astype(np.float)
            fission.shape = xsd[i].xs_shapes['[G]']
            xsd[i].set_fission(fission, temperature)

        temp = get_data(xsdata_elem, 'kappa_fission')
        if temp is not None:
            temp = np.array(temp.split())
            kappa_fission = temp.astype(np.float)
            kappa_fission.shape = xsd[i].xs_shapes['[G]']
            xsd[i].set_kappa_fission(kappa_fission, temperature)

        temp = get_data(xsdata_elem, 'chi')
        if temp is not None:
            temp = np.array(temp.split())
            chi = temp.astype(np.float)
            chi.shape = xsd[i].xs_shapes['[G]']
            xsd[i].set_chi(chi, temperature)
        else:
            chi = None

        temp = get_data(xsdata_elem, 'nu_fission')
        if temp is not None:
            temp = np.array(temp.split())
            nu_fission = temp.astype(np.float)
            if chi is not None:
                nu_fission.shape = xsd[i].xs_shapes['[G]']
            else:
                nu_fission.shape = xsd[i].xs_shapes["[G][G']"]
            xsd[i].set_nu_fission(nu_fission, temperature)

    # Build library as we go, but first we have enough to initialize it
    lib = openmc.MGXSLibrary(energy_groups)

    lib.add_xsdatas(xsd)

    lib.export_to_hdf5(args['output'])
