#!/usr/bin/python3

from __future__ import division, print_function
import struct
import sys
from argparse import ArgumentParser

import numpy as np
import h5py


def main():
    # Process command line arguments
    parser = ArgumentParser()
    parser.add_argument('voxel_file', help='Path to voxel file')
    parser.add_argument('-o', '--output', action='store',
                        default='plot', help='Path to output SILO or VTK file.')
    parser.add_argument('-s', '--silo', action='store_true',
                        default=False, help='Flag to convert to SILO instead of VTK.')
    args = parser.parse_args()

    # Read data from voxel file
    fh = h5py.File(args.voxel_file, 'r')
    dimension = fh.attrs['num_voxels']
    width = fh.attrs['voxel_width']
    lower_left = fh.attrs['lower_left']
    voxel_data = fh['data'].value

    nx, ny, nz = dimension
    upper_right = lower_left + width*dimension

    if not args.silo:
        import vtk

        grid = vtk.vtkImageData()
        grid.SetDimensions(nx+1, ny+1, nz+1)
        grid.SetOrigin(*lower_left)
        grid.SetSpacing(*width)

        data = vtk.vtkDoubleArray()
        data.SetName("id")
        data.SetNumberOfTuples(nx*ny*nz)
        for x in range(nx):
            sys.stdout.write(" {}%\r".format(int(x/nx*100)))
            sys.stdout.flush()
            for y in range(ny):
                for z in range(nz):
                    i = z*nx*ny + y*nx + x
                    data.SetValue(i, voxel_data[x, y, z])
        grid.GetCellData().AddArray(data)

        writer = vtk.vtkXMLImageDataWriter()
        if vtk.vtkVersion.GetVTKMajorVersion() > 5:
            writer.SetInputData(grid)
        else:
            writer.SetInput(grid)
        if not args.output.endswith(".vti"):
            args.output += ".vti"
        writer.SetFileName(args.output)
        writer.Write()

    else:
        import silomesh

        if not args.output.endswith(".silo"):
            args.output += ".silo"
        silomesh.init_silo(args.output)
        meshparams = list(map(int, dimension)) + list(map(float, lower_left)) + \
                     list(map(float, upper_right))
        silomesh.init_mesh('plot', *meshparams)
        silomesh.init_var("id")
        for x in range(nx):
            sys.stdout.write(" {}%\r".format(int(x/nx*100)))
            sys.stdout.flush()
            for y in range(ny):
                for z in range(nz):
                    silomesh.set_value(float(voxel_data[x, y, z]),
                                       x + 1, y + 1, z + 1)
        print()
        silomesh.finalize_var()
        silomesh.finalize_mesh()
        silomesh.finalize_silo()


if __name__ == '__main__':
    main()
