import numpy as np
import bisect
import json
from joule import FilterModule, utilities

import scipy.fftpack  # fftpack fft??
import asyncio
from . import helpers
from joule.models.pipes import EmptyPipe

ARGS_DESC = """
---
:name:
  NILM Prep
:author:
  John Donnal, James Paris
:license:
  Closed
:url:
  http://git.wattsworth.net/wattsworth/nilm.git
:description:
  compute power spectral envelopes
:usage:
  Compute real and reactive envelopes of a current waveform
 
    'current_indices':  [N,] #current columns (1 indexed)
    'rotations': [N.0,] #voltage angles for each phase
    'nshift': 1         #amount to shift window
    'nharm': 4          #number of odd harmonics to compute
    'scale_factor': 1.0 #scale result to watts
    'merge': True       #set to use a single prep output stream
    'polar': False      #compute polar data (mag/phase) instead of PQ


  | Arguments         | Default | Description
  |-------------------|---------|---------------------------------
  |``current-indices``| --  | up to three current elements (1 indexed)
  |``rotations``      | --  | voltage angle for each current (radians)
  |``nshift``         | 1   | amount to shift window
  |``nharm``          | 4   | number of odd harmonics to compute
  |``scale-factor``   | 1.0 | scale result to watts (RMS voltage)
  |``merge``          | yes | [yes|no] use a single output stream for all phases
  |``polar``          | no  | [yes|no] use polar coordinates (mag, angle) or cartesian (P,Q)
  |``samp-freq``      | --  | sampling frequency

:inputs:
 
  iv
  : current data (nilm-reconstructor)

  zero_crossings
  : phase, amplitude and offset data (nilm-sinefit)

:outputs:

  for merged 3 phase

  prep
  :  ``float32`` [P1A, Q1A, ..., P1B, Q1B, ... P1C, Q1C, ...]

  for unmerged 3 phase
  
  prep-a
  :   ``float32`` [P1, Q1, ... P7, Q7]

  prep-b
  :   ``float32`` [P1, Q1, ... P7, Q7]

  prep-c
  :   ``float32`` [P1, Q1, ... P7, Q7]

:module_config:
  [Main]
  name = NILM Prep
  exec_cmd = nilm-filter-prep

  [Arguments]
  current-indices = [1,2,3]
  rotations = [0, 120, 240]
  scale-factor = 120
  merge = yes
  polar = no

  [Inputs]
  iv = /path/to/iv
  zero_crossings = /path/to/zero_crossings

  [Outputs]
  prep = /path/to/prep


:stream_configs:
  #prep (merged)#
     [Main]
     name = Prep
     path = /path/to/prep
     datatype = float32
     keep = 1w

     # Phase A
     [Element1]
     name = P1A
     [Element2]
     name = Q1A
     [Element1]
     name = P3A
     [Element2]
     name = Q3A
     [Element1]
     name = P5A
     [Element2]
     name = Q5A
     [Element1]
     name = P7A
     [Element2]
     name = Q7A

     # Phase B
     [Element1]
     name = P1B
     [Element2]
     name = Q1B
     [Element1]
     name = P3B
     [Element2]
     name = Q3B
     [Element1]
     name = P5B
     [Element2]
     name = Q5B
     [Element1]
     name = P7B
     [Element2]
     name = Q7B

     # Phase C
     [Element1]
     name = P1C
     [Element2]
     name = Q1C
     [Element1]
     name = P3C
     [Element2]
     name = Q3C
     [Element1]
     name = P5C
     [Element2]
     name = Q5C
     [Element1]
     name = P7C
     [Element2]
     name = Q7C

  #prep-a (unmerged)#
     [Main]
     name = Prep A
     path = /path/to/prep-a
     datatype = float32
     keep = 1w

     [Element1]
     name = P1
     [Element2]
     name = Q1
     [Element1]
     name = P3
     [Element2]
     name = Q3
     [Element1]
     name = P5
     [Element2]
     name = Q5
     [Element1]
     name = P7
     [Element2]
     name = Q7

---
"""


class Prep(FilterModule):

    def custom_args(self, parser):
        parser.add_argument("--current-indices", required=True, type=json.loads,
                            help="up to 3 current elements (1 indexed)")
        parser.add_argument("--rotations", required=True, type=json.loads,
                            help="voltage angle for each current (radians)")
        parser.add_argument("--nshift", type=int, default=1,
                            help="amount to shift each window")
        parser.add_argument("--nharm", type=int, default=4,
                            help="number of odd harmonics to compute"),
        parser.add_argument("--scale-factor", type=float, default=1.0,
                            help="scale to watts (RMS voltage)")
        parser.add_argument("--merge", type=utilities.yesno, default=True,
                            help="use a single output stream for all phases")
        parser.add_argument("--polar", type=utilities.yesno, default=False,
                            help="use polar coordinates (mag, angle) or cartesian (P,Q)")
        parser.add_argument("--samp-freq", type=int, required=True,
                            help="sampling frequency in Hz")
        parser.add_argument("--goertzel", type=utilities.yesno,
                            help="use goertzel algorithm or fft")
        parser.description = ARGS_DESC

    n_output = 0
    async def run(self, parsed_args, inputs, outputs):
        nphases = len(parsed_args.current_indices)
        try:
            iv = inputs['iv']
            zc = inputs['zero_crossings']
            if parsed_args.merge:
                output = outputs['prep']
                output.name = 'prep_watched'
                output.enable_cache(60)  # 1 out per second

            else:
                # expect prep-a, prep-b, prep-c output streams
                ph_outputs = []
                for ph in range(nphases):
                    output_pipe = outputs['prep-%c' % ('abc'[ph])]
                    output_pipe.enable_cache(60)
                    ph_outputs.append(output_pipe)
        except KeyError as e:
            raise helpers.FilterError("required inputs [iv, zero_crossings], outputs [prep] (or [prep-a,...])")

        while True:
            try:
                iv_data = await iv.read(flatten=True)
                zc_data = await zc.read(flatten=True)
            except EmptyPipe:
                break

            # print("iv [%d]  %s->%s" % (len(iv_data), timestamp_to_human(iv_data[0, 0]),
            #                           timestamp_to_human(iv_data[-1, 0])))
            # print("zc [%d] %s->%s" % (len(zc_data), timestamp_to_human(zc_data[0, 0]),
            #                          timestamp_to_human(zc_data[-1, 0])))
            await asyncio.sleep(0.2)
            # set up default parameters:
            nharm = parsed_args.nharm
            nshift = parsed_args.nshift
            scale_factor = parsed_args.scale_factor
            rows = iv_data.shape[0]
            data_timestamps = iv_data[:, 0]
            zc_timestamps = zc_data[:, 0]
            samp_freq = parsed_args.samp_freq

            # check to make sure the the zc data overlaps with iv data
            # if we lose a voltage signal iv stalls looking for zc data
            # flush the old iv data because there are not any zc's for it

            if min(zc_timestamps) > max(data_timestamps):
                print("no zero crossings, dumping iv data")
                iv.consume(len(iv_data))
                continue
            if rows < 2:
                continue  # go back and get more data before we run

            last_inserted = None #[joule.utils.time.min_timestamp]

            async def insert_if_nonoverlapping(data):
                """Call insert_function to insert data, but only if this
                data doesn't overlap with other data that we inserted."""
                nonlocal last_inserted
                #                if data[0][0] <= last_inserted[0]:
                if last_inserted is not None and data[0][0] <= last_inserted:
                    return
                last_inserted = data[-1][0]
                if parsed_args.merge:
                    await output.write(data)
                else:
                    width = nharm * 2
                    for ph in range(nphases):
                        start_i = 1 + ph * width
                        end_i = start_i + width
                        cols = [0] + list(range(start_i, end_i))
                        await ph_outputs[ph].write(data[:, cols])

            processed = 0
            out = np.zeros((1, nphases * (nharm * 2) + 1))
            # Pull out sinefit data for the entire time range of this block
            for sinefit_line in zc_data:
                async def prep_period(t_min, t_max, angle_shift,f_expected,f_samp):
                    """
                    Compute prep coefficients from time t_min to t_max, which
                    are the timestamps of the start and end of one period.
                    Results are rotated by an additional extra_rot before
                    being inserted into the database.  Returns the maximum
                    index processed, or None if the period couldn't be
                    processed.
                    """
                    # Find the indices of data that correspond to (t_min,
                    # t_max)
                    idx_min = bisect.bisect_left(data_timestamps, t_min)
                    idx_max = bisect.bisect_left(data_timestamps, t_max)
                    if idx_min >= idx_max or idx_max >= len(data_timestamps):
                        return None
                    time_lag = (data_timestamps[
                                    idx_min] - t_min) / (t_max - t_min)
                    lag_correction = np.e ** (2 * 1j * np.pi * time_lag)

                    # Perform FFT over those indices
                    N = idx_max - idx_min
                    freq_vect = np.zeros(nharm)
                    for ph in range(nphases):
                        col = parsed_args.current_indices[ph]
                        rot = parsed_args.rotations[ph] - angle_shift
                        d = iv_data[idx_min:idx_max, col]
                        
                        if parsed_args.goertzel:
                            for k in range(nharm):
                                freq_vect[k] = (2*k+1)*f_expected
                            F = helpers.goertzel(d,f_samp,freq_vect)
                        else:
                            F = scipy.fftpack.fft(d) * 2.0 / N

                        # If we wanted more harmonics than the FFT gave us, pad
                        # with zeros
                        if N < (nharm * 2):
                            F = np.r_[F, np.zeros(nharm * 2 - N)]

                        # Fill output data.
                        out[0, 0] = round(t_min)
                        for k in range(nharm):
                            if parsed_args.goertzel:
                                Fk = F[k]*np.e**(rot * 1j * (2 * k + 1)) / lag_correction
                            else:
                                Fk = F[2 * k + 1] * \
                                    np.e**(rot * 1j * (2 * k + 1)) / lag_correction
                            # compute output values
                            if parsed_args.polar:
                                out_a = np.abs(Fk) * scale_factor  # VA
                                out_b = (np.angle(Fk) + np.pi / 2) * 2 / np.pi  # pf
                            else:
                                out_a = -np.imag(Fk) * scale_factor  # Pk
                                out_b = -np.real(Fk) * scale_factor  # Qk

                            out[0, ph * nharm * 2 + 2 * k + 1] = out_a
                            out[0, ph * nharm * 2 + 2 * k + 2] = out_b

                    await insert_if_nonoverlapping(out)
                    # zc.consume(1)
                    return idx_max

                # Extract sinefit data to get zero crossing timestamps.
                # t_min = beginning of period
                # t_max = end of period
                t_min = sinefit_line[0]
                f0 = sinefit_line[1]

                t_max = t_min + 1e6 / f0
                # Compute prep over shifted windows of the period
                # (nshift is typically 1)
                for n in range(nshift):
                    # Compute timestamps and rotations for shifted window
                    time_shift = n * (t_max - t_min) / nshift
                    shifted_min = t_min + time_shift
                    shifted_max = t_max + time_shift
                    angle_shift = n * 2 * np.pi / nshift

                    # Run prep computation
                    idx_max = await prep_period(shifted_min,
                                                shifted_max,
                                                angle_shift, f0, samp_freq)
                    if not idx_max:
                        break
                    processed = idx_max

                    if n == (nshift - 1):
                        zc.consume(1)

            """
            # If we processed no data but there's lots in here, pretend we
            # processed half of it.
            if processed == 0 and rows > 10000:
                processed = rows / 2
                logging.warning("%s: warning: no periods found; skipping %d rows\n" %
                                (timestamp_to_human(iv_data[0][0]), processed))
            else:
                logging.info("%s: processed %d of %d rows" %
                             (timestamp_to_human(iv_data[0][0]), processed, rows))
            """
            iv.consume(processed)

            # if production is stalled check to see if the sources have stopped
            if processed == 0:
                if iv.closed or zc.closed:
                    break

        if parsed_args.merge:
            await output.flush_cache()
        else:
            for p in ph_outputs:
                await p.flush_cache()

def main():
    my_filter = Prep()
    my_filter.start()


if __name__ == "__main__":
    main()
