import joule.utils.time
import numpy as np
import bisect
import json
from joule.utils.time import timestamp_to_human
from joule.client import FilterModule

import logging
import scipy.fftpack  # fftpack fft??
import asyncio
import sys

ARGS_DESC = """
Operation
------------

    iv ------,
    sinefit -+-->[prep]--> prep (A/B/C)

Pipes
-----------
Inputs:
    'iv'             float32 <V, I_A, I_B, ...>
    'zero_crossings' float32 <freq, amp, phase>

Outputs (eg for 3 phase):
  merge = True:
    'prep'           float32 <P1A, P1B, P1C, ... Q7A, Q7B, Q7C>
  merge = False:
    'prep-a'         float32 <P1, Q1, ... P7, Q7>
    'prep-b'         float32 <P1, Q1, ... P7, Q7>
    'prep-c'         float32 <P1, Q1, ... P7, Q7>

Arguments
-------------
config:
  this should be a JSON encoded dict with the
  following properties

 {
    'current_indices':  [N,] #current columns (1 indexed)
    'rotations': [N.0,] #voltage angles for each phase
    'nshift': 1         #amount to shift window
    'nharm': 4          #number of odd harmonics to compute
    'scale_factor': 1.0 #scale result to watts
    'merge': True       #set to use a single prep output stream
    'polar': False      #compute polar data (mag/phase) instead of PQ
 }
"""


class Prep(FilterModule):

    def __init__(self):
        super(Prep, self).__init__("Prep")
        self.description = "run prep alogrithm"
        self.help = '''\
          Extract harmonic envelopes from I,V data
        '''
        self.arg_description = ARGS_DESC

    def custom_args(self, parser):
        parser.add_argument("config", help="JSON encoded dict")

    async def run(self, parsed_args, inputs, outputs, nrows=0):
        try:
            self._parse_configs(json.loads(parsed_args.config))
        except KeyError as e:
            print("ERROR: invalid config string missing [%s]" % e)
            exit(1)

        try:
            iv = inputs['iv']
            zc = inputs['zero_crossings']
            if(self.merge):
                output = outputs['prep']
            else:
                # expect prep-a, prep-b, prep-c output streams
                ph_outputs = []
                for ph in range(self.nphases):
                    ph_outputs.append(outputs['prep-%c' % ('abc'[ph])])
        except KeyError as e:
            print("ERROR: required inputs [iv, zero_crossings], outputs [prep] (or [prep-a,...])")
            print("missing %s" % e)
            exit(1)
            
            
        rows_processed = 0
        nphases = len(self.current_indices)
        while(True):
            iv_data = await iv.read(flatten=True)
            zc_data = await zc.read(flatten=True)
            # print("iv [%d]  %s->%s" % (len(iv_data), timestamp_to_human(iv_data[0, 0]),
            #                           timestamp_to_human(iv_data[-1, 0])))
            # print("zc [%d] %s->%s" % (len(zc_data), timestamp_to_human(zc_data[0, 0]),
            #                          timestamp_to_human(zc_data[-1, 0])))
            sys.stdout.flush()
            await asyncio.sleep(0.2)
            # set up default parameters:
            nharm = self.nharm
            nshift = self.nshift
            rows = iv_data.shape[0]
            data_timestamps = iv_data[:, 0]

            if rows < 2:
                continue  # go back and get more data before we run

            last_inserted = [joule.utils.time.min_timestamp]

            async def insert_if_nonoverlapping(data):
                """Call insert_function to insert data, but only if this
                data doesn't overlap with other data that we inserted."""
                if data[0][0] <= last_inserted[0]:
                    return
                last_inserted[0] = data[-1][0]
                if(self.merge):
                    await output.write(data)
                else:
                    nphases = self.nphases
                    width = self.nharm*2
                    for ph in range(nphases):
                        start_i = 1+ph*width
                        end_i = start_i+width
                        cols = [0]+list(range(start_i, end_i))
                        await ph_outputs[ph].write(data[:, cols])

            processed = 0
            out = np.zeros((1, nphases * (nharm * 2) + 1))
            # Pull out sinefit data for the entire time range of this block
            for sinefit_line in zc_data:
                async def prep_period(t_min, t_max, angle_shift):
                    """
                    Compute prep coefficients from time t_min to t_max, which
                    are the timestamps of the start and end of one period.
                    Results are rotated by an additional extra_rot before
                    being inserted into the database.  Returns the maximum
                    index processed, or None if the period couldn't be
                    processed.
                    """
                    # Find the indices of data that correspond to (t_min,
                    # t_max)
                    idx_min = bisect.bisect_left(data_timestamps, t_min)
                    idx_max = bisect.bisect_left(data_timestamps, t_max)
                    if idx_min >= idx_max or idx_max >= len(data_timestamps):

                        return None
                    time_lag = (data_timestamps[
                                idx_min] - t_min) / (t_max - t_min)
                    lag_correction = np.e**(2 * 1j * np.pi * time_lag)

                    # Perform FFT over those indices
                    N = idx_max - idx_min
                    for ph in range(nphases):
                        col = self.current_indices[ph]
                        rot = self.rotations[ph] - angle_shift
                        d = iv_data[idx_min:idx_max, col]
                        F = scipy.fftpack.fft(d) * 2.0 / N

                        # If we wanted more harmonics than the FFT gave us, pad
                        # with zeros
                        if N < (nharm * 2):
                            F = np.r_[F, np.zeros(nharm * 2 - N)]

                        # Fill output data.
                        out[0, 0] = round(t_min)
                        for k in range(nharm):
                            Fk = F[2 * k + 1] * \
                                np.e**(rot * 1j * (2 * k + 1)) / lag_correction
                            # compute output values
                            if self.polar:
                                out_a = np.abs(Fk) * self.scale_factor  # VA
                                out_b = (np.angle(Fk)+np.pi/2)*2/np.pi  # pf
                            else:
                                out_a = -np.imag(Fk) * self.scale_factor  # Pk
                                out_b = -np.real(Fk) * self.scale_factor  # Qk

                            out[0, ph * nharm * 2 + 2 * k + 1] = out_a
                            out[0, ph * nharm * 2 + 2 * k + 2] = out_b

                    await insert_if_nonoverlapping(out)
                    zc.consume(1)
                    return idx_max

                # Extract sinefit data to get zero crossing timestamps.
                # t_min = beginning of period
                # t_max = end of period
                t_min = sinefit_line[0]
                f0 = sinefit_line[1]

                t_max = t_min + 1e6 / f0
                # Compute prep over shifted windows of the period
                # (nshift is typically 1)
                for n in range(nshift):
                    # Compute timestamps and rotations for shifted window
                    time_shift = n * (t_max - t_min) / nshift
                    shifted_min = t_min + time_shift
                    shifted_max = t_max + time_shift
                    angle_shift = n * 2 * np.pi / nshift

                    # Run prep computation
                    idx_max = await prep_period(shifted_min, shifted_max, angle_shift)
                    if not idx_max:
                        break
                    processed = idx_max

            # If we processed no data but there's lots in here, pretend we
            # processed half of it.
            if processed == 0 and rows > 10000:
                processed = rows / 2
                logging.warning("%s: warning: no periods found; skipping %d rows\n" %
                                (timestamp_to_human(iv_data[0][0]), processed))
            else:
                logging.info("%s: processed %d of %d rows" %
                             (timestamp_to_human(iv_data[0][0]), processed, rows))
            iv.consume(processed)
            rows_processed += processed
            if(nrows > 0 and rows_processed >= nrows):
                break

    def _parse_configs(self, config):
        #1 indexed b/c ts is 0
        self.current_indices = config['current_indices']
        self.nphases = len(self.current_indices)
        self.rotations = config['rotations']
        self.nshift = config['nshift']
        self.nharm = config['nharm']
        self.scale_factor = config['scale_factor']
        self.merge = config['merge']
        self.polar = config['polar']
        

def main():
    filter = Prep()
    filter.start()

if __name__ == "__main__":
    main()
