# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
import numpy as np
from nipy.testing import (assert_true, assert_equal, assert_raises, 
                          assert_false, assert_array_equal,
                          assert_almost_equal, parametric)

# this import line is a little ridiculous...
from nipy.core.reference.coordinate_map import (CoordinateMap, AffineTransform, 
                                                compose, CoordinateSystem, product,
                                                append_io_dim,
                                                equivalent, shifted_domain_origin,
                                                shifted_range_origin,
                                                _as_coordinate_map)


class empty:
    pass

E = empty()


def setup():
    def f(x):
        return 2*x
    def g(x):
        return x/2.0
    x = CoordinateSystem('x', 'x')
    E.a = CoordinateMap(x, x, f)
    E.b = CoordinateMap(x, x, f, inverse_function=g)
    E.c = CoordinateMap(x, x, g)        
    E.d = CoordinateMap(x, x, g, inverse_function=f)        
    E.e = AffineTransform.identity('ijk')

    A = np.identity(4)
    A[0:3] = np.random.standard_normal((3,4))
    E.mapping = AffineTransform.from_params('ijk' ,'xyz', A)
    
    E.singular = AffineTransform.from_params('ijk', 'xyzt',
                                    np.array([[ 0,  1,  2,  3],
                                              [ 4,  5,  6,  7],
                                              [ 8,  9, 10, 11],
                                              [ 8,  9, 10, 11],
                                              [ 0,  0,  0,  1]]))

    
def test_shift_origin():
    CS = CoordinateSystem

    A = np.random.standard_normal((5,6))
    A[-1] = [0,0,0,0,0,1]

    aff1 = AffineTransform(CS('ijklm', 'oldorigin'), CS('xyzt'), A)
    difference = np.random.standard_normal(5)
    point_in_old_basis = np.random.standard_normal(5)

    for aff in [aff1, _as_coordinate_map(aff1)]:
        # The same affine transforation with a different origin for its domain

        shifted_aff = shifted_domain_origin(aff, difference, 'neworigin')

        # This is the relation ship between coordinates in old and new origins

        yield assert_true, np.allclose(shifted_aff(point_in_old_basis), aff(point_in_old_basis+difference))

        yield assert_true, np.allclose(shifted_aff(point_in_old_basis-difference), aff(point_in_old_basis))

    # OK, now for the range

    A = np.random.standard_normal((5,6))
    A[-1] = [0,0,0,0,0,1]
    aff2 = AffineTransform(CS('ijklm', 'oldorigin'), CS('xyzt'), A)

    difference = np.random.standard_normal(4)

    for aff in [aff2, _as_coordinate_map(aff2)]:
    # The same affine transforation with a different origin for its domain

        shifted_aff = shifted_range_origin(aff, difference, 'neworigin')

        # Let's check that things work

        point_in_old_basis = np.random.standard_normal(5)

        # This is the relation ship between coordinates in old and new origins

        yield assert_true, np.allclose(shifted_aff(point_in_old_basis), aff(point_in_old_basis)-difference)

        yield assert_true, np.allclose(shifted_aff(point_in_old_basis)+difference, aff(point_in_old_basis))



def test_renamed():

    A = AffineTransform.from_params('ijk', 'xyz', np.identity(4))

    ijk = CoordinateSystem('ijk')
    xyz = CoordinateSystem('xyz')
    C = CoordinateMap(ijk, xyz, np.log)

    for B in [A,C]:
        B_re = B.renamed_domain({'i':'foo'})
        yield assert_equal, B_re.function_domain.coord_names, ('foo', 'j', 'k')

        B_re = B.renamed_domain({'i':'foo','j':'bar'})
        yield assert_equal, B_re.function_domain.coord_names, ('foo', 'bar', 'k')

        B_re = B.renamed_range({'y':'foo'})
        yield assert_equal, B_re.function_range.coord_names, ('x', 'foo', 'z')

        B_re = B.renamed_range({0:'foo',1:'bar'})
        yield assert_equal, B_re.function_range.coord_names, ('foo', 'bar', 'z')

        B_re = B.renamed_domain({0:'foo',1:'bar'})
        yield assert_equal, B_re.function_domain.coord_names, ('foo', 'bar', 'k')

        B_re = B.renamed_range({'y':'foo','x':'bar'})
        yield assert_equal, B_re.function_range.coord_names, ('bar', 'foo', 'z')

        yield assert_raises, ValueError, B.renamed_range, {'foo':'y'}
        yield assert_raises, ValueError, B.renamed_domain, {'foo':'y'}



def test_call():
    value = 10
    yield assert_true, np.allclose(E.a(value), 2*value)
    yield assert_true, np.allclose(E.b(value), 2*value)
    # FIXME: this shape just below is not 
    # really expected for a CoordinateMap
    yield assert_true, np.allclose(E.b([value]), 2*value)
    yield assert_true, np.allclose(E.c(value), value/2)
    yield assert_true, np.allclose(E.d(value), value/2)
    value = np.array([1., 2., 3.])
    yield assert_true, np.allclose(E.e(value), value)


def test_compose():
    value = np.array([[1., 2., 3.]]).T
    aa = compose(E.a, E.a)
    yield assert_true, aa.inverse() is None
    yield assert_true, np.allclose(aa(value), 4*value)
    ab = compose(E.a,E.b)
    yield assert_true, ab.inverse() is None
    assert_true, np.allclose(ab(value), 4*value)
    ac = compose(E.a,E.c)
    yield assert_true, ac.inverse() is None
    yield assert_true, np.allclose(ac(value), value)
    bb = compose(E.b,E.b)
    #    yield assert_true, bb.inverse() is not None
    aff1 = np.diag([1,2,3,1])
    affine1 = AffineTransform.from_params('ijk', 'xyz', aff1)
    aff2 = np.diag([4,5,6,1])
    affine2 = AffineTransform.from_params('xyz', 'abc', aff2)
    # compose mapping from 'ijk' to 'abc'
    compcm = compose(affine2, affine1)
    yield assert_equal, compcm.function_domain.coord_names, ('i', 'j', 'k')
    yield assert_equal, compcm.function_range.coord_names, ('a', 'b', 'c')
    yield assert_equal, compcm.affine, np.dot(aff2, aff1)
    # check invalid coordinate mappings
    yield assert_raises, ValueError, compose, affine1, affine2

    yield assert_raises, ValueError, compose, affine1, 'foo'
  
    cm1 = CoordinateMap(CoordinateSystem('ijk'),
                        CoordinateSystem('xyz'), np.log)
    cm2 = CoordinateMap(CoordinateSystem('xyz'),
                        CoordinateSystem('abc'), np.exp)
    yield assert_raises, ValueError, compose, cm1, cm2


def test__eq__():
    yield assert_true, E.a == E.a
    yield assert_false, E.a != E.a

    yield assert_false, E.a == E.b
    yield assert_true, E.a != E.b

    yield assert_true, E.singular == E.singular
    yield assert_false, E.singular != E.singular
    
    A = AffineTransform.from_params('ijk', 'xyz', np.diag([4,3,2,1]))
    B = AffineTransform.from_params('ijk', 'xyz', np.diag([4,3,2,1]))

    yield assert_true, A == B
    yield assert_false, A != B

def test_isinvertible():
    yield assert_false, E.a.inverse()
    yield assert_true, E.b.inverse()
    yield assert_false, E.c.inverse()
    yield assert_true, E.d.inverse()
    yield assert_true, E.e.inverse()
    yield assert_true, E.mapping.inverse()
    yield assert_false, E.singular.inverse()


def test_inverse1():
    inv = lambda a: a.inverse()
    yield assert_true, inv(E.a) is None
    yield assert_true, inv(E.c) is None
    inv_b = E.b.inverse()
    inv_d = E.d.inverse()
    ident_b = compose(inv_b,E.b)
    ident_d = compose(inv_d,E.d)
    value = np.array([[1., 2., 3.]]).T    
    yield assert_true, np.allclose(ident_b(value), value)
    yield assert_true, np.allclose(ident_d(value), value)
        
      
def test_compose_cmap():
    value = np.array([1., 2., 3.])
    b = compose(E.e, E.e)
    assert_true(np.allclose(b(value), value))

    
def test_inverse2():
    assert_true(np.allclose(E.e.affine, E.e.inverse().inverse().affine))


def voxel_to_world():
    # utility function for generating trivial CoordinateMap
    incs = CoordinateSystem('ijk', 'voxels')
    outcs = CoordinateSystem('xyz', 'world')
    map = lambda x: x + 1
    inv = lambda x: x - 1
    return incs, outcs, map, inv


def test_comap_init():
    # Test mapping and non-mapping functions
    incs, outcs, map, inv = voxel_to_world()
    cm = CoordinateMap(incs, outcs, map, inv)
    yield assert_equal, cm.function, map
    yield assert_equal, cm.function_domain, incs
    yield assert_equal, cm.function_range, outcs
    yield assert_equal, cm.inverse_function, inv
    yield assert_raises, ValueError, CoordinateMap, incs, outcs, 'foo', inv
    yield assert_raises, ValueError, CoordinateMap, incs, outcs, map, 'bar'


def test_comap_copy():
    import copy
    incs, outcs, map, inv = voxel_to_world()
    cm = CoordinateMap(incs, outcs, inv, map)
    cmcp = copy.copy(cm)
    yield assert_equal, cmcp.function, cm.function
    yield assert_equal, cmcp.function_domain, cm.function_domain
    yield assert_equal, cmcp.function_range, cm.function_range
    yield assert_equal, cmcp.inverse_function, cm.inverse_function


#
# AffineTransform tests
#

def affine_v2w():
    # utility function
    incs = CoordinateSystem('ijk', 'voxels')
    outcs = CoordinateSystem('xyz', 'world')
    aff = np.diag([1, 2, 4, 1])
    aff[:3, 3] = [11, 12, 13]
    """array([[ 1,  0,  0, 11],
       [ 0,  2,  0, 12],
       [ 0,  0,  4, 13],
       [ 0,  0,  0,  1]])
    """
    return incs, outcs, aff


def test_affine_init():
    incs, outcs, aff = affine_v2w()
    print aff, incs, outcs
    cm = AffineTransform(incs, outcs, aff)
    yield assert_equal, cm.function_domain, incs
    yield assert_equal, cm.function_range, outcs
    yield assert_equal, cm.affine, aff
    badaff = np.diag([1,2])
    yield assert_raises, ValueError, AffineTransform, incs, outcs, badaff


def test_affine_bottom_row():
    # homogeneous transformations have bottom rows all zero 
    # except the last one
    yield assert_raises, ValueError, AffineTransform.from_params, 'ij', \
        'x', np.array([[3,4,5],[1,1,1]])


def test_affine_inverse():
    incs, outcs, aff = affine_v2w()
    inv = np.linalg.inv(aff)
    cm = AffineTransform(incs, outcs, aff)
    x = np.array([10, 20, 30], np.float)
    x_roundtrip = cm(cm.inverse()(x))
    yield assert_equal, x_roundtrip, x
    badaff = np.array([[1,2,3],[0,0,1]])
    badcm = AffineTransform(CoordinateSystem('ij'),
                            CoordinateSystem('x'),
                            badaff)
    yield assert_equal, badcm.inverse(), None


def test_affine_from_params():
    incs, outcs, aff = affine_v2w()
    cm = AffineTransform.from_params('ijk', 'xyz', aff)
    yield assert_equal, cm.affine, aff
    badaff = np.array([[1,2,3],[4,5,6]])
    yield assert_raises, ValueError, AffineTransform.from_params, 'ijk', 'xyz', badaff


def test_affine_start_step():
    incs, outcs, aff = affine_v2w()
    start = aff[:3, 3]
    step = aff.diagonal()[:3]
    cm = AffineTransform.from_start_step(incs.coord_names, outcs.coord_names,
                                start, step)
    yield assert_equal, cm.affine, aff
    yield assert_raises, ValueError, AffineTransform.from_start_step, 'ijk', 'xy', \
        start, step


def test_affine_identity():
    aff = AffineTransform.identity('ijk')
    yield assert_equal, aff.affine, np.eye(4)
    yield assert_equal, aff.function_domain, aff.function_range
    x = np.array([3, 4, 5])
    # AffineTransform's aren't CoordinateMaps, so
    # they don't have "function" attributes
    yield assert_false, hasattr(aff, 'function')


def test_affine_copy():
    incs, outcs, aff = affine_v2w()
    cm = AffineTransform(incs, outcs, aff)
    import copy
    cmcp = copy.copy(cm)
    yield assert_equal, cmcp.affine, cm.affine
    yield assert_equal, cmcp.function_domain, cm.function_domain
    yield assert_equal, cmcp.function_range, cm.function_range


#
# Module level functions
#

def test_reordered_domain():
    incs, outcs, map, inv = voxel_to_world()
    cm = CoordinateMap(incs, outcs, map, inv)
    recm = cm.reordered_domain('jki')
    yield assert_equal, recm.function_domain.coord_names, ('j', 'k', 'i')
    yield assert_equal, recm.function_range.coord_names, outcs.coord_names
    yield assert_equal, recm.function_domain.name, incs.name
    yield assert_equal, recm.function_range.name, outcs.name
    # default reverse reorder
    recm = cm.reordered_domain()
    yield assert_equal, recm.function_domain.coord_names, ('k', 'j', 'i')
    # reorder with order as indices
    recm = cm.reordered_domain([2,0,1])
    yield assert_equal, recm.function_domain.coord_names, ('k', 'i', 'j')


def test_str():
    result = """AffineTransform(
   function_domain=CoordinateSystem(coord_names=('i', 'j', 'k'), name='', coord_dtype=float64),
   function_range=CoordinateSystem(coord_names=('x', 'y', 'z'), name='', coord_dtype=float64),
   affine=array([[ 1.,  0.,  0.,  0.],
                 [ 0.,  1.,  0.,  0.],
                 [ 0.,  0.,  1.,  0.],
                 [ 0.,  0.,  0.,  1.]])
)"""
    domain = CoordinateSystem('ijk')
    range = CoordinateSystem('xyz')
    affine = np.identity(4)
    affine_mapping = AffineTransform(domain, range, affine)
    yield assert_equal, result, str(affine_mapping)

    cmap = CoordinateMap(domain, range, np.exp, np.log)
    result="""CoordinateMap(
   function_domain=CoordinateSystem(coord_names=('i', 'j', 'k'), name='', coord_dtype=float64),
   function_range=CoordinateSystem(coord_names=('x', 'y', 'z'), name='', coord_dtype=float64),
   function=<ufunc 'exp'>,
   inverse_function=<ufunc 'log'>
  )"""
    cmap = CoordinateMap(domain, range, np.exp)
    result="""CoordinateMap(
   function_domain=CoordinateSystem(coord_names=('i', 'j', 'k'), name='', coord_dtype=float64),
   function_range=CoordinateSystem(coord_names=('x', 'y', 'z'), name='', coord_dtype=float64),
   function=<ufunc 'exp'>
  )"""
    yield assert_equal, result, repr(cmap)


def test_reordered_range():
    incs, outcs, map, inv = voxel_to_world()
    cm = CoordinateMap(incs, outcs, inv, map)
    recm = cm.reordered_range('yzx')
    yield assert_equal, recm.function_domain.coord_names, incs.coord_names
    yield assert_equal, recm.function_range.coord_names, ('y', 'z', 'x')
    yield assert_equal, recm.function_domain.name, incs.name
    yield assert_equal, recm.function_range.name, outcs.name
    # default reverse order
    recm = cm.reordered_range()
    yield assert_equal, recm.function_range.coord_names, ('z', 'y', 'x')
    # reorder with indicies
    recm = cm.reordered_range([2,0,1])
    yield assert_equal, recm.function_range.coord_names, ('z', 'x', 'y')    


def test_product():
    affine1 = AffineTransform.from_params('i', 'x', np.diag([2, 1]))
    affine2 = AffineTransform.from_params('j', 'y', np.diag([3, 1]))
    affine = product(affine1, affine2)

    cm1 = CoordinateMap(CoordinateSystem('i'),
                        CoordinateSystem('x'),
                        np.log)

    cm2 = CoordinateMap(CoordinateSystem('j'),
                        CoordinateSystem('y'),
                        np.log)
    cm = product(cm1, cm2) 

    yield assert_equal, affine.function_domain.coord_names, ('i', 'j')
    yield assert_equal, affine.function_range.coord_names, ('x', 'y')
    yield assert_almost_equal, cm([3,4]), np.log([3,4])
    yield assert_almost_equal, cm.function([[3,4],[5,6]]), np.log([[3,4],[5,6]])

    yield assert_equal, affine.function_domain.coord_names, ('i', 'j')
    yield assert_equal, affine.function_range.coord_names, ('x', 'y')
    yield assert_equal, affine.affine, np.diag([2, 3, 1])


def test_equivalent():
    ijk = CoordinateSystem('ijk')
    xyz = CoordinateSystem('xyz')
    T = np.random.standard_normal((4,4))
    T[-1] = [0,0,0,1]
    A = AffineTransform(ijk, xyz, T)

    # now, cycle through
    # all possible permutations of 
    # 'ijk' and 'xyz' and confirm that 
    # the mapping is equivalent

    yield assert_false, equivalent(A, A.renamed_domain({'i':'foo'}))

    try:
        import itertools
        for pijk in itertools.permutations('ijk'):
            for pxyz in itertools.permutations('xyz'):
                B = A.reordered_domain(pijk).reordered_range(pxyz)
                yield assert_true, equivalent(A, B)
    except ImportError:
        # just do some if we can't find itertools
        for pijk in ['ikj', 'kij']:
            for pxyz in ['xzy', 'yxz']:
                B = A.reordered_domain(pijk).reordered_range(pxyz)
                yield assert_true, equivalent(A, B)


def test_as_coordinate_map():

    ijk = CoordinateSystem('ijk')
    xyz = CoordinateSystem('xyz')
    
    A = np.random.standard_normal((4,4))
    
    # bottom row of A is not [0,0,0,1]
    yield assert_raises, ValueError, AffineTransform, ijk, xyz, A

    A[-1] = [0,0,0,1]

    aff = AffineTransform(ijk, xyz, A)
    _cmapA = _as_coordinate_map(aff)
    yield assert_true, isinstance(_cmapA, CoordinateMap)
    yield assert_true, _cmapA.inverse_function != None

    # a non-invertible one

    B = A[1:]
    xy = CoordinateSystem('xy')
    affB = AffineTransform(ijk, xy, B)
    _cmapB = _as_coordinate_map(affB)

    yield assert_true, isinstance(_cmapB, CoordinateMap)
    yield assert_true, _cmapB.inverse_function == None


def test_cm__setattr__raise_error():
    # CoordinateMap has all read-only attributes

    # AffineTransform has some properties and it seems
    # the same __setattr__ doesn't work for it.
    ijk = CoordinateSystem('ijk')
    xyz = CoordinateSystem('xyz')

    cm = CoordinateMap(ijk, xyz, np.exp)

    yield assert_raises, AttributeError, cm.__setattr__, "function_range", xyz


@parametric
def test_append_io_dim():
    aff = np.diag([1,2,3,1])
    in_dims = list('ijk')
    out_dims = list('xyz')
    cm = AffineTransform.from_params(in_dims, out_dims, aff)
    cm2 = append_io_dim(cm, 'l', 't')
    yield assert_array_equal(cm2.affine, np.diag([1,2,3,1,1]))
    yield assert_equal(cm2.function_range.coord_names,
                       out_dims + ['t'])
    yield assert_equal(cm2.function_domain.coord_names,
                       in_dims + ['l'])
    cm2 = append_io_dim(cm, 'l', 't', 9, 5)
    a2 = np.diag([1,2,3,5,1])
    a2[3,4] = 9
    yield assert_array_equal(cm2.affine, a2)
    yield assert_equal(cm2.function_range.coord_names,
                       out_dims + ['t'])
    yield assert_equal(cm2.function_domain.coord_names,
                       in_dims + ['l'])
    # non square case
    aff = np.array([[2,0,0],
                    [0,3,0],
                    [0,0,1],
                    [0,0,1]])
    cm = AffineTransform.from_params('ij', 'xyz', aff)
    cm2 = append_io_dim(cm, 'q', 't', 9, 5)
    a2 = np.array([[2,0,0,0],
                   [0,3,0,0],
                   [0,0,0,1],
                   [0,0,5,9],
                   [0,0,0,1]])
    yield assert_array_equal(cm2.affine, a2)
    yield assert_equal(cm2.function_range.coord_names,
                       list('xyzt'))
    yield assert_equal(cm2.function_domain.coord_names,
                       list('ijq'))
    
