#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#    Project: Azimuthal integration
#             https://github.com/kif/pyFAI
#
#    File: "$Id$"
#
#    Copyright (C) European Synchrotron Radiation Facility, Grenoble, France
#
#    Principal author:       Jérôme Kieffer (Jerome.Kieffer@ESRF.eu)
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

"""

MX-calibrate is a tool to calibrate the distance of a detector from a set of powder diffraction patterns

Idea:

MX-calibrate -e 12 --spacing dSpacing.D file1.edf file2.edf file3.edf

calibrate the by hand the most distant frame then calibrate subsequently all frames

"""

__author__ = "Jerome Kieffer"
__contact__ = "Jerome.Kieffer@ESRF.eu"
__license__ = "GPLv3+"
__copyright__ = "European Synchrotron Radiation Facility, Grenoble, France"
__date__ = "04/04/2012"
__satus__ = "development"

import os, types
import pyFAI, pyFAI.calibration
import numpy
import pylab
import fabio
import logging
from optparse import OptionParser
logger = logging.getLogger("MX-calibrate")
from pyFAI.units import hc
from pyFAI.detectors import Detector, detector_factory
try:
    from rfoo.utils import rconsole
    rconsole.spawn_server()
except ImportError:
    logging.info("No socket opened for debugging. Please install rfoo")
from scipy.stats import linregress

class MultiCalib(object):
    def __init__(self, dataFiles=None, darkFiles=None, flatFiles=None, pixelSize=None, splineFile=None, detector=None):
        """
        """
        self.dataFiles = dataFiles or []
        self.darkFiles = darkFiles or []
        self.flatFiles = flatFiles or []
        self.data = {}

        if type(detector) in types.StringTypes:
            self.detector = detector_factory(detector)
        elif isinstance(detector, Detector):
            self.detector = detector
        else:
            self.detector = Detector()

        if splineFile and os.path.isfile(splineFile):
            self.detector.splineFile = os.path.abspath(splineFile)
        if pixelSize:
            if __len__ in pixelSize and len(pixelSize) >= 2:
                self.detector.pixel1 = float(pixelSize[0])
                self.detector.pixel2 = float(pixelSize[1])
            else:
                self.detector.pixel1 = self.detector.pixel2 = float(pixelSize)
        self.cutBackground = None
        self.outfile = "merged.edf"
        self.peakPicker = None
        self.basename = None
        self.geoRef = None
#        self.reconstruct = False
        self.mask = None
        self.max_iter = 1000
        self.filter = "mean"
        self.saturation = 0.1
        self.spacing_file = None
        self.wavelength = None
        self.weighted = False
        self.polarization_factor = 0
        self.results = {}
        self.gui = True
        self.interactive = True
        self.centerX = None
        self.centerY = None
        self.distance = None
        self.fixed = []
        self.max_rings = None

    def __repr__(self):
        lst = ["Multi-Calibration object:",
             "data= " + ", ".join(self.dataFiles),
             "dark= " + ", ".join(self.darkFiles),
             "flat= " + ", ".join(self.flatFiles)]
        lst.append(self.detector.__repr__())
#        lst.append("gaussian= %s" % self.gaussianWidth)
        return os.linesep.join(lst)

    def parse(self):
        """
        parse options from command line
        """
        usage = "usage: %prog file1.edf file2.edf ..."
        version = "%prog " + pyFAI.version
        description = """
        Calibrate automatically a set of frames taken at various sample-detector distance.
        """
        epilog = """This tool has been developed for ESRF MX-beamlines where an acceptable calibration is
        usually present is the header of the image. PyFAI reads it and does a "recalib" on
        each of them before exporting a linear regression of all parameters versus this distance.
        """
        parser = OptionParser(usage=usage, version=version,
                              description=description, epilog=epilog)

#        parser.add_option("-V", "--version", dest="version", action="store_true",
#                          help="print version of the program and quit", metavar="FILE", default=False)
#        parser.add_option("-o", "--out", dest="outfile",
#                          help="Filename where processed image is saved", metavar="FILE", default="merged.edf")
        parser.add_option("-v", "--verbose",
                          action="store_true", dest="debug", default=False,
                          help="switch to debug/verbose mode")
#        parser.add_option("-g", "--gaussian", dest="gaussian", help="""Size of the gaussian kernel.
#Size of the gap (in pixels) between two consecutive rings, by default 100
#Increase the value if the arc is not complete;
#decrease the value if arcs are mixed together.""", default=None)
#        parser.add_option("-c", "--square", dest="square", action="store_true",
#                      help="Use square kernel shape for neighbor search instead of diamond shape", default=False)
        parser.add_option("-S", "--spacing", dest="spacing", metavar="FILE",
                      help="file containing d-spacing of the reference sample (MANDATORY)", default=None)
        parser.add_option("-w", "--wavelength", dest="wavelength", type="float",
                      help="wavelength of the X-Ray beam in Angstrom", default=None)
        parser.add_option("-e", "--energy", dest="energy", type="float",
                      help="energy of the X-Ray beam in keV (hc=%skeV.A)" % hc, default=None)
        parser.add_option("-P", "--polarization", dest="polarization_factor",
                      type="float", default=0.0,
                      help="polarization factor, from -1 (vertical) to +1 (horizontal), default is 0, synchrotrons are around 0.95")
        parser.add_option("-b", "--background", dest="background",
                      help="Automatic background subtraction if no value are provided", default=None)
        parser.add_option("-d", "--dark", dest="dark",
                      help="list of dark images to average and subtract", default=None)
        parser.add_option("-f", "--flat", dest="flat",
                      help="list of flat images to average and divide", default=None)
#        parser.add_option("-r", "--reconstruct", dest="reconstruct",
#                      help="Reconstruct image where data are masked or <0  (for Pilatus detectors or detectors with modules)",
#                      action="store_true", default=False)
        parser.add_option("-s", "--spline", dest="spline",
                      help="spline file describing the detector distortion", default=None)
        parser.add_option("-p", "--pixel", dest="pixel",
                      help="size of the pixel in micron", default=None)
        parser.add_option("-D", "--detector", dest="detector_name",
                      help="Detector name (instead of pixel size+spline)", default=None)
        parser.add_option("-m", "--mask", dest="mask",
                      help="file containing the mask (for image reconstruction)", default=None)
#        parser.add_option("-n", "--npt", dest="npt",
#                      help="file with datapoints saved", default=None)
        parser.add_option("--filter", dest="filter",
                      help="select the filter, either mean(default), max or median",
                       default="mean")
        parser.add_option("--saturation", dest="saturation",
                      help="consider all pixel>max*(1-saturation) as saturated and reconstruct them",
                      default=0.1, type="float")
        parser.add_option("-r", "--ring", dest="max_rings", type="float",
                      help="maximum number of rings to extract", default=None)

        parser.add_option("--weighted", dest="weighted",
                      help="weight fit by intensity",
                       default=False, action="store_true")
        parser.add_option("-l", "--distance", dest="distance", type="float",
                      help="sample-detector distance in millimeter", default=None)
        parser.add_option("--no-tilt", dest="tilt",
                      help="refine the detector tilt", default=True , action="store_false")
        parser.add_option("--poni1", dest="poni1", type="float",
                      help="poni1 coordinate in meter", default=None)
        parser.add_option("--poni2", dest="poni2", type="float",
                      help="poni2 coordinate in meter", default=None)
        parser.add_option("--rot1", dest="rot1", type="float",
                      help="rot1 in radians", default=None)
        parser.add_option("--rot2", dest="rot2", type="float",
                      help="rot2 in radians", default=None)
        parser.add_option("--rot3", dest="rot3", type="float",
                      help="rot3 in radians", default=None)

        parser.add_option("--fix-dist", dest="fix_dist",
                      help="fix the distance parameter", default=None, action="store_true")
        parser.add_option("--free-dist", dest="fix_dist",
                      help="free the distance parameter", default=None, action="store_false")

        parser.add_option("--fix-poni1", dest="fix_poni1",
                      help="fix the poni1 parameter", default=None, action="store_true")
        parser.add_option("--free-poni1", dest="fix_poni1",
                      help="free the poni1 parameter", default=None, action="store_false")

        parser.add_option("--fix-poni2", dest="fix_poni2",
                      help="fix the poni2 parameter", default=None, action="store_true")
        parser.add_option("--free-poni2", dest="fix_poni2",
                      help="free the poni2 parameter", default=None, action="store_false")

        parser.add_option("--fix-rot1", dest="fix_rot1",
                      help="fix the rot1 parameter", default=None, action="store_true")
        parser.add_option("--free-rot1", dest="fix_rot1",
                      help="free the rot1 parameter", default=None, action="store_false")

        parser.add_option("--fix-rot2", dest="fix_rot2",
                      help="fix the rot2 parameter", default=None, action="store_true")
        parser.add_option("--free-rot2", dest="fix_rot2",
                      help="free the rot2 parameter", default=None, action="store_false")

        parser.add_option("--fix-rot3", dest="fix_rot3",
                      help="fix the rot3 parameter", default=None, action="store_true")
        parser.add_option("--free-rot3", dest="fix_rot3",
                      help="free the rot3 parameter", default=None, action="store_false")

        parser.add_option("--fix-wavelength", dest="fix_wavelength",
                      help="fix the wavelength parameter", default=True, action="store_true")
        parser.add_option("--free-wavelength", dest="fix_wavelength",
                      help="free the wavelength parameter", default=True, action="store_false")


        parser.add_option("--no-gui", dest="gui",
                      help="force the program to run without a Graphical interface",
                      default=True, action="store_false")
        parser.add_option("--gui", dest="gui",
                      help="force the program to run with a Graphical interface",
                      default=True, action="store_true")

        parser.add_option("--no-interactive", dest="interactive",
                      help="force the program to run and exit without prompting for refinements",
                      default=True, action="store_false")
        parser.add_option("--interactive", dest="interactive",
                      help="force the program to prompt for refinements",
                      default=True, action="store_true")

        (options, args) = parser.parse_args()

        # Analyse aruments and options
        if options.debug:
            logger.setLevel(logging.DEBUG)
        if options.background is not None:
            try:
                self.cutBackground = float(options.background)
            except Exception:
                self.cutBackground = True
        if options.dark:
#            print options.dark, type(options.dark)
            self.darkFiles = [f for f in options.dark.split(",") if os.path.isfile(f)]
        if options.flat:
#            print options.flat, type(options.flat)
            self.flatFiles = [f for f in options.flat.split(",") if os.path.isfile(f)]
#        self.reconstruct = options.reconstruct
        if options.mask and os.path.isfile(options.mask):
            self.mask = fabio.open(options.mask).data

#        self.pointfile = options.npt
        if options.detector_name:
            self.detector = detector_factory(options.detector_name)
        if options.spline:
            if os.path.isfile(options.spline):
                self.detector.splineFile = os.path.abspath(options.spline)
            else:
                logger.error("Unknown spline file %s" % (options.spline))
        if options.pixel is not None:
            self.get_pixelSize(options.pixel)
        self.filter = options.filter
        self.saturation = options.saturation
        if options.wavelength:
            self.wavelength = 1e-10 * options.wavelength
        elif options.energy:
            self.wavelength = 1e-10 * hc / options.energy
        self.spacing_file = options.spacing
        self.polarization_factor = options.polarization_factor
#        if not self.spacing_file or not os.path.isfile(self.spacing_file):
#            raise RuntimeError("you must specify the d-spacing file")
        self.gui = options.gui
        self.interactive = options.interactive
        self.max_rings = options.max_rings
        self.fixed = []
        if not options.tilt:
            self.fixed += ["rot1", "rot2", "rot3"]
        if options.fix_dist:
            self.fixed.append("dist")
        if options.fix_poni1:
            self.fixed.append("poni1")
        if options.fix_poni2:
            self.fixed.append("poni2")
        if options.fix_rot1:
            self.fixed.append("rot1")
        if options.fix_rot2:
            self.fixed.append("rot2")
        if options.fix_rot3:
            self.fixed.append("rot3")
        if options.fix_wavelength:
            self.fixed.append("wavelength")

        self.dataFiles = [f for f in args if os.path.isfile(f)]
        if not self.dataFiles:
            raise RuntimeError("Please provide some calibration images ... if you want to analyze them. Try also the --help option to see all options!")
        self.weighted = options.weighted

    def get_pixelSize(self, ans):
        """convert a comma separated sting into pixel size"""
        sp = ans.split(",")
        if len(sp) >= 2:
            try:
                pixelSizeXY = [float(i) * 1e-6 for i in sp[:2]]
            except Exception:
                logger.error("error in reading pixel size_2")
                return
        elif len(sp) == 1:
            px = sp[0]
            try:
                pixelSizeXY = [float(px) * 1e-6, float(px) * 1e-6]
            except Exception:
                logger.error("error in reading pixel size_1")
                return
        else:
            logger.error("error in reading pixel size_0")
            return
        self.detector.pixel1 = pixelSizeXY[1]
        self.detector.pixel2 = pixelSizeXY[0]

    def read_pixelsSize(self):
        """Read the pixel size from prompt if not available"""
        if (self.detector.pixel1 is None) and (self.detector.splineFile is None):
            pixelSize = [15, 15]
            ans = raw_input("Please enter the pixel size (in micron, comma separated X, Y , i.e. %.2e,%.2e) or a spline file: " % tuple(pixelSize)).strip()
            if os.path.isfile(ans):
                self.detector.splineFile = ans
            else:
                self.get_pixelSize(ans)

    def read_dSpacingFile(self):
        """Read the name of the file with d-spacing"""
        if (self.spacing_file is None):
            comments = ["pyFAI calib has changed !!!",
                        "Instead of entering the 2theta value, which was tedious,"
                        "the program takes a d-spacing file in input (just a serie of number representing the inter planar distance in Angstrom)",
                        "and an associated wavelength", ""
                        "You will be asked to enter the ring number, which is usually a simpler than the 2theta value."]
            print(os.linesep.join(comments))
            ans = ""
            while not os.path.isfile(ans):
                ans = raw_input("Please enter the name of the file containing the d-spacing:\t").strip()
            self.spacing_file = ans

    def read_wavelength(self):
        """Read the wavelength"""
        while not self.wavelength:
            ans = raw_input("Please enter wavelength in Angstrom:\t").strip()
            try:
                self.wavelength = 1e-10 * float(ans)
            except:
                self.wavelength = None

    def process(self):
        """

        """
        self.dataFiles.sort()
        for fn in self.dataFiles:
            fabimg = fabio.open(fn)
            wavelength = self.wavelength
            dist = self.distance
            centerX = self.centerX
            centerY = self.centerY
            if "_array_data.header_contents" in fabimg.header:
                headers = fabimg.header["_array_data.header_contents"].split()
                if "Detector_distance" in headers:
                    dist = float(headers[headers.index("Detector_distance") + 1])
                if "Wavelength" in headers:
                    wavelength = float(headers[headers.index("Wavelength") + 1]) * 1e-10
                if "Beam_xy" in headers:
                    centerX = float(headers[headers.index("Beam_xy") + 1][1:-1])
                    centerY = float(headers[headers.index("Beam_xy") + 2][:-1])
            if dist is None:
                digits = ""
                for i in os.path.basename(fn):
                    if i.isdigit() and not digits:
                        digits += i
                    elif i.isdigit():
                        digits += i
                    elif not i.isdigit() and digits:
                        break
                dist = int(digits) * 0.001
            if centerX is None:
                centerX = fabimg.data.shape[1] // 2
            if centerY is None:
                centerY = fabimg.data.shape[0] // 2
            self.results[fn] = {"wavelength":wavelength, "dist":dist}
            rec = pyFAI.calibration.Recalibration(dataFiles=[fn], darkFiles=self.darkFiles, flatFiles=self.flatFiles,
                                                  detector=self.detector, spacing_file=self.spacing_file, wavelength=wavelength)
            rec.outfile = os.path.splitext(fn)[0] + ".proc.edf"
            rec.interactive = self.interactive
            rec.gui = self.gui
            rec.saturation = self.saturation
            rec.mask = self.mask
            rec.filter = self.filter
            rec.cutBackground = self.cutBackground
            rec.fixed = self.fixed
            rec.max_rings = self.max_rings
            rec.weighted = self.weighted
            if centerY:
                rec.ai.poni1 = centerY * self.detector.pixel1
            if centerX:
                rec.ai.poni2 = centerX * self.detector.pixel2
            if dist:
                rec.ai.dist = dist
            rec.preprocess()
            rec.extract_cpt()
            rec.refine()
            self.results[fn]["ai"] = rec.ai

    def regression(self):
        print self.results
        dist = numpy.zeros(len(self.results))
        x = dist.copy()
        poni1 = dist.copy()
        poni2 = dist.copy()
        rot1 = dist.copy()
        rot2 = dist.copy()
        rot3 = dist.copy()
        direct = dist.copy()
        tilt = dist.copy()
        trp = dist.copy()
        centerX = dist.copy()
        centerY = dist.copy()
        idx = 0
        print("")
        print("Results of linear regression for distance in mm")
        for key, dico in  self.results.iteritems():
            print key, dico["dist"]
            print dico["ai"]
            x[idx] = dico["dist"] * 1000
            dist[idx] = dico["ai"].dist
            poni1[idx] = dico["ai"].poni1
            poni2[idx] = dico["ai"].poni2
            rot1[idx] = dico["ai"].rot1
            rot2[idx] = dico["ai"].rot2
            rot3[idx] = dico["ai"].rot3
            f = dico["ai"].getFit2D()
            direct[idx] = f["directDist"]
            tilt[idx] = f["tilt"]
            trp[idx] = f["tiltPlanRotation"]
            centerX[idx] = f["centerX"]
            centerY[idx] = f["centerY"]
            idx += 1
        for name, elt in [("dist", dist),
                         ("poni1", poni1), ("poni2", poni2),
                         ("rot1", rot1), ("rot2", rot2), ("rot3", rot3),
                         ("direct", direct), ("tilt", tilt), ("trp", trp),
                         ("centerX", centerX), ("centerY", centerY)]:
            slope, intercept, r, two, stderr = linregress(x, elt)

            print "%s = %s * dist_mm + %s \t R= %s\t stderr= %s" % (name, slope, intercept, r, stderr)


# This is for debugin wtih rconsole
c = None
if __name__ == "__main__":
    c = MultiCalib()
    c.parse()
    c.read_pixelsSize()
    c.read_dSpacingFile()
    c.process()
    c.regression()
    raw_input("Press enter to quit")
