import numpy as np
import sys
import asyncio
import json
from joule.models.pipes import EmptyPipe
from joule.utilities import (timestamp_to_human,
                              timestamp_to_seconds,
                              seconds_to_timestamp)
from . import helpers
from joule import FilterModule

ARGS_DESC = """
---
:name:
  NILM Sinefit
:author:
  John Donnal, James Paris
:license:
  Closed
:url:
  http://git.wattsworth.net/wattsworth/nilm.git
:description:
  characterize a voltage waveform
:usage:
  Compute phase, amplitude and offset at each zero crossing of a signal.
 
  | Arguments     | Default | Description
  |---------------|---------|---------------------------------
  |``v-index``    | --  | voltage element (1 indexed)
  |``frequency``  | 60  | utility line frequency
  |``min-freq``   | 55  | reject frequencies below threshold
  |``max-freq``   | 65  | reject frequencies above threshold
  |``min-amp``    | 10  | reject signals below threshold

:inputs:
 
  iv
  : voltage source (output from nilm-reconstructor)

:outputs:

  zero_crossings
  : ``float32`` with phase, amplitude and offset elements

:module_config:
    [Main]
    name = NILM Sinefit
    exec_cmd = nilm-filter-sinefit

    [Arguments]
    v-index = 1
    frequency = 60
    min-freq = 55
    max-freq = 65
    min-amp = 10

    [Inputs]
    iv = /path/to/iv

    [Outputs]
    zero_crossings = /path/to/zero_crossings

:stream_configs:
  #zero_crossings#
     [Main]
     name = Zero Crossings
     path = /path/to/zero_crossings
     datatype = float32
     keep = 1w

     [Element1]
     name = Amplitude
     [Element2]
     name = Offset
     [Element3]
     name = Phase

---
"""


class Sinefit(FilterModule):

    def custom_args(self, parser):
        parser.add_argument("--v-index", type="int",
                            help="voltage element (1 indexed)")
        parser.add_argument("--frequency", type="float", default=60.0,
                            help="utility line frequency")
        parser.add_argument("--min-freq", type="float", default=55.0,
                            help="reject frequencies below threshold")
        parser.add_argument("--max-freq", type="float", default=65.0,
                            help="reject frequencies above threshold")
        parser.add_argument("--min-amp", type="float", default=10.0,
                            help="reject signals below threshold")
        parser.add_argument("--sfit4-iters", type="int", default=7,
                            help="accuracy increases with higher values")
        parser.add_argument("--cache-size", type="int", default=60,
                            help="pipe cache, default holds 1 second @ 60Hz")
        parser.description = ARGS_DESC

    async def run(self, parsed_args, inputs, outputs, nrows=0):
        # check configs
        if(parsed_args.min_freq >= parsed_args.frequency):
            raise helpers.FilterError("Error min_freq >= frequency")
        if(parsed_args.max_freq <= parsed_args.frequency):
            raise helpers.FilterError("Error max_freq <= frequency")
        if(parsed_args.v_index <= 0):
            raise helpers.FilterError("Error invalid v_index")

        try:
            iv = inputs['iv']
            zc = outputs['zero_crossings']
        except KeyError as e:
            raise helpers.FilterError("missing stream %s" % e)

        # cache the sinefit pipe
        zc.enable_cache(parsed_args.cache_size) # default 1 second @ 60 Hz
        rows_processed = 0
        f_expected = parsed_args.frequency
        a_min = parsed_args.min_amp
        f_min = parsed_args.min_freq
        f_max = parsed_args.max_freq
        sfit4_iters = parsed_args.sfit4_iters
        # 1 indexed but ts is col 0 so this is correct
        column = parsed_args.v_index

       
        
        while True:
            try:
                # check if previous read was the end of an interval
                # if so discard any queued data and close the output interval
                if iv.end_of_interval:
                    iv.consume_all()
                    await zc.close_interval()
                
                data = await iv.read(flatten=True)               
            except EmptyPipe:
                break
            await asyncio.sleep(0.1)
            rows = data.shape[0]

            # Estimate sampling frequency from timestamps
            ts_min = timestamp_to_seconds(data[0][0])
            ts_max = timestamp_to_seconds(data[-1][0])
            if ts_min >= ts_max:
                raise ValueError("ts_min>=ts_max")

            fs = (rows - 1) / (ts_max - ts_min)

            # Pull out about 3.5 periods of data at once;
            # we'll expect to match 3 zero crossings in each window

            N = max(int(3.5 * fs / f_expected), 10)

            # If we don't have enough data, don't bother processing it
            if rows < N:
                # stop if the pipeline is closed
                if iv.closed:
                    break
                continue
            warn = helpers.SuppressibleWarning(3, 1000)

            # Process overlapping windows
            start = 0
            num_zc = 0
            last_inserted_timestamp = None
            
             # flag to track when the interval has been closed due to bad data
            interval_closed_on_error = False
            
            while start < (rows - N):
                this = data[start:start + N, column]
                t_min = timestamp_to_seconds(data[start, 0])

                # Do 4-parameter sine wave fit
                (A, f0, phi, C) = helpers.sfit4(this, fs,num_iters = sfit4_iters)
                # Check bounds.  If frequency is too crazy, ignore this window
                if f0 < f_min or f0 > f_max:
                    warn.warn("frequency %s outside valid range %s - %s\n" %
                              (str(f0), str(f_min), str(f_max)), t_min)
                    start += N
                    if not interval_closed_on_error:
                        interval_closed_on_error = True
                        await zc.close_interval()
                    continue

                # If amplitude is too low, results are probably just noise
                if A < a_min:
                    warn.warn("amplitude %0.2f below minimum threshold %0.2f\n" %
                              (A, a_min))
                    start += N
                    if not interval_closed_on_error:
                        interval_closed_on_error = True
                        await zc.close_interval()
                    continue
                # reset the flag
                interval_closed_on_error = False
                # p.plot(arange(N), this)
                # p.plot(arange(N), A * sin(f0/fs * 2 * pi * arange(N) + phi) + C,
                # 'g')

                # Period starts when the argument of sine is 0 degrees,
                # so we're looking for sample number:
                #     n = (0 - phi) / (f0/fs * 2 * pi)
                zc_n = (0 - phi) / (f0 / fs * 2 * np.pi)
                period_n = fs / f0

                # Add periods to make N positive
                while zc_n < 0:
                    zc_n += period_n

                last_zc = None
                # Mark the zero crossings until we're a half period away
                # from the end of the window
                while zc_n < (N - period_n / 2):
                    # p.plot(zc_n, C, 'ro')
                    t = t_min + zc_n / fs
                    if (last_inserted_timestamp is None or
                            t > last_inserted_timestamp):
                        await zc.write(
                            np.array([[seconds_to_timestamp(t), f0, A, C]]))
                        last_inserted_timestamp = t
                        warn.reset(t)
                    else:
                        warn.warn("timestamp overlap\n", t)
                    num_zc += 1
                    last_zc = zc_n
                    zc_n += period_n

                # Advance the window one quarter period past the last marked
                # zero crossing, or advance the window by half its size if we
                # didn't mark any.
                if last_zc is not None:
                    advance = min(last_zc + period_n / 4, N)
                else:
                    advance = N / 2
                # p.plot(advance, C, 'go')
                # p.show()

                start = int(round(start + advance))

            # Return the number of rows we've processed
            warn.reset(last_inserted_timestamp)
            iv.consume(start)
           
            rows_processed += start
            if(nrows > 0 and rows_processed >= nrows):
                break
        await zc.close()


def main():
    filter = Sinefit()
    filter.start()

    
if __name__ == "__main__":
    main()
