'''

'''

from gda.analysis.functions import Parameter as _param
from gda.analysis.functions import AFunction as _absfn
from gda.analysis.functions import CompositeFunction as _compfn
from uk.ac.diamond.scisoft.analysis.fitting.Fitter import geneticFit as _genfit
#from uk.ac.diamond.scisoft.analysis.fitting.Fitter import polynomialFit as _polyfit

import scisoftpy as _dnp
_asIterable = _dnp.asIterable
_toList = _dnp.toList
_toDS = _dnp.toDS
_npwrapped = _dnp.ndarraywrapped

import java.lang.Class as _jclass

#from jarray import array as _jarray

#Parameter = _param

import function as _fn

def _createparams(np, params, bounds):
    '''Create a Parameters list with bounds, popping off items from both input lists
    np     -- number of parameters
    params -- list of initial values
    bounds -- list of tuples of bounds
    '''
    pl = [ _param(params.pop(0)) for i in range(np) ]

    nbound = len(bounds)
    if nbound > np:
        nbound = np
 
    for i in range(nbound):
        b = bounds.pop(0)
        if b is not None:
            b = _asIterable(b)
            if b[0] is not None:
                pl[i].lowerLimit = b[0]
            if len(b) > 1:
                if b[1] is not None:
                    pl[i].upperLimit = b[1]
#    print [(p.value, p.lowerLimit, p.upperLimit) for p in pl]
    return pl

class fitfunc(_absfn):
    '''Class to wrap an ordinary Jython function for fitting.
    That function should take two arguments:
    p -- list of parameter values
    coords -- coordinates array (or list of such)
    *args -- optional arguments
    '''
    def __init__(self, fn, name, plist, *args):
        '''
        This constructor consumes creates a fit function from given jython function and parameter list

        Arguments:
        fn     -- function
        name   -- function name
        plist  -- list of Parameter objects
        '''
        _absfn.__init__(self, plist) #@UndefinedVariable
        self.func = fn
        self.args = args
        self.name = name

    def val(self, coords):
        '''Evaluate function at single set of coordinates
        '''
        try:
            v = self.func(self.parameterValues, _dnp.array(coords), *self.args)
            return v.getElementDoubleAbs(0)
        except ValueError:
            raise ValueError, 'Problem with function \"' + self.name + '\" at coord ' + coords + ' with params  ' + self.parameterValues

    @_npwrapped
    def makeDataset(self, coords):
        '''Evaluate function across given coordinates
        '''
        try:
            d = self.func(self.parameterValues, coords, *self.args)
            d.name = self.name
            return d
        except ValueError:
            raise ValueError, 'Problem with function \"' + self.name + '\" with params  ' + self.parameterValues

    def residual(self, allvalues, data, coords):
        '''Find residual as sum of squared differences of function and data
        
        Arguments:
        allvalues -- boolean, currently ignored 
        data      -- used to subtract from evaluated function
        coords    -- coordinates over which the function is evaluated
        '''
        try:
            vals = self.func(self.parameterValues, coords, *self.args)
            return _dnp.residual(vals, data)
        except ValueError:
            raise ValueError, 'Problem with function \"' + self.name + '\" with params  ' + self.parameterValues

class cfitfunc(_compfn):
    '''Composite function for situation where there's a mixture of jython and Java fitting functions
    '''
    def __init__(self):
        _compfn.__init__(self) #@UndefinedVariable

    def val(self, coords):
        '''Evaluate function at single set of coordinates
        '''
        v = 0.
        for n in range(self.noOfFunctions):
            v += self.getFunction(n).val(coords)
        return v

    @_npwrapped
    def makeDataset(self, coords):
        '''Evaluate function across given coordinates
        '''
        vt = None
        for n in range(self.noOfFunctions):
            v = self.getFunction(n).makeDataset(coords)
            if vt is None:
                vt = v
            else:
                vt += v
        return vt

    def residual(self, allvalues, data, coords):
        '''Find residual as sum of squared differences of function and data
        
        Arguments:
        allvalues -- boolean, currently ignored 
        data      -- used to subtract from evaluated function
        coords    -- coordinates over which the function is evaluated
        '''
        return _dnp.residual(self.makeDataset(coords), data)


class fitresult(object):
    '''This is used to contain results from a fit
    '''
    def __init__(self, func, coords, data):
        '''Arguments:
        func   -- function after fitting as occurred
        coords -- coordinate(s)
        data   -- scalar dataset that was fitted to
        '''
        self.func = func
        self.coords = coords
        self.data = data

    def _calcdelta(self, coords):
        delta = 1.
        if coords[0].rank > 1:
            r = coords[0].rank
            for n in range(len(coords)):
                x = coords[n]
                if x.rank != r:
                    raise ValueError, "Given coordinates are not all of same rank"
                delta *= x.ptp()/x.shape[n]
                n += 1
        else:
            for x in coords:
                if x.rank != 1:
                    raise ValueError, "Given coordinates are not all 1D"
                delta *= x.ptp()/x.size
        return delta

    def __getitem__(self, key):
        '''Get specified parameter value
        '''
        try:
            return self.func.getParameter(key)
        except:
            raise IndexError

    def __len__(self):
        '''Number of parameters
        '''
        return self.func.getNoOfParameters()

    def makeplotdata(self):
        '''Make a list of datasets to plot
        '''
        pdata = self.makefuncdata()
        pdata.insert(0, self.data)
        offset = self.data.min() - ((self.data.max() - self.data.min()) / 5.0)
        edata = self.data - pdata[1] + offset
        edata.name = "Error value"
        odata = _dnp.zeros_like(edata)
        odata.fill(offset)
        odata.name = "Error offset"
        pdata.insert(2, odata)
        pdata.insert(2, edata)
        return pdata

    def makefuncdata(self):
        '''Make a list of datasets for composite fitting function and its components
        '''
        nf = self.func.noOfFunctions
        if nf > 1:
            fdata = [ self.func.makeDataset(self.coords) ]
            fdata[0].name = "Composite function"
            for n in range(nf):
                fdata.append(self.func.getFunction(n).makeDataset(self.coords))
        elif nf == 1:
            fdata = [ self.func.getFunction(0).makeDataset(self.coords) ]
        else:
            fdata = []

        return fdata


    def plot(self, name=None):
        '''Plot fit as 1D
        '''
        _dnp.plot.line(self.coords[0], self.makeplotdata(), name)

    def _parameters(self):
        '''List of all parameters values
        '''
        return [ p for p in self.func.getParameterValues() ]
    parameters = property(_parameters)

    def _residual(self):
        '''Residual of fit
        '''
        return self.func.residual(True, self.data, self.coords)
    residual = property(_residual)

    def _area(self):
        '''Area or hypervolume under fit assuming coordinates are uniformly spaced
        '''
        deltax = self._calcdelta(self.coords)
        return self.func.makeDataset(self.coords).sum() * deltax
    area = property(_area)

    def __str__(self):
        nf = self.func.noOfFunctions
        out = "Fit parameters:\n"
        for n in range(nf):
            f = self.func.getFunction(n)
            p = [ q for q in f.getParameterValues() ]
            np = len(p)
            out += "    function '%s' (%d) has %d parameters = %s\n" % (f.name, n, np, p)
        return out


def fit(func, coords, data, p0, bounds=[], args=None, ptol=1e-4):
    '''
    Arguments:
    func   -- function(s), either as a jython function/parameter pair or one of the pre-defined ones
    coords -- coordinate dataset(s)
    data   -- data to fit
    p0     -- list of initial parameters
    bounds -- list of parameter bounds, bounds are tuples of lower and upper values (any can be None)
    args   -- extra arguments
    ptol   -- parameter fit tolerance
    Returns:
    fitresult object
    '''
    fnlist = []
    if not isinstance(func, list):
        func = [func]
    if not isinstance(p0, list):
        p0 = [p0]
    if not isinstance(bounds, list):
        bounds = [bounds]
    mixed = False
    for f in _toList(func):
        if isinstance(f, _jclass):
            # create bound function object
            np = _fn.nparams(f)
            pl = _createparams(np, p0, bounds)
            fnlist.append(f(pl))
        else:
            if not isinstance(f, tuple):
                raise ValueError, 'jython function must be paired with a parameter count as a tuple'
            fo = f[0]
            np = f[1]
            pl = _createparams(np, p0, bounds)
            fnlist.append(fitfunc(fo, fo.__name__, pl, args))
            mixed = True

    if not mixed: # no jython functions
        cfunc = _compfn()
    else:
        cfunc = cfitfunc()
    for f in fnlist:
        cfunc.addFunction(f)

    coords = _asIterable(coords)

    import time

    start = -time.time()
    _genfit(ptol, coords, data, cfunc)
    start += time.time()
    print "Fit took %fs" % start

    return fitresult(cfunc, coords, data)

# genfit = _genfit

#def polyfit(x, y, deg, rcond=None, full=False):
#    from gda.analysis.functions import Polynomial
#    poly = Polynomial(deg)
#    x = _asIterable(x)
#    _polyfit(x, y, poly)
#    return fitresult(poly, x, y)

# need a cspline fit function
