'''
Wrapper of plotting functionality in GDA
'''

from uk.ac.diamond.scisoft.analysis.plotserver import GuiParameters as _guiparam #@UnresolvedImport
from gda.analysis import RCPPlotter as _plotter #@UnresolvedImport
from gda.analysis import RCPNexusTreeViewer as _nxsviewer #@UnresolvedImport

import scisoftpy as _np
import scisoftpy.roi as roi

_PVNAME = "Plot 1"

_toList = _np.toUnwrappedList
_unwrap = _np.Sciunwrap

class _parameters:
    plotmode = _guiparam.PLOTMODE
    title = _guiparam.TITLE
    roi = _guiparam.ROIDATA
    roilist = _guiparam.ROIDATALIST
    plotid = _guiparam.PLOTID
    plotop = _guiparam.PLOTOPERATION
    fileop = _guiparam.FILEOPERATION
    filename = _guiparam.FILENAME
    fileselect = _guiparam.FILESELECTEDLIST
    dispview = _guiparam.DISPLAYFILEONVIEW

parameters = _parameters()

def setdefname(name):
    '''Assign a default plot view name used by all plotters
    This default name starts as "Plot 1"
    '''
    global _PVNAME
    _PVNAME = name

def line(x, y=None, name=None):
    '''Plot y dataset (or list of datasets), optionally against
    any given x dataset in the named view

    Arguments:
    x -- optional dataset for x-axis
    y -- dataset or list of datasets
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if y is None:
        _plotter.plot(name, _unwrap(x))
    else:
        _plotter.plot(name, _unwrap(x), _toList(y))

def updateline(x, y=None, name=None):
    '''Update existing plot by changing displayed y dataset (or list of datasets), optionally against
    any given x dataset in the named view

    Arguments:
    x -- optional dataset for x-axis
    y -- dataset or list of datasets
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if y == None:
        _plotter.updatePlot(name, _unwrap(x))
    else:
        _plotter.updatePlot(name, _unwrap(x), _toList(y))

plot = line
updateplot = updateline

def image(im, x=None, y=None, name=None):
    '''Plot a 2D dataset as an image in the named view with optional x and y axes

    Arguments:
    im -- image dataset
    x -- optional dataset for x-axis
    y -- optional dataset for y-axis
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if x is None or y is None:
        _plotter.imagePlot(name, _unwrap(im))
    else:
        _plotter.imagePlot(name, _unwrap(x), _unwrap(y), _unwrap(im))

def surface(s, x=None, y=None, name=None):
    '''Plot the 2D dataset as a surface in the named view with optional x and y axes

    Arguments:
    s -- surface (height field) dataset
    x -- optional dataset for x-axis
    y -- optional dataset for y-axis
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if x is None or y is None:
        _plotter.surfacePlot(name, _unwrap(s))
    else:
        _plotter.surfacePlot(name, _unwrap(x), _unwrap(y), _unwrap(s))

def stack(x, y=None, z=None, name=None):
    '''Plot all of the given 1D y datasets against corresponding x as a 3D stack
    with optional z coordinates in the named view

    Arguments:
    x -- optional dataset or list of datasets for x-axis
    y -- dataset or list of datasets
    z -- optional dataset for z-axis
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if not y:
        y = _toList(x)
        l = 0
        for d in y:
            if d.size > l:
                l = d.size
        x = [ _np.arange(l) ]

    if z is None:
        _plotter.stackPlot(name, _toList(x), _toList(y))
    else:
        _plotter.stackPlot(name, _toList(x), _toList(y), _unwrap(z))

def points(x, y, z=None, size=0, name=None):
    '''Plot points with given coordinates

    Arguments:
    x -- dataset of x coords
    y -- dataset of y coords
    z -- optional dataset of z coords
    size -- integer size or dataset of sizes
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if z is None:
        _plotter.scatter2DPlot(name, _unwrap(x), _unwrap(y), _unwrap(size))
    else:
        _plotter.scatter3DPlot(name, _unwrap(x), _unwrap(y), _unwrap(z), _unwrap(size))

def addpoints(x, y, z=None, size=0, name=None):
    '''Update existing plot by adding points of given coordinates

    Arguments:
    x -- dataset of x coords
    y -- dataset of y coords
    z -- optional dataset of z coords
    size -- integer size or dataset of sizes
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    if z is None:
        _plotter.scatter2DPlotOver(name, _unwrap(x), _unwrap(y), _unwrap(size))
    else:
        _plotter.scatter3DPlotOver(name, _unwrap(x), _unwrap(y), _unwrap(z), _unwrap(size))

_IMAGEEXPNAME = "ImageExplorer View"

__orders = { "none": _plotter.IMAGEORDERNONE, "alpha": _plotter.IMAGEORDERALPHANUMERICAL, "chrono": _plotter.IMAGEORDERCHRONOLOGICAL}
def _order(order):
    try:
        return __orders[order]
    except KeyError:
        raise ValueError, "Given order not one of none, alpha, chrono"

def scanforimages(path, order="none", suffices=None, columns=-1, rowMajor=True, name=_IMAGEEXPNAME):
    '''Scan for images in path and load into given image explorer view
    order can be "none", "alpha", "chrono"
    '''
    _plotter.scanForImages(name, path, _order(order), suffices, columns, rowMajor)

_REMOTEVOLNAME = "Remote Volume Viewer"

def volume(v, name=_REMOTEVOLNAME):
    '''Plot a volume dataset in remote volume view
    '''
    import tempfile
    import os
    tmp = tempfile.mkstemp('.dsr') # '/tmp/blah.dsr'
    os.close(tmp[0])
    vdatafile = tmp[1]
    # convert to byte, int or float as volume viewer cannot cope with boolean, long or double datasets
    if v.dtype == _np.bool:
        v = _np.cast(v, _np.int8)
    elif v.dtype == _np.int64:
        v = _np.cast(v, _np.int32)
    elif v.dtype == _np.float64 or v.dtype == _np.complex64 or v.dtype == _np.complex128:
        v = _np.cast(v, _np.float32)
    _np.io.save(vdatafile, v, format='binary')
    _plotter.volumePlot(name, vdatafile)
    os.remove(vdatafile)

from uk.ac.diamond.scisoft.analysis.plotserver import GuiBean as _guibean #@UnresolvedImport

bean = _guibean

def getbean(name=None):
    '''Get GUI bean (contains information from named view)

    Arguments:
    name -- name of plot view to use (if None, use default name)
    '''
    if not name:
        name = _PVNAME

    return _plotter.getGuiBean(name)

def setbean(bean, name=None):
    '''Set GUI bean

    Arguments:
    name -- name of plot view to use (if None, use default name)
    '''
    if not bean:
        if not name:
            name = _PVNAME
        _plotter.setGuiBean(name, bean)

def getroi(bean):
    '''Get region of interest from bean'''
    if bean is None:
        return None
    return bean[parameters.roi]

def setroi(bean, roi):
    '''Set region of interest in bean'''
    if bean:
        bean[parameters.roi] = roi

def delroi(bean):
    '''Delete region of interest from bean'''
    if bean is None:
        return None
    if parameters.roi in bean:
        bean.remove(parameters.roi)
    return bean

def getrois(bean):
    '''Get list of regions of interest from bean'''
    if bean is None:
        return None
    return bean[parameters.roilist]

def setrois(bean, roilist):
    '''Set list of regions of interest in bean'''
    if bean:
        bean[parameters.roilist] = roilist

def delrois(bean):
    '''Delete list of regions of interest from bean'''
    if bean is None:
        return None
    if parameters.roilist in bean:
        bean.remove(parameters.roilist)
    return bean

def getline(bean):
    '''Get linear region of interest'''
    r = getroi(bean)
    if r is None or not isinstance(r, roi.line):
        return None
    return r

def getlines(bean):
    '''Get list of linear regions of interest'''
    rs = getrois(bean)
    if rs is None:
        return None
    try:
        iter(rs)
    except:
        rs = [rs]
    return [ r for r in rs if isinstance(r, roi.line)]

def getrect(bean):
    '''Get rectangular region of interest'''
    r = getroi(bean)
    if r is None or not isinstance(r, roi.rect):
        return None
    return r

def getrects(bean):
    '''Get list of rectangular regions of interest'''
    rs = getrois(bean)
    if rs is None:
        return None
    try:
        iter(rs)
    except:
        rs = [rs]
    return [ r for r in rs if isinstance(r, roi.rect)]

def getsect(bean):
    '''Get sector region of interest'''
    r = getroi(bean)
    if r is None or not isinstance(r, roi.sect):
        return None
    return r

def getsects(bean):
    '''Get list of sector regions of interest'''
    rs = getrois(bean)
    if rs is None:
        return None
    try:
        iter(rs)
    except:
        rs = [rs]
    return [ r for r in rs if isinstance(r, roi.sect)]

def getfiles(bean):
    '''Get list of selected files'''
    try:
        fn = bean[parameters.fileselect]
    except KeyError:
        print "No selection has been made and sent to server"
        return None
    if fn is None:
        print "No selection has been made and sent to server"
        return None
    fl = []
    for f in fn:
        fl.append(f)
    fl.sort()
    return fl

_NTVNAME="nexusTreeViewer"
def viewnexus(tree, name=_NTVNAME):
    '''View a NeXus tree in the named view'''
    _nxsviewer.viewNexusTree(name, tree)
