#!/usr/bin/python

"""
Detect the calibration load
Make sure the sensors are not
on the same wire
"""
from numpy.fft import fft
import numpy as np
from . import read_meter

# detect if the calibrator is present
MIN_CAL_SIGNAL = 5e3  # if |S|<N, detection error
# detect if calibrator is on the same phase as before
# minimum difference between calibration unit vectors
MIN_CAL_DIFFERENCE = 0.1
# compensate for FFT details so we get watts
CAL_SCALE = 246
# location to store calibration artifacts
CALIB_DATA = "/opt/configs/cal_data"

"""
Calibration routine
Parameters: 
   cur_phase           [0,1,2]: the current phase under calibration
   meter               meter dictionary object
   replay=False        re-run last calibration
   debug=False         display plots showing calibration process

   returns (code,msg) tuple
     codes: 0  Success
            1  Hardware error
            2  No cal load detected
            3  Redundant phase
"""


# *******NOTE: cur_phase is 0 indexed**********
def run(cur_phase, meter, replay=False, debug=False):
    if (cur_phase < 0):
        print("Error, phase must be >= 0")
        return (1, "Argument error")

    calibration_config = meter['calibration']
    meter_name = meter['name']
    # set up the sensor_index array based off the config
    m_sensors = meter['sensors']['current']['sensor_indices']
    # set up the number of lines to capture based off the duration
    duration = calibration_config['duration']
    if (duration < 9):
        duration = 9  # run for at least 9 seconds
    # round up to a multiple of 3s
    if (duration % 3 != 0):
        duration += 3 - duration % 3
    num_sensors = len(m_sensors)
    if (not calibration_config['has_neutral']):
        if (cur_phase + 1 == 1):
            print(("\t Calibrating [%s] LINE TO LINE pair 1" % meter_name))
        elif (cur_phase + 1 == 2):
            print(("\t Calibrating [%s] LINE TO LINE pair 2" % meter_name))
        else:
            print("ERROR: can't calibrate this phase")
            exit(1)
    else:
        print(("\t Calibrating [%s] phase %d"
               % (meter_name, cur_phase + 1)))
    # print("\t Current Sensors on channels: [%s]"%", ".join("%d"%x for x in m_sensors))
    print(("\t Duration: %d seconds" % duration))
    sensor_coeffs = []
    #########
    # Read the meter, or if replay is specified, read in 
    # from cal_data instead
    if (replay == False):
        prep_buffer = read_meter.run(meter, duration)
        if len(prep_buffer) == 0:
            print("Meter is not available, make sure it is connected or wait and try again")
            exit(0)
        # save the data for debugging
        np.savetxt(CALIB_DATA + "/%s_ph%d.csv" % (meter_name, cur_phase + 1),
                   prep_buffer, delimiter=",")
    else:
        data = open(CALIB_DATA + "/%s_ph%d.csv" % (meter_name, cur_phase + 1), 'r')
        prep_buffer = np.loadtxt(data, delimiter=",")

    for i in range(num_sensors):
        p_coeff = detect_calibrator(prep_buffer[:, 2 * i])
        q_coeff = detect_calibrator(prep_buffer[:, 2 * i + 1])

        sensor_coeffs.append([p_coeff, q_coeff])

    norm_coeffs = np.array(sensor_coeffs)
    # ----DEBUG vector visualization----
    if (debug):
        try:
            import matplotlib
            from matplotlib import pyplot as plt
            matplotlib.use('tkagg')
        except:
            print("Error setting up visualization, do you have an X11 server?")
            exit(1)
        # --plot raw P,Q outputs for each sensor--
        raw_figs = []
        for i in range(len(sensor_coeffs)):
            fig = plt.figure()
            raw_figs.append(fig)
            p_data = prep_buffer[:, 2 * i]
            q_data = prep_buffer[:, 2 * i + 1]
            p_data -= np.mean(p_data)
            q_data -= np.mean(q_data)
            data_max = np.max((np.max(p_data), np.max(q_data)))
            data_min = np.min((np.min(p_data), np.min(q_data)))
            if (i > 0):
                global_max = np.max((global_max, data_max))
                global_min = np.min((global_min, data_min))
            else:
                global_max = data_max
                global_min = data_min
            # rough check on sign of coeff (only good for clean calib data)
            if (np.median(p_data) < 0):
                label = "P ++"
            else:
                label = "P --"
            plt.plot(p_data, label=label)
            if (np.median(q_data) < 0):
                label = "Q ++"
            else:
                label = "Q --"
            plt.plot(q_data, label=label)
            plt.title('Sensor %d' % i)
            plt.legend()
        for f in raw_figs:
            f.gca().set_ylim(global_min - 0.1 * np.abs(global_min),
                             global_max + 0.1 * np.abs(global_max))
        # ---plot normalized sensors on unit circle---
        np_coeffs = np.array(sensor_coeffs)
        fig = plt.figure(facecolor='white')
        fig.add_subplot('111', aspect='equal')
        plt.plot([0, 1], [0, 0], 'k--', label="E-sensor")

        max_sensor = 0
        for i in range(len(norm_coeffs)):
            max_sensor = np.max((max_sensor, np.linalg.norm(norm_coeffs[i])))
        for i in range(len(norm_coeffs)):
            norm_coeffs[i] /= max_sensor
            plt.plot([0, norm_coeffs[i, 0]], [0, norm_coeffs[i, 1]],
                     label="Sensor %d" % i)

        circle = plt.Circle((0, 0), 1, fill=False)
        fig = plt.gcf()
        fig.gca().add_artist(circle)
        plt.ylim([-1, 1])
        plt.xlim([-1, 1])
        plt.legend()
        plt.title('Normalized Sensor vectors for Phase %d' % cur_phase)
        plt.axis("off")

        plt.ion()
        plt.show()

    # now we have p and q coeffs for each sensor
    # 1.) find the largest magnitude sensor
    # 2.) calculate its phase rotation ph = atan(Q/P)
    # 3.) find unit vector in direction of max vector
    # 4.) final_coeff: u*v where u is sensor and v is unit vector
    # 5.) scale coeffs to amps using settings in meters.yml

    final_coeffs = []

    # Step (1)
    sensor_abs = []
    for i in range(num_sensors):
        p = sensor_coeffs[i][0]
        q = sensor_coeffs[i][1]
        sensor_abs.append(np.sqrt(p * p + q * q))
    i = np.argmax(sensor_abs)
    # Step (2)
    ph = np.arctan2(sensor_coeffs[i][1], sensor_coeffs[i][0])
    R = np.array([[np.cos(ph), -1 * np.sin(ph)], [np.sin(ph), np.cos(ph)]])
    # Step (3)
    v = sensor_coeffs[i] / np.linalg.norm(sensor_coeffs[i])
    # Step (4)
    for j in range(num_sensors):
        x = v.dot(sensor_coeffs[j])
        final_coeffs.append(x)
    # Step (5)
    cal_watts = calibration_config["watts"]
    rms_voltage = float(meter["sensors"]["voltage"]["nominal_rms_voltage"])
    scale_factor = cal_watts * np.sqrt(2) / (rms_voltage)
    sensor_coeffs = np.array(final_coeffs) / CAL_SCALE / scale_factor

    # see if we found the calibrator
    if (np.linalg.norm(final_coeffs) < MIN_CAL_SIGNAL):
        print("[ERROR] can't detect calibration load")
        return (2, "can't detect calibration load")

    # see if the calibrator is on a different phase
    if (cur_phase > 0):  # only check for phases > 0
        for phase in range(cur_phase):
            prev_coeffs = meter['calibration']["sensor_matrix"][:, phase]
            v = prev_coeffs / np.linalg.norm(prev_coeffs)
            u = sensor_coeffs / np.linalg.norm(sensor_coeffs)
            if (np.linalg.norm(v - u) < MIN_CAL_DIFFERENCE):
                print("[ERROR] calibration on same phase as before")
                return (3, "move plug to phase B")

    meter['calibration']["sensor_matrix"][:, cur_phase] = sensor_coeffs
    meter['calibration']["sinefit_rotations"][cur_phase] = ph
    meter['calibration']["pq_coeffs"][cur_phase, :, :] = norm_coeffs

    if (debug):
        # ---plot final_coeffs on a numberline defined by largest sensor--
        fig = plt.figure(facecolor="white")
        # make the number line 10% greater than the largest val
        max_coeff = np.max(np.abs(sensor_coeffs))
        arrowprops = dict(arrowstyle='simple',
                          facecolor="grey")

        plt.annotate("", xy=(-1.1 * max_coeff, 0), xytext=(0.3, 0),
                     arrowprops=arrowprops)
        plt.annotate("", xy=(1.1 * max_coeff, 0), xytext=(-0.3, 0),
                     arrowprops=arrowprops)

        plt.hlines(0, -1.1 * max_coeff, max_coeff * 1.1)
        for i in range(len(sensor_coeffs)):
            x = sensor_coeffs[i]
            if (x == max_coeff):
                plt.vlines(x, -2, 2, linewidth=4, color='blue',
                           label="Reference Sensor")
            else:
                plt.vlines(x, -2, 2)
            plt.text(x - 1, 5, "S%d" % i)
            plt.text(x - 1, -5, "%0.2f" % x, rotation=-45)

        plt.vlines(0, -3, 3, color='green', label="Zero")
        plt.vlines(0, 0, 0, label="Other Sensors")
        plt.xlim(-1.2 * max_coeff, max_coeff * 1.2)
        plt.ylim(-40, 40)
        plt.legend()
        plt.axis('off')
        plt.title("Sensor Coeffecients for Phase %d" % cur_phase)
        plt.show()
        x = input("Press [enter] to continue...")
    return (0, "OK")


def detect_calibrator(x_input):
    """
    For input signal x, find the 3 Hz 
    calibration signal in the waveform and return
    its amplitude. The input waveform can have additional
    content besides the calibration signal

    Calculates the Discrete Fourier Series Coeff for the
    waveform at 1/3Hz which is the fundamental of the 
    square wave signal. Multiplying by pi/4 recovers
    the amplitude of the square wave itself. To increase
    the robustness of the calculation, the input waveform
    is divided in to BSize sample chunks and the DFS coeff's 
    are averaged together

    calculate the sign of each term by time shifting the waveform 
    and checking the second harmonic. If it is close to pi then
    this is a positive term, if its close to zero then it is negative
    """
    PERIOD = 180
    HALF_PERIOD = 90
    XK6_INDEX = 6  # depends on nfft (Bsize)
    XK3_INDEX = 3
    Bsize = PERIOD * 3  # N periods per block
    a = np.shape(x_input)
    # divide the input into blocks
    Nb = int(np.floor(a[0] / Bsize))  # number of blocks

    xk_sum = 0
    for j in range(Nb):
        N = Bsize
        x = x_input[(j) * Bsize:(j + 1) * Bsize]
        X = fft(x)
        XK3 = X[XK3_INDEX]
        XK6 = X[XK6_INDEX]
        xk3 = np.abs(XK3)
        # correct the time shift
        M = PERIOD * np.angle(XK3) / (2 * np.pi)
        m_delta = HALF_PERIOD - M  # set time shift to HALF PERIOD
        xk6_ph = np.angle(XK6 * np.exp(1j * 2 * np.pi * XK6_INDEX * m_delta / Bsize))
        if np.abs(xk6_ph) > np.pi / 2:  # cal load is mostly *off*
            xk3 = -xk3
        xk_sum = xk_sum + xk3
    xk_avg = xk_sum / Nb
    s = xk_avg * np.pi / 4 * 2
    return s
