import numpy as np
import sys
import asyncio
import json
from joule.utils.time import (timestamp_to_human,
                              timestamp_to_seconds,
                              seconds_to_timestamp)
from . import helpers
from joule.client import FilterModule

ARGS_DESC = """
Pipes
------------
Inputs:
    'iv'             float32 <V, I_A, I_B, ... > 

Outputs:
    'zero_crossings' float32 <freq, qmp, phase>

Arguments
-------------
config:
  this should be a JSON encoded dict with the
  following properties 

  {
    'v_index':    N   #voltage data column 1 indexed
    'frequency':  N.0 #frequency estimate
     'min_freq':  N.0 #minimum bound
     'max_freq':  N.0 #max bound
     'min_amp':   N.0 #min bound (smallest voltage)
  }
"""


class Sinefit(FilterModule):

    def __init__(self):
        super(Sinefit, self).__init__("Sinefit")
        self.description = "run sinefit alogrithm"
        self.help = '''\
          Compute zero crossings from voltage data data
        '''
        self.arg_description = ARGS_DESC

    def custom_args(self, parser):
        parser.add_argument("config", help="JSON encoded dict")

    async def run(self, parsed_args, inputs, outputs, nrows=0):
        try:
            self._parse_configs(json.loads(parsed_args.config))
        except KeyError as e:
            print("ERROR: invalid config string missing [%s]" % e)
            exit(1)

        try:
            iv = inputs['iv']
            zc = outputs['zero_crossings']
        except KeyError as e:
            print("ERROR: required input [iv], output [zero_crossings]")
            print("missing %s" % e)
            exit(1)
    
        rows_processed = 0
        f_expected = self.frequency
        a_min = self.min_amp
        f_min = self.min_freq
        f_max = self.max_freq
        column = self.v_index

        while(True):
            data = await iv.read(flatten=True)
            await asyncio.sleep(0.1)
            sys.stdout.flush()
            rows = data.shape[0]

            # Estimate sampling frequency from timestamps
            ts_min = timestamp_to_seconds(data[0][0])
            ts_max = timestamp_to_seconds(data[-1][0])
            if ts_min >= ts_max:
                raise ValueError("ts_min>=ts_max")

            fs = (rows - 1) / (ts_max - ts_min)

            # Pull out about 3.5 periods of data at once;
            # we'll expect to match 3 zero crossings in each window

            N = max(int(3.5 * fs / f_expected), 10)

            # If we don't have enough data, don't bother processing it
            if rows < N:
                continue
            warn = helpers.SuppressibleWarning(3, 1000)

            # Process overlapping windows
            start = 0
            num_zc = 0
            last_inserted_timestamp = None
            while start < (rows - N):
                this = data[start:start + N, column]
                t_min = timestamp_to_seconds(data[start, 0])

                # Do 4-parameter sine wave fit
                (A, f0, phi, C) = helpers.sfit4(this, fs)
                # Check bounds.  If frequency is too crazy, ignore this window
                if f0 < f_min or f0 > f_max:
                    warn.warn("frequency %s outside valid range %s - %s\n" %
                              (str(f0), str(f_min), str(f_max)), t_min)
                    start += N
                    continue

                # If amplitude is too low, results are probably just noise
                if A < a_min:
                    warn.warn("amplitude %s below minimum threshold %s\n" %
                              (str(A), str(a_min)), t_min)
                    start += N
                    continue

                # p.plot(arange(N), this)
                # p.plot(arange(N), A * sin(f0/fs * 2 * pi * arange(N) + phi) + C,
                # 'g')

                # Period starts when the argument of sine is 0 degrees,
                # so we're looking for sample number:
                #     n = (0 - phi) / (f0/fs * 2 * pi)
                zc_n = (0 - phi) / (f0 / fs * 2 * np.pi)
                period_n = fs / f0

                # Add periods to make N positive
                while zc_n < 0:
                    zc_n += period_n

                last_zc = None
                # Mark the zero crossings until we're a half period away
                # from the end of the window
                while zc_n < (N - period_n / 2):
                    # p.plot(zc_n, C, 'ro')
                    t = t_min + zc_n / fs
                    if (last_inserted_timestamp is None or
                            t > last_inserted_timestamp):
                        await zc.write(
                            np.array([[seconds_to_timestamp(t), f0, A, C]]))
                        last_inserted_timestamp = t
                        warn.reset(t)
                    else:
                        warn.warn("timestamp overlap\n", t)
                    num_zc += 1
                    last_zc = zc_n
                    zc_n += period_n

                # Advance the window one quarter period past the last marked
                # zero crossing, or advance the window by half its size if we
                # didn't mark any.
                if last_zc is not None:
                    advance = min(last_zc + period_n / 4, N)
                else:
                    advance = N / 2
                # p.plot(advance, C, 'go')
                # p.show()

                start = int(round(start + advance))

            # Return the number of rows we've processed
            warn.reset(last_inserted_timestamp)
            if last_inserted_timestamp:
                now = timestamp_to_human(seconds_to_timestamp(
                    last_inserted_timestamp)) + ": "
            else:
                now = ""
            # logging.info("%d zc in %d/%d rows" %
            #             (num_zc, start, len(data)))
            iv.consume(start)
            rows_processed += start
            if(nrows > 0 and rows_processed >= nrows):
                break

    def _parse_configs(self, config):
        self.v_index = config['v_index']
        self.min_amp = config['min_amp']
        self.min_freq = config['min_freq']
        self.max_freq = config['max_freq']
        self.frequency = config['frequency']
        if(self.min_freq >= self.frequency):
            print("Error min_freq >= frequency")
        if(self.max_freq <= self.frequency):
            print("Error max_freq <= frequency")
        if(self.v_index < 0):
            print("Error invalid v_index")



def main():
    filter = Sinefit()
    filter.start()

if __name__ == "__main__":
    main()
