import numpy as np
import asyncio
import json

from joule.client import FilterModule
from nilm.filters.helpers import FilterError


ARGS_DESC = """
Pipes
-----------
Inputs: 
  'raw' float32 <ch1, ch2, ... chM> for M ADC channels
Outputs (noncontact):
  'iv' float32 <V, I_A, I_B, I_C>  eg for 3 phase
Outputs (contact):
  'iv' float32 <V_A, I_A, V_B, I_B> eg for 2 phase

Arguments
-------------
config:
  this should be a JSON encoded dict with the
  following properties (note that some fields depend on meter type):

  {
          m_indices: array
          e_indices: array
          max_gap:   seconds
          current_matrix: matrix
          nominal_frequency: Hz
          sampling_frequency: Hz
          ---noncontact only--
          integrate: bool
          voltage_scale: float
          ---contact only--
          voltage_matrix: matrix
  }
"""


class Reconstructor(FilterModule):
    """
      raw-->[reconstructor]-->iv
    """

    def __init__(self):
        super(Reconstructor, self).__init__("Reconstructor")
        self.description = "convert ADC counts to I,V"
        self.help = '''\
          apply conversion matrix to 
          ADC samples to recover current and voltage
        '''
        self.arg_description = ARGS_DESC
        
    def custom_args(self, parser):
        parser.add_argument("config", help="JSON encoded dict")
        
    def _parse_configs(self, config):
        self.noncontact = config["noncontact"]
        self.m_indices = config["m_indices"]
        self.e_indices = config["e_indices"]
        self.max_gap = config["max_gap"]
        self.current_matrix = np.array(config["current_matrix"])
        if(self.noncontact):
            self.integrate = config["integrate"]
            self.voltage_scale = config["voltage_matrix"][0]
        else:
            self.voltage_matrix = config["voltage_matrix"]
            self.integrate = False
        f = config["nominal_frequency"]
        fs = config["sampling_frequency"]

        self.filter = self._setup_integration(fs / f)

    async def run(self, parsed_args, inputs, outputs, nrows=0):
        # Set up operating variables: pipes
        try:
            input = inputs['raw']
            output = outputs['iv']
        except KeyError as e:
            print("ERROR: required inputs [raw], outputs ['iv']")
            print("missing %s" % e)
            exit(1)
        # Set up operating variables: configs
        try:
            self._parse_configs(json.loads(parsed_args.config))
        except KeyError as e:
            print("ERROR: invalid config string missing [%s]" % e)
            exit(1)
        # Setup is good, now process the data
        total_rows_processed = 0
        dc_offsets = None
        while(True):
            await asyncio.sleep(0.1)
            sarray_in = await input.read()
            if(dc_offsets is None):
                dc_offsets = np.mean(sarray_in['data'], axis=0)

            output_len = len(sarray_in)
            if(self.integrate):
                size = len(self.filter) - 1
                output_len = len(sarray_in) - size
                if(output_len <= 0):
                    # we need more data to integrate
                    continue
            sarray_out = np.zeros(output_len,
                                  dtype=output.dtype)

            data_in = sarray_in['data'].astype('float')
            self._verify_no_gaps(sarray_in['timestamp'])
            data_in -= dc_offsets
            m_data = data_in[:, self.m_indices]
            e_data = data_in[:, self.e_indices]
            currents = m_data.dot(self.current_matrix.T)
            voltages = self._process_voltages(np.squeeze(e_data))
            if(self.integrate):
                bound = int(size / 2)
                sarray_out['data'] = np.c_[voltages, currents[bound:-bound]]
                sarray_out['timestamp'] = sarray_in['timestamp'][bound:-bound]
            else:
                sarray_out['data'] = np.c_[voltages, currents]
                sarray_out['timestamp'] = sarray_in['timestamp']

            # print("ts: [%d->%d]"%(sarray_in['timestamp'][0],
            #                      sarray_in['timestamp'][-1]))
            # print("reconstructor:%d->%d"%(len(sarray_in),len(sarray_out)))
            # sys.stdout.flush()
            await output.write(sarray_out)
            rows_processed = len(sarray_out)
            total_rows_processed += rows_processed
            input.consume(rows_processed)
            if(nrows > 0 and total_rows_processed >= nrows):
                break

    def _verify_no_gaps(self, ts):
        """ if there are gaps in timetsamps, raise error"""
        max_gap = np.max(np.diff(ts))
        max_gap /= 1e6  # convert to seconds
        if(max_gap > self.max_gap):
            msg = "%d gap in input, cannot process" % max_gap
            raise FilterError(msg)

    def _process_voltages(self, e_data):
        if(self.noncontact):
            if(self.integrate):
                voltage = np.convolve(self.filter,
                                      e_data, 'valid')
            else:
                voltage = e_data
            return voltage * self.voltage_scale
        else:  # contact sensor
            if(len(np.shape(e_data)) == 1):
                e_data.shape = (len(e_data),1)
            return np.dot(e_data, self.voltage_matrix)


    # assume 3kHz sampling on 60Hz line
    def _setup_integration(self, cycle_length=50):
        N = cycle_length // 2  # assumes cycle_length is even
        t = np.arange(1 - N, N)  # -(N-1) to +(N-1) inclusive
        n = np.arange(1, N)    # 1 to N-1 inclusive
        return np.sum(np.sin(np.pi * n * t[:, None] / N) / n / N, 1)



def main():
    filter = Reconstructor()
    filter.start()

if __name__ == "__main__":
    main()
