import numpy as np
import asyncio
import json

from joule import FilterModule
from joule.utilities import yesno
from nilm.filters.helpers import FilterError


ARGS_DESC = """
---
:name:
  NILM Reconstructor
:author:
  John Donnal
:license:
  Closed
:url:
  http://git.wattsworth.net/wattsworth/nilm.git
:description:
  convert raw data to current and voltage
:usage:
  Convert data collected by nilm-reader into current (I) and voltage (V)
 
  | Arguments            | Default | Description
  |----------------------|---------|-----------------------------------------------------------
  |``m-indices``         | --  | magnetic sensors *json array* (1 indexed)
  |``e-indices``         | --  | e-field sensors (1 indexed)
  |``max-gap``           | 10  | max seconds between sample ts and clock
  |``current-matrix``    | --  | convert sensors to currents
  |``integrate``         | yes | integrate e-field *non-contact only*
  |``voltage-scale``     | 1.0 | e-field scale factor *non-contact only*
  |``voltage-matrix``    | --  | convert sensors to voltages *contact only*

:inputs:
 
  raw
  : output from nilm-reader module

:outputs:
  Noncontact Meter:

  output
  : ``float32`` with voltage reference and phase currents

  Contact Meter:

  output
  : ``float32`` with phase voltages and phase currents

:module_config:
    [Main]
    name = NILM Reconstructor
    exec_cmd = nilm-filter-reconstructor

    [Arguments]
    m-indices = [1,2,3]
    e-indices = [4,5,6]
    max-gap = 10
    current-matrix = [[1,0,0],[0,1,0],[0,0,1]]
    voltage-matrix = [[1,0,0],[0,1,0],[0,0,1]]

    [Inputs]
    raw = /path/to/raw

    [Outputs]
    iv = /path/to/iv

:stream_configs:
  #output (noncontact)#
     [Main]
     name = IV
     path = /path/to/output
     datatype = float32
     keep = 1w

     # for 3 phase
     [Element1]
     name = Vref
     [Element2]
     name = I A
     [Element3]
     name = I B
     [Element4]
     name = I C

  #output (contact)#
     [Main]
     name = IV
     path = /path/to/output
     datatype = float32
     keep = 1w

     # for 3 phase
     [Element1]
     name = Current A
     [Element2]
     name = Current B
     [Element3]
     name = Current C
     [Element4]
     name = Voltage A
     [Element5]
     name = Voltage B
     [Element6]
     name = Voltage C
---
"""


class Reconstructor(FilterModule):
    """
      raw-->[reconstructor]-->iv
    """
    
    def custom_args(self, parser):
        
        parser.add_argument("m-indices", type=json.loads,
                            help="magnetic sensors (1 indexed)")
        parser.add_argument("e-indices", type=json.loads,
                            help="e-field sensors (1 indexed)")
        parser.add_argument("max-gap", type=int, default=10,
                            help="max seconds between sample ts and clock")
        parser.add_argument("current-matrix", type=json.loads,
                            help="convert sensors to currents")
        parser.add_argument("integrate", type=yesno, default=True,
                            help="magnetic sensors (1 indexed)")
        parser.add_argument("voltage-matrix", type=json, default=None,
                            help="convert sensors to voltages")

        parser.description = ARGS_DESC
        
    async def run(self, parsed_args, inputs, outputs, nrows=0):
        # Set up operating variables: pipes
        try:
            input = inputs['raw']
            output = outputs['iv']
        except KeyError as e:
            raise FilterError("missing stream [%s]" % e) from e
        fir_integrator = self._setup_integration()
        # Setup is good, now process the data
        total_rows_processed = 0
        dc_offsets = None
        current_matrix = np.array(parsed_args.current_matrix)
        while(True):
            await asyncio.sleep(0.1)
            sarray_in = await input.read()
            if(dc_offsets is None):
                dc_offsets = np.mean(sarray_in['data'], axis=0)

            output_len = len(sarray_in)
            if(parsed_args.integrate):
                size = len(fir_integrator) - 1
                output_len = len(sarray_in) - size
                if(output_len <= 0):
                    # we need more data to integrate
                    continue
            sarray_out = np.zeros(output_len,
                                  dtype=output.dtype)

            data_in = sarray_in['data'].astype('float')
            self._verify_no_gaps(sarray_in['timestamp'], parsed_args.max_gap)
            data_in -= dc_offsets
            m_data = data_in[:, parsed_args.m_indices]
            e_data = data_in[:, parsed_args.e_indices]
            currents = m_data.dot(current_matrix.T)
            voltages = self._process_voltages(np.squeeze(e_data),
                                              fir_integrator,
                                              parsed_args.integrate,
                                              parsed_args.voltage_matrix)
            if(parsed_args.integrate):
                bound = int(size / 2)
                sarray_out['data'] = np.c_[voltages, currents[bound:-bound]]
                sarray_out['timestamp'] = sarray_in['timestamp'][bound:-bound]
            else:
                sarray_out['data'] = np.c_[voltages, currents]
                sarray_out['timestamp'] = sarray_in['timestamp']

            # print("ts: [%d->%d]"%(sarray_in['timestamp'][0],
            #                      sarray_in['timestamp'][-1]))
            # print("reconstructor:%d->%d"%(len(sarray_in),len(sarray_out)))
            # sys.stdout.flush()
            await output.write(sarray_out)
            rows_processed = len(sarray_out)
            total_rows_processed += rows_processed
            input.consume(rows_processed)
            if(nrows > 0 and total_rows_processed >= nrows):
                break

    def _verify_no_gaps(self, ts, max_gap):
        """ if there are gaps in timetsamps, raise error"""
        local_max_gap = np.max(np.diff(ts))
        local_max_gap /= 1e6  # convert to seconds
        if(local_max_gap > max_gap):
            msg = "%d gap in input, cannot process" % local_max_gap
            raise FilterError(msg)

    def _process_voltages(self, e_data, fir_filter,
                          integrate, matrix):
        if(len(np.shape(e_data)) == 1):
            e_data.shape = (len(e_data), 1)
        voltage = np.dot(e_data, matrix)
        if(integrate):
            voltage = np.convolve(fir_filter,
                                  voltage, 'valid')
        return voltage

    #  assume 3kHz sampling on 60Hz line (3000/60)
    def _setup_integration(self, cycle_length=50):
        N = cycle_length // 2  # assumes cycle_length is even
        t = np.arange(1 - N, N)  # -(N-1) to +(N-1) inclusive
        n = np.arange(1, N)    # 1 to N-1 inclusive
        return np.sum(np.sin(np.pi * n * t[:, None] / N) / n / N, 1)


def main():
    filter = Reconstructor()
    filter.start()

    
if __name__ == "__main__":
    main()
