#!/usr/bin/python

# Spectral envelope preprocessor.
# Requires two streams as input: the original raw data, and sinefit data.

from nilmdb.utils.printf import *
from nilmdb.utils.time import timestamp_to_human
import nilmtools.filter
import nilmdb.client
from numpy import *
import scipy.fftpack
import scipy.signal
#from matplotlib import pyplot as p
import bisect
from nilmdb.utils.interval import Interval

def main(argv = None):
    # Set up argument parser
    f = nilmtools.filter.Filter()
    parser = f.setup_parser("Spectral Envelope Preprocessor", skip_paths = True)
    group = parser.add_argument_group("Prep options")
    group.add_argument("-c", "--column", action="store", type=int,
                       help="Column number (first data column is 1)")
    group.add_argument("-q", "--nharm", action="store", type=int, default=4,
                       help="number of odd harmonics to compute (default 4)")
    group.add_argument("-N", "--nshift", action="store", type=int, default=1,
                       help="number of shifted FFTs per period (default 1)")
    exc = group.add_mutually_exclusive_group()
    exc.add_argument("-r", "--rotate", action="store", type=float,
                     help="rotate FFT output by this many degrees (default 0)")
    exc.add_argument("-R", "--rotate-rad", action="store", type=float,
                     help="rotate FFT output by this many radians (default 0)")

    group.add_argument("srcpath", action="store",
                       help="Path of raw input, e.g. /foo/raw")
    group.add_argument("sinepath", action="store",
                       help="Path of sinefit input, e.g. /foo/sinefit")
    group.add_argument("destpath", action="store",
                       help="Path of prep output, e.g. /foo/prep")

    # Parse arguments
    try:
        args = f.parse_args(argv)
    except nilmtools.filter.MissingDestination as e:
        rec = "float32_%d" % (e.parsed_args.nharm * 2)
        print "Source is %s (%s)" % (e.src.path, e.src.layout)
        print "Destination %s doesn't exist" % (e.dest.path)
        print "You could make it with a command like:"
        print "  nilmtool -u %s create %s %s" % (e.dest.url, e.dest.path, rec)
        raise SystemExit(1)

    if f.dest.layout_count != args.nharm * 2:
        print "error: need", args.nharm*2, "columns in destination stream"
        raise SystemExit(1)

    # Check arguments
    if args.column is None or args.column < 1:
        parser.error("need a column number >= 1")

    if args.nharm < 1 or args.nharm > 32:
        parser.error("number of odd harmonics must be 1-32")

    if args.nshift < 1:
        parser.error("number of shifted FFTs must be >= 1")

    if args.rotate is not None:
        rotation = args.rotate * 2.0 * pi / 360.0
    else:
        rotation = args.rotate_rad or 0.0

    # Check the sine fit stream
    client_sinefit = nilmdb.client.Client(args.url)
    sinefit = nilmtools.filter.get_stream_info(client_sinefit, args.sinepath)
    if not sinefit:
        raise Exception("sinefit data not found")
    if sinefit.layout != "float32_3":
        raise Exception("sinefit data type is " + sinefit.layout
                        + "; expected float32_3")

    # Check and set metadata in prep stream
    f.check_dest_metadata({ "prep_raw_source": f.src.path,
                            "prep_sinefit_source": sinefit.path,
                            "prep_column": args.column,
                            "prep_rotation": repr(rotation),
                            "prep_nshift": args.nshift })

    # Find the intersection of the usual set of intervals we'd filter,
    # and the intervals actually present in sinefit data.  This is
    # what we will process.
    filter_int = f.intervals()
    sinefit_int = ( Interval(start, end) for (start, end) in
                    client_sinefit.stream_intervals(
                        args.sinepath, start = f.start, end = f.end) )
    intervals = nilmdb.utils.interval.intersection(filter_int, sinefit_int)

    # Run the process (using the helper in the filter module)
    f.process_numpy(process, args = (client_sinefit, sinefit.path, args.column,
                                     args.nharm, rotation, args.nshift),
                    intervals = intervals)


def process(data, interval, args, insert_function, final):
    (client, sinefit_path, column, nharm, rotation, nshift) = args
    rows = data.shape[0]
    data_timestamps = data[:,0]

    if rows < 2:
        return 0

    last_inserted = [nilmdb.utils.time.min_timestamp]
    def insert_if_nonoverlapping(data):
        """Call insert_function to insert data, but only if this
        data doesn't overlap with other data that we inserted."""
        if data[0][0] <= last_inserted[0]:
            return
        last_inserted[0] = data[-1][0]
        insert_function(data)

    processed = 0
    out = zeros((1, nharm * 2 + 1))
    # Pull out sinefit data for the entire time range of this block
    for sinefit_line in client.stream_extract(sinefit_path,
                                              data[0, 0], data[rows-1, 0]):
        def prep_period(t_min, t_max, rot):
            """
            Compute prep coefficients from time t_min to t_max, which
            are the timestamps of the start and end of one period.
            Results are rotated by an additional extra_rot before
            being inserted into the database.  Returns the maximum
            index processed, or None if the period couldn't be
            processed.
            """
            # Find the indices of data that correspond to (t_min, t_max)
            idx_min = bisect.bisect_left(data_timestamps, t_min)
            idx_max = bisect.bisect_left(data_timestamps, t_max)
            if idx_min >= idx_max or idx_max >= len(data_timestamps):
                return None

            time_lag = (data_timestamps[idx_min] - t_min) / (t_max - t_min)
            lag_correction = e**(2 * 1j * pi * time_lag) 

            # Perform FFT over those indices
            N = idx_max - idx_min
            d = data[idx_min:idx_max, column]
            F = scipy.fftpack.fft(d) * 2.0 / N

            # If we wanted more harmonics than the FFT gave us, pad with zeros
            if N < (nharm * 2):
                F = r_[F, zeros(nharm * 2 - N)]

            # Fill output data.
            out[0, 0] = round(t_min)
            for k in range(nharm):
                Fk = F[2 * k + 1] * e**(rot * 1j * (2*k+1)) / lag_correction
                out[0, 2 * k + 1] = -imag(Fk) # Pk
                out[0, 2 * k + 2] = -real(Fk) # Qk

            insert_if_nonoverlapping(out)
            return idx_max

        # Extract sinefit data to get zero crossing timestamps.
        # t_min = beginning of period
        # t_max = end of period
        (t_min, f0, A, C) = [ float(x) for x in sinefit_line.split() ]
        t_max = t_min + 1e6 / f0

        # Compute prep over shifted windows of the period
        # (nshift is typically 1)
        for n in range(nshift):
            # Compute timestamps and rotations for shifted window
            time_shift = n * (t_max - t_min) / nshift
            shifted_min = t_min + time_shift
            shifted_max = t_max + time_shift
            angle_shift = n * 2 * pi / nshift
            shifted_rot = rotation - angle_shift

            # Run prep computation
            idx_max = prep_period(shifted_min, shifted_max, shifted_rot)
            if not idx_max:
                break
            processed = idx_max

    # If we processed no data but there's lots in here, pretend we
    # processed half of it.
    if processed == 0 and rows > 10000:
        processed = rows / 2
        printf("%s: warning: no periods found; skipping %d rows\n",
               timestamp_to_human(data[0][0]), processed)
    else:
        printf("%s: processed %d of %d rows\n",
               timestamp_to_human(data[0][0]), processed, rows)
    return processed

if __name__ == "__main__":
    main()
