#!/usr/bin/env bash

USAGE="lalinference_mpi_wrapper --mpirun MPIRUN --executable EXECUTABLE --np NCPUS [arg1] [arg2] ... \\
          Runs MPIRUN --np NCPUS EXECUTABLE [arg1] [arg2] ... \\
"
MPIRUN=mpirun
MCMCARGS=""
NP=1
while [[ $# -ge 1 ]]
do
        key=$1

        case $key in
		--ntemps)
			NTEMPS="$2"
			shift 2
			;;
                --np)
                        NP="$2"
                        shift 2
                        ;;
                --executable)
                        EXE="$2"
                        shift 2
                        ;;
                --mpirun)
                        MPIRUN="$2"
                        shift 2
                        ;;
                *)
                        MCMCARGS="$MCMCARGS $1"
                        shift
                        ;;

        esac
done
# If NTEMPS is set, use it
if [ ! -z ${NTEMPS} ]
then
	MCMCARGS="$MCMCARGS --ntemps $NTEMPS"
	# Number of OpenMP threads is ceil(NTEMPS/NP)
	export OMP_NUM_THREADS=$[ ( $NTEMPS + $NP - 1) / $NP ]
	echo "Running with OMP_NUM_THREADS=${OMP_NUM_THREADS}"
fi

if [ -z $EXE ]
then
	echo "--executable not set"
	exit 1
fi

echo "$MPIRUN -np $NP $EXE $MCMCARGS"
$MPIRUN -np $NP $EXE $MCMCARGS
MPI_EXIT_CODE=$?
echo "MPI exited with code ${MPI_EXIT_CODE}"
exit $MPI_EXIT_CODE

