
from joule import LocalPipe
import unittest
import asyncio
import argparse
import numpy as np
import scipy.integrate
from tests import helpers
from nilm.filters import reconstructor

#  ---testing---
# import matplotlib.pyplot as plt


class TestReconstructor(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)
        
    def tearDown(self):
        self.loop.close()
        
    def test_integrating_noncontact_reconstruction(self):
        INPUT_WIDTH = 8
        INPUT_FORMAT = "int16"
        PHASES = 3
        OUTPUT_WIDTH = PHASES + 1
        OUTPUT_FORMAT = "float32"
        E_INDICES = [4]
        VOLT_MATRIX = [1.8]
        M_INDICES = [1, 2, 3, 5, 7]
        CURRENT_MATRIX = np.random.randint(-10, 10, (PHASES, len(M_INDICES)))

        input_layout = "%s_%s" % (INPUT_FORMAT, INPUT_WIDTH)
        output_layout = "%s_%s" % (OUTPUT_FORMAT, OUTPUT_WIDTH)
        test_data = helpers.create_data(
            input_layout, length=1000, step=1e6 / 3000)
        test_data['data'][:, 1] += 30  # add an offset
        test_data['data'][:, 2] += 45  # add an offset
        test_data['data'][:, 3] += 60  # add an offset

        # replace voltage column with a sinewave to test the integrator
        test_data['data'][:, E_INDICES] = self.create_sine(
            f=60, ts=test_data['timestamp'] / 1e6)[:, None]
        npipe_raw = LocalPipe(layout=input_layout, name="raw")
        npipe_iv = LocalPipe(layout=output_layout, name="iv")

        args = argparse.Namespace(**{
            'm_indices':          M_INDICES,
            'e_indices':          E_INDICES,
            'max_gap':            5,
            'current_matrix':     CURRENT_MATRIX.tolist(),
            'integrate':          True,
            'voltage_matrix':     VOLT_MATRIX
        })
        my_filter = reconstructor.Reconstructor()
        loop = asyncio.get_event_loop()
        # test buffering with multiple write blocks

        async def write():
            await npipe_raw.write(test_data[:len(test_data) // 2])
            await asyncio.sleep(0.1)
            await npipe_raw.write(test_data[len(test_data) // 2:])

        tasks = [asyncio.ensure_future(write()),
                 asyncio.ensure_future(
                     my_filter.run(args,
                                   {"raw": npipe_raw},
                                   {"iv": npipe_iv},
                                   nrows=len(test_data) - 50))]
        loop.run_until_complete(asyncio.gather(*tasks))
        result = npipe_iv.read_nowait()

        # remove DC offsets from test_data
        # test_data['data'] -= np.mean(test_data['data'], 0, dtype="int16")

        # compute expected value
        expected = np.zeros(len(test_data), dtype=npipe_iv.dtype)
        expected['timestamp'] = test_data['timestamp']
        # align the expected output with the actual output
        for i in range(len(expected)):
            if(expected['timestamp'][i] == result['timestamp'][0]):
                break
        start = i
        end = start + len(result)

        # compute expected currents
        expected['data'][:, 1:] = test_data['data'][
            :, M_INDICES].dot(CURRENT_MATRIX.T)

        # compute expected voltage
        expected['data'][1:, 0] = scipy.integrate.cumtrapz(
            np.squeeze(test_data['data'][:, E_INDICES]),
            test_data['timestamp'] / 1e6) * VOLT_MATRIX[0] * 375
        # remove mean due to trapz integration
        offset = np.mean(expected['data'][:-1, :], axis=0)
        expected['data'][:, :] -= offset

        # --visualize integration---
        """
        plt.plot(test_data['timestamp'],
                 test_data['data'][:, E_INDICES], label="input")
        plt.plot(expected['timestamp'],
                 expected['data'][:, 0], label="expected")
        plt.plot(result['timestamp'],
                 result['data'][:, 0], label="result")
        plt.legend()
        plt.show()
        """
        # --visualize reconstruction--
        """
        plt.plot(test_data['timestamp'],
                 test_data['data'][:, M_INDICES[0]],
                 label="inputA")
        plt.plot(result['timestamp'],
                 result['data'][:, 1],
                 label="resultA")
        plt.plot(test_data['timestamp'],
                 test_data['data'][:, M_INDICES[1]],
                 label="inputB")
        plt.plot(result['timestamp'],
                 result['data'][:, 2],
                 label="resultB")
        plt.plot(test_data['timestamp'],
                 test_data['data'][:, M_INDICES[2]],
                 label="inputC")
        plt.plot(result['timestamp'],
                 result['data'][:, 3],
                 label="resultC")
        plt.ylim(-200, 200)
        plt.legend()
        plt.show()
        """
        # verify correct timestamps
        np.testing.assert_array_almost_equal(
            result['timestamp'],
            expected['timestamp'][start:end])
        # verify correct currents
        np.testing.assert_array_almost_equal(
            result['data'][:, 1:],
            expected['data'][start:end, 1:],
            decimal=1)
        # verify correct voltages
        np.testing.assert_array_almost_equal(
            result['data'][:, 0] / 100,
            expected['data'][start:end, 0] / 100,
            decimal=1)

    def create_sine(self, f, ts):
        v = 1000 * np.sin(2 * np.pi * ts * f)
        return(v)

    def test_plain_noncontact_reconstruction(self):
        INPUT_WIDTH = 8
        INPUT_FORMAT = "int16"
        PHASES = 3
        OUTPUT_WIDTH = PHASES + 1
        OUTPUT_FORMAT = "float32"
        E_INDICES = [4]
        VOLT_SCALE = 5.0
        M_INDICES = [1, 2, 3, 5, 7]
        CURRENT_MATRIX = np.random.randint(-10, 10, (PHASES, len(M_INDICES)))

        input_layout = "%s_%s" % (INPUT_FORMAT, INPUT_WIDTH)
        output_layout = "%s_%s" % (OUTPUT_FORMAT, OUTPUT_WIDTH)
        test_data = helpers.create_data(input_layout, length=500)
        # replace voltage column with a sinewave to test the integrator
        test_data['data'][:, 0] = self.create_sine(
            f=60, ts=test_data['timestamp'] / 1e6)
        npipe_raw = LocalPipe(layout=input_layout, name="raw")
        npipe_iv = LocalPipe(layout=output_layout, name="iv")

        args = argparse.Namespace(**{
            'm_indices':          M_INDICES,
            'e_indices':          E_INDICES,
            'max_gap':            5,
            'current_matrix':     CURRENT_MATRIX.tolist(),
            'integrate':          False,
            'voltage_matrix':     [VOLT_SCALE]
        })
        my_filter = reconstructor.Reconstructor()
        
        loop = asyncio.get_event_loop()
        npipe_raw.write_nowait(test_data)
        loop.run_until_complete(my_filter.run(args,
                                              {'raw': npipe_raw},
                                              {'iv': npipe_iv},
                                              nrows=len(test_data) - 50))
        result = npipe_iv.read_nowait()

        # compute expected value
        expected = np.zeros(len(test_data), dtype=npipe_iv.dtype)
        expected['timestamp'] = test_data['timestamp']

        # remove DC offsets from test_data
        test_data['data'] -= np.mean(test_data['data'], 0, dtype="int16")

        # compute expected currents
        expected['data'][:, 1:] = test_data['data'][
            :, M_INDICES].dot(CURRENT_MATRIX.T)
        # compute expected voltage
        expected['data'][:, 0] = np.squeeze(
            test_data['data'][:, E_INDICES] * VOLT_SCALE)
        expected['data'] -= np.mean(expected['data'], axis=0)
        # verify correct timestamps
        np.testing.assert_array_almost_equal(
            result['timestamp'], expected['timestamp'])

        # verify correct currents
        np.testing.assert_array_almost_equal(result['data'], expected['data'],
                                             decimal=3)

    def test_contact_reconstruction(self):
        INPUT_WIDTH = 6
        INPUT_FORMAT = "uint16"
        PHASES = 3
        OUTPUT_WIDTH = PHASES*2
        OUTPUT_FORMAT = "float32"
        E_INDICES = [3, 4, 5][:PHASES]
        M_INDICES = [0, 1, 2][:PHASES]

        input_layout = "%s_%s" % (INPUT_FORMAT, INPUT_WIDTH)
        output_layout = "%s_%s" % (OUTPUT_FORMAT, OUTPUT_WIDTH)
        test_data = helpers.create_data(input_layout, length=500)
        npipe_raw = LocalPipe(layout=input_layout, name="raw")
        npipe_iv = LocalPipe(layout=output_layout, name="iv")
        
        CURRENT_MATRIX = np.eye(PHASES)*2
        VOLTAGE_MATRIX = np.eye(PHASES)*3
        """
        CURRENT_MATRIX = np.array([[2.0, 0.0, 0.0],
                                   [0.0, 3.0, 0.0],
                                   [0.0, 0.0, 4.0]])

        VOLTAGE_MATRIX = np.array([[0.0, 2.0, 0.0],
                                   [4.0, 0.0, 0.0],
                                   [0.0, 0.0, 3.0]])
        """
        
        args = argparse.Namespace(**{
            'm_indices':          M_INDICES,
            'e_indices':          E_INDICES,
            'max_gap':            5,
            'current_matrix':     CURRENT_MATRIX.tolist(),
            'integrate':          False,
            'voltage_matrix':     VOLTAGE_MATRIX.tolist(),
            'voltage_scale':      1.0
        })
        my_filter = reconstructor.Reconstructor()
        loop = asyncio.get_event_loop()
        npipe_raw.write_nowait(test_data)
        loop.run_until_complete(my_filter.run(args,
                                              {'raw': npipe_raw},
                                              {'iv': npipe_iv},
                                              nrows=len(test_data) - 50))
        result = npipe_iv.read_nowait()

        # compute expected value
        x = np.empty(len(test_data), dtype=npipe_iv.dtype)
        expected = x['data']
        datain = test_data['data']
        # compute VOLTAGE
        vdata = np.dot(datain[:, E_INDICES],
                       VOLTAGE_MATRIX)
        if(len(vdata.shape) == 1):
            vdata.shape = (len(vdata), 1)
        expected[:, :PHASES] = vdata
        # compute CURRENT
        idata = np.dot(datain[:, M_INDICES],
                       CURRENT_MATRIX)
        if(len(idata.shape) == 1):
            idata.shape = (len(idata), 1)
        expected[:, PHASES:] = idata
        expected = expected - np.mean(expected, axis=0)
        np.testing.assert_array_almost_equal(result['data'], expected,
                                             decimal=2)
