#!/usr/bin/env python3

"""NEST Server with MPI support.

Usage:
  nest-server-mpi --help
  mpirun -np N nest-server-mpi [--host HOST] [--port PORT]

Options:
  -h --help     display usage information and exit
  --host HOST   use hostname/IP address HOST for server [default: 127.0.0.1]
  --port PORT   use port PORT for opening the socket [default: 52425]

"""

from docopt import docopt
from mpi4py import MPI

if __name__ == '__main__':
    opt = docopt(__doc__)

import time
import sys

import nest
import nest.server

import os

HOST = os.getenv("NEST_SERVER_HOST", "127.0.0.1")
PORT = os.getenv("NEST_SERVER_PORT", "52425")

comm = MPI.COMM_WORLD.Clone()
rank = comm.Get_rank()


def log(call_name, msg):
    msg = f'==> WORKER {rank}/{time.time():.7f} ({call_name}): {msg}'
    print(msg, flush=True)


if rank == 0:
    print("==> Starting NEST Server Master on rank 0", flush=True)
    nest.server.set_mpi_comm(comm)
    nest.server.run_mpi_app(host=opt.get('--host', HOST), port=opt.get('--port', PORT))

else:
    print(f"==> Starting NEST Server Worker on rank {rank}", flush=True)
    nest.server.set_mpi_comm(comm)
    while True:

        log('spinwait', 'waiting for call bcast')
        call_name = comm.bcast(None, root=0)

        log(call_name, 'received call bcast, waiting for data bcast')
        data = comm.bcast(None, root=0)

        log(call_name, f'received data bcast, data={data}')
        args, kwargs = data

        if call_name == 'exec':
            response = nest.server.do_exec(args, kwargs)
        else:
            call, args, kwargs = nest.server.nestify(call_name, args, kwargs)
            log(call_name, f'local call, args={args}, kwargs={kwargs}')

            # The following exception handler is useful if an error
            # occurs simulataneously on all processes. If only a
            # subset of processes raises an exception, a deadlock due
            # to mismatching MPI communication calls is inevitable on
            # the next call.
            try:
                response = call(*args, **kwargs)
            except Exception:
                continue

        log(call_name, f'sending reponse gather, data={response}')
        comm.gather(nest.serializable(response), root=0)
