#!/usr/bin/env python

# --------

import sys, os, tempfile
from distutils.core import setup
from distutils.dir_util import remove_tree
from Cython.Distutils.extension import Extension
from Cython.Distutils import build_ext

INCLUDES = []
try:
    from numpy import get_include as npy_get_include
    INCLUDES += [npy_get_include()]
except ImportError:
    pass
try:
    from mpi4py import get_include as mpi_get_include
    INCLUDES += [mpi_get_include()]
except ImportError:
    pass

def main(cy_source, c_sources, lang='c', verb=False):
    name, ext = os.path.splitext(cy_source)
    tmpdir = tempfile.mkdtemp()
    try:
        setup(script_args = [['--quiet', '--verbose',][verb],
                             'build',     '--build-base', tmpdir,
                             'build_ext', '--inplace', #'--force',
                             ],
              ext_modules = [Extension(name,
                                       sources=[cy_source]+list(c_sources),
                                       depends=[],
                                       include_dirs=INCLUDES,
                                       language=lang)
                             ],
              cmdclass    = {'build_ext' : build_ext},
            )
    finally:
        remove_tree(tmpdir)

# --------

if __name__ == '__main__':
    args = list(sys.argv[1:])
    verbose  = 0
    if args[0] in ('-v',):
        verbose=1; del args[0]
    language = 'c'
    if args[0] in ('-c++', '-cxx', '-cplus'):
        language= 'c++'; del args[0]
    #
    from distutils.sysconfig import get_config_vars
    config = get_config_vars()
    if language == 'c':
        config['CC'] =  LD = os.environ.get('MPICC', 'mpicc')
    elif language == 'c++':
        config['CXX'] = LD = os.environ.get('MPICXX', 'mpicxx')
    LD = os.environ.get('MPILD', LD)
    try:
        LDSHARED = config['LDSHARED'].split()
        LDSHARED = LD.split() + LDSHARED[1:]
        config['LDSHARED'] = ' '.join(LDSHARED)
    except KeyError:
        pass
    #
    cy_source = args[0]
    c_sources = args[1:]
    main(cy_source, c_sources, language, verbose)

# --------

# Local Variables:
# mode: python
# End:
