import pdb
import unittest
import asyncio
import argparse
import numpy as np
from numpy.testing import assert_array_almost_equal

from joule import LocalPipe
from nilm.filters import prep


class TestPrep(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

    def tearDown(self):
        self.loop.close()

    def test_computes_separated_prep(self):
        PHASES = 3
        NHARM = 4

        phase_content = []
        for ph in range(PHASES):
            phase_content.append(np.random.randint(-10, 10, NHARM * 2))

        phase_rotations = np.random.rand(PHASES)
        test_iv_data = self._create_iv_data(phase_content,
                                            phase_rotations, f=50, fs=2000)
        test_zc_data = self._create_zc_data()

        npipe_iv = LocalPipe("float32_%d" % (
                PHASES + 1), name="iv")
        npipe_zc = LocalPipe("float32_3", name="zc")
        outputs = {}
        for ph in range(PHASES):
            name = "prep-%c" % 'abc'[ph]
            outputs[name] = LocalPipe("float32_%d" % (NHARM * 2), name=name)

        npipe_iv.write_nowait(test_iv_data)
        npipe_zc.write_nowait(test_zc_data)
        npipe_iv.close_nowait()
        npipe_zc.close_nowait()

        args = argparse.Namespace(**{
            # [2,3,4] for 3 phase
            "current_indices": (2 + np.array(range(PHASES))).tolist(),
            "rotations": phase_rotations.tolist(),
            "nshift": 1,
            "nharm": NHARM,
            "scale_factor": 1.0,
            "samp_freq": 2000,
            "line_freq": 50,
            "merge": False,
            "polar": False,
            "goertzel": False
        })

        my_filter = prep.Prep()
        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            my_filter.run(args,
                          {"iv": npipe_iv,
                           "zero_crossings": npipe_zc},
                          outputs))

        expected_content = np.r_[phase_content]
        expected_content.shape = (1, len(phase_content) * NHARM * 2)
        # make sure prep produced a sample for all of the zero crossings
        # acceptable if we missed one because the end of the data
        width = NHARM * 2
        for ph in range(PHASES):
            name = "prep-%c" % 'abc'[ph]
            start_i = ph * width
            end_i = start_i + width
            result = outputs[name].read_nowait()
            self.assertGreaterEqual(len(result), len(test_zc_data) - 1)
            ph_cols = list(range(start_i, end_i))
            expected_ph_data = np.squeeze(expected_content)[ph_cols]
            for row in result:
                assert_array_almost_equal(np.squeeze(row['data']),
                                          expected_ph_data)

    def test_computes_meged_prep(self):
        PHASES = 3
        NHARM = 4

        phase_content = []
        for ph in range(PHASES):
            phase_content.append(np.random.randint(-10, 10, NHARM * 2))

        phase_rotations = np.random.rand(PHASES)
        test_iv_data = self._create_iv_data(phase_content,
                                            phase_rotations,
                                            f=50, fs=2000)
        test_zc_data = self._create_zc_data()

        npipe_iv = LocalPipe("float32_%d" % (PHASES + 1), name="iv")
        npipe_zc = LocalPipe("float32_3", name="zc")
        npipe_prep = LocalPipe("float32_%d" % (PHASES * NHARM * 2), name="prep")

        npipe_iv.write_nowait(test_iv_data)
        npipe_zc.write_nowait(test_zc_data)
        npipe_iv.close_nowait()
        npipe_zc.close_nowait()

        args = argparse.Namespace(**{
            # [2,3,4] for 3 phase
            "current_indices": (2 + np.array(range(PHASES))).tolist(),
            "rotations": phase_rotations.tolist(),
            "nshift": 1,
            "nharm": NHARM,
            "scale_factor": 1.0,
            "merge": True,
            "polar": False,
            "samp_freq": 2000,
            "line_freq": 50,
            "goertzel": False
        })
        my_filter = prep.Prep()

        loop = asyncio.get_event_loop()
        loop.run_until_complete(
            my_filter.run(args,
                          {"iv": npipe_iv,
                           "zero_crossings": npipe_zc},
                          {"prep": npipe_prep}))
        result = npipe_prep.read_nowait()
        expected_content = np.r_[phase_content]
        expected_content.shape = (1, len(phase_content) * NHARM * 2)
        # make sure prep produced a sample for all of the zero crossings
        # acceptable if we missed one because the end of the data
        self.assertGreaterEqual(len(result), len(test_zc_data) - 1)

        for row in result:
            assert_array_almost_equal(np.squeeze(row['data']),
                                      np.squeeze(expected_content))

    def test_computes_mag_pf_prep(self):
        PHASES = 3
        NHARM = 4

        phase_content = [[1, 1, 1, -1, 0.01, 3, 0.01, -3],
                         [1, 0, 3, 0, 4, 1, 3, 2],
                         [8, 9, 5, -5, 4, 8, 9, 0]]

        phase_rotations = np.random.rand(PHASES)
        test_iv_data = self._create_iv_data(phase_content,
                                            phase_rotations,
                                            f=50, fs=2000)
        test_zc_data = self._create_zc_data()

        npipe_iv = LocalPipe("float32_%d" % (PHASES + 1), name="iv")
        npipe_zc = LocalPipe("float32_3", name="zc")
        npipe_prep = LocalPipe(layout="float32_%d" % (PHASES * NHARM * 2), name="prep")

        npipe_iv.write_nowait(test_iv_data)
        npipe_zc.write_nowait(test_zc_data)
        npipe_iv.close_nowait()
        npipe_zc.close_nowait()

        args = argparse.Namespace(**{
            # [2,3,4] for 3 phase
            "current_indices": (2 + np.array(range(PHASES))).tolist(),
            "rotations": phase_rotations.tolist(),
            "nshift": 1,
            "nharm": NHARM,
            "scale_factor": 1.0,
            "merge": True,
            "polar": True,
            "samp_freq": 2000,
            "line_freq": 50,
            "goertzel": False
        })
        my_filter = prep.Prep()

        loop = asyncio.get_event_loop()

        loop.run_until_complete(
            my_filter.run(args,
                          {"iv": npipe_iv,
                           "zero_crossings": npipe_zc},
                          {"prep": npipe_prep}))
        result = npipe_prep.read_nowait()
        expected_content = np.empty(np.shape(phase_content))
        for i in range(PHASES):
            pq = phase_content[i]
            for k in range(NHARM):
                expected_content[i, 2 * k] = \
                    np.linalg.norm([pq[2 * k], pq[2 * k + 1]])
                expected_content[i, 2 * k + 1] = \
                    np.angle(pq[2 * k] + pq[2 * k + 1] * 1j) * -2 / np.pi

        # make sure prep produced a sample for all of the zero crossings
        # acceptable if we missed one because the end of the data
        self.assertGreaterEqual(len(result), len(test_zc_data) - 1)

        for row in result:
            actual_content = row['data']
            actual_content.shape = (PHASES, NHARM * 2)
            assert_array_almost_equal(actual_content, expected_content)

    def _create_iv_data(self,
                        phase_content,
                        phase_rotations,
                        f=50, t=2, fs=2000):
        t = np.arange(0, t, step=1 / fs)
        v = np.sin(2 * np.pi * t * f)

        data = np.c_[t * 1e6, v]

        def build_waveform(harm_content, rot):
            k = np.zeros(np.shape(v))
            for h in range(len(harm_content)):
                n = 2 * (h // 2) + 1
                ph = 0 if h % 2 == 0 else -np.pi / 2
                A = harm_content[h]
                k += A * np.sin(2 * np.pi * t * f * n + ph - rot * n)
            return k

        for ph in range(len(phase_content)):
            i = build_waveform(phase_content[ph], phase_rotations[ph])
            data = np.c_[data, i]
        return data

    def _create_zc_data(self, f=50, a=1.0, c=0, t=2, fs=2000):
        ts = np.arange(0, t, step=1 / f) * 1e6
        f_s = f * np.ones(len(ts))
        a_s = a * np.ones(len(ts))
        c_s = c * np.ones(len(ts))
        return np.c_[ts, f_s, a_s, c_s]
