#!/usr/bin/python

"""
Detect the calibration load
Make sure the sensors are not
on the same wire
"""
from numpy.fft import fft
import numpy as np
from . import read_meter
import pandas as pd

# detect if the calibrator is present
MIN_CAL_SIGNAL = 10  # if |S|<N, detection error
# detect if calibrator is on the same phase as before
# minimum difference between calibration unit vectors
MIN_CAL_DIFFERENCE = 0.1
# location to store calibration artifacts
CALIB_DATA = "/opt/configs/cal_data"

"""
Calibration routine
Parameters: 
   cur_phase           [0,1,2]: the current phase under calibration
   meter               meter dictionary object
   replay=False        re-run last calibration
   debug=False         display plots showing calibration process

   returns (code,msg) tuple
     codes: 0  Success
            1  Hardware error
            2  No cal load detected
            3  Redundant phase
"""


# *******NOTE: cur_phase is 0 indexed**********
def run(cur_phase, meter, replay=False, debug=False):
    if (cur_phase < 0):
        print("Error, phase must be >= 0")
        return (1, "Argument error")

    calibration_config = meter['calibration']
    meter_name = meter['name']
    # set up the sensor_index array based off the config
    m_sensors = meter['sensors']['current']['sensor_indices']
    # set up the number of lines to capture based off the duration
    duration = calibration_config['duration']
    if (duration < 9):
        duration = 9  # run for at least 9 seconds
    # round up to a multiple of 3s
    if (duration % 3 != 0):
        duration += 3 - duration % 3
    num_sensors = len(m_sensors)
    if (not calibration_config['has_neutral']):
        if (cur_phase + 1 == 1):
            print(("\t Calibrating [%s] LINE TO LINE pair 1" % meter_name))
        elif (cur_phase + 1 == 2):
            print(("\t Calibrating [%s] LINE TO LINE pair 2" % meter_name))
        else:
            print("ERROR: can't calibrate this phase")
            exit(1)
    else:
        print(("\t Calibrating [%s] phase %d"
               % (meter_name, cur_phase + 1)))
    # print("\t Current Sensors on channels: [%s]"%", ".join("%d"%x for x in m_sensors))
    print(("\t Duration: %d seconds" % duration))
    sensor_coeffs = []
    #########
    # Read the meter, or if replay is specified, read in 
    # from cal_data instead
    if (replay == False):
        prep_buffer = read_meter.run(meter, duration)
        if len(prep_buffer) == 0:
            print("Meter is not available, make sure it is connected or wait and try again")
            exit(0)
        # save the data for debugging
        np.savetxt(CALIB_DATA + "/%s_ph%d.csv" % (meter_name, cur_phase + 1),
                   prep_buffer, delimiter=",")
    else:
        data = open(CALIB_DATA + "/%s_ph%d.csv" % (meter_name, cur_phase + 1), 'r')
        prep_buffer = np.loadtxt(data, delimiter=",")

    Fs = 60
    cal_on =  calibration_config['on_duration']         # calibration load ON time in seconds
    cal_off = calibration_config['off_duration']        # calibration load OFF time in seconds
    cal_T = cal_on +  cal_off                           # calibration load period in seconds
    period = Fs*cal_T                                   # calibration load period in samples
    f_fund = 1/cal_T                                    # fundamental frequency

    # frequency of calibration load
    if cal_on <=1:
        k = int(1)
    else:
        k = int(cal_on)

    d = cal_on/cal_T                                    # duty cycle 
    ratio = cal_on/cal_off


    Bsize = int(period*3)                               # block size (DFT length)
    delta_f = Fs/Bsize                                  # frequency spacing

    index_1 = int(k*f_fund*Bsize/Fs)                        # index of harmonic of interest

    sensor_coeffs = []
    for i in range(num_sensors):
        p_coeff = detect_calibrator(prep_buffer[:,2*i],period,index_1,Bsize,d,k,ratio)
        q_coeff = detect_calibrator(prep_buffer[:,2*i+1],period,index_1,Bsize,d,k,ratio)
       
        sensor_coeffs.append([p_coeff,q_coeff])

    norm_coeffs = np.array(sensor_coeffs)

    # ----DEBUG vector visualization----
    if (debug):
        try:
            import matplotlib
            from matplotlib import pyplot as plt
            matplotlib.use('tkagg')
        except:
            print("Error setting up visualization, do you have an X11 server?")
            exit(1)
        # --plot raw P,Q outputs for each sensor--
        raw_figs = []
        for i in range(len(sensor_coeffs)):
            fig = plt.figure()
            raw_figs.append(fig)
            p_data = prep_buffer[:, 2 * i]
            q_data = prep_buffer[:, 2 * i + 1]
            p_data -= np.mean(p_data)
            q_data -= np.mean(q_data)
            data_max = np.max((np.max(p_data), np.max(q_data)))
            data_min = np.min((np.min(p_data), np.min(q_data)))
            if (i > 0):
                global_max = np.max((global_max, data_max))
                global_min = np.min((global_min, data_min))
            else:
                global_max = data_max
                global_min = data_min
            # rough check on sign of coeff (only good for clean calib data)
            if (np.median(p_data) < 0):
                label = "P ++"
            else:
                label = "P --"
            plt.plot(p_data, label=label)
            if (np.median(q_data) < 0):
                label = "Q ++"
            else:
                label = "Q --"
            plt.plot(q_data, label=label)
            plt.title('Sensor %d' % i)
            plt.legend()
        for f in raw_figs:
            f.gca().set_ylim(global_min - 0.1 * np.abs(global_min),
                             global_max + 0.1 * np.abs(global_max))
        # ---plot normalized sensors on unit circle---
        np_coeffs = np.array(sensor_coeffs)
        fig = plt.figure(facecolor='white')
        fig.add_subplot(111, aspect='equal')
        plt.plot([0, 1], [0, 0], 'k--', label="E-sensor")

        max_sensor = 0
        for i in range(len(norm_coeffs)):
            max_sensor = np.max((max_sensor, np.linalg.norm(norm_coeffs[i])))
        for i in range(len(norm_coeffs)):
            norm_coeffs[i] /= max_sensor
            plt.plot([0, norm_coeffs[i, 0]], [0, norm_coeffs[i, 1]],
                     label="Sensor %d" % i)

        circle = plt.Circle((0, 0), 1, fill=False)
        fig = plt.gcf()
        fig.gca().add_artist(circle)
        plt.ylim([-1, 1])
        plt.xlim([-1, 1])
        plt.legend()
        plt.title('Normalized Sensor vectors for Phase %d' % cur_phase)
        plt.axis("off")

        plt.ion()
        plt.show()

    # now we have p and q coeffs for each sensor
    # 1.) find the largest magnitude sensor
    # 2.) calculate its phase rotation ph = atan(Q/P)
    # 3.) find unit vector in direction of max vector
    # 4.) final_coeff: u*v where u is sensor and v is unit vector
    # 5.) scale coeffs to amps using settings in meters.yml

    final_coeffs = []

    # Step (1)
    sensor_abs = []
    for i in range(num_sensors):
        p = sensor_coeffs[i][0]
        q = sensor_coeffs[i][1]
        sensor_abs.append(np.sqrt(p * p + q * q))
    i = np.argmax(sensor_abs)
    # Step (2)
    ph = np.arctan2(sensor_coeffs[i][1], sensor_coeffs[i][0])
    R = np.array([[np.cos(ph), -1 * np.sin(ph)], [np.sin(ph), np.cos(ph)]])
    # Step (3)
    v = sensor_coeffs[i] / np.linalg.norm(sensor_coeffs[i])
    # Step (4)
    for j in range(num_sensors):
        x = v.dot(sensor_coeffs[j])
        final_coeffs.append(x)
    # Step (5)
    cal_watts = calibration_config["watts"]
    rms_voltage = float(meter["sensors"]["voltage"]["nominal_rms_voltage"])
    scale_factor = cal_watts * np.sqrt(2) / (rms_voltage)
    sensor_coeffs = np.array(final_coeffs) / scale_factor

    # see if we found the calibrator
    if (np.linalg.norm(final_coeffs) < MIN_CAL_SIGNAL):
        print("[ERROR] can't detect calibration load")
        return (2, "can't detect calibration load")

    # see if the calibrator is on a different phase
    if (cur_phase > 0):  # only check for phases > 0
        for phase in range(cur_phase):
            prev_coeffs = meter['calibration']["sensor_matrix"][:, phase]
            v = prev_coeffs / np.linalg.norm(prev_coeffs)
            u = sensor_coeffs / np.linalg.norm(sensor_coeffs)
            if (np.linalg.norm(v - u) < MIN_CAL_DIFFERENCE):
                print("[ERROR] calibration on same phase as before")
                return (3, "move plug to phase B")

    meter['calibration']["sensor_matrix"][:, cur_phase] = sensor_coeffs
    meter['calibration']["sinefit_rotations"][cur_phase] = ph
    meter['calibration']["pq_coeffs"][cur_phase, :, :] = norm_coeffs

    if (debug):
        # ---plot final_coeffs on a numberline defined by largest sensor--
        fig = plt.figure(facecolor="white")
        # make the number line 10% greater than the largest val
        max_coeff = np.max(np.abs(sensor_coeffs))
        arrowprops = dict(arrowstyle='simple',
                          facecolor="grey")

        plt.annotate("", xy=(-1.1 * max_coeff, 0), xytext=(0.3, 0),
                     arrowprops=arrowprops)
        plt.annotate("", xy=(1.1 * max_coeff, 0), xytext=(-0.3, 0),
                     arrowprops=arrowprops)

        plt.hlines(0, -1.1 * max_coeff, max_coeff * 1.1)
        for i in range(len(sensor_coeffs)):
            x = sensor_coeffs[i]
            if (x == max_coeff):
                plt.vlines(x, -2, 2, linewidth=4, color='blue',
                           label="Reference Sensor")
            else:
                plt.vlines(x, -2, 2)
            plt.text(x - 1, 5, "S%d" % i)
            plt.text(x - 1, -5, "%0.2f" % x, rotation=-45)

        plt.vlines(0, -3, 3, color='green', label="Zero")
        plt.vlines(0, 0, 0, label="Other Sensors")
        plt.xlim(-1.2 * max_coeff, max_coeff * 1.2)
        plt.ylim(-40, 40)
        plt.legend()
        plt.axis('off')
        plt.title("Sensor Coeffecients for Phase %d" % cur_phase)
        plt.show()
        x = input("Press [enter] to continue...")
    return (0, "OK")

def rollingmedian(data,window_size):
    return pd.Series(data).rolling(window = window_size,center=True,min_periods=1).median().values 


def detect_calibrator(x_input, PERIOD, XK1_INDEX, Bsize, d, k, ratio):
    """
    For input signal x, find the 3 Hz 
    calibration signal in the waveform and return
    its amplitude. The input waveform can have additional
    content besides the calibration signal

    Calculates the Discrete Fourier Series Coeff for the
    waveform at the harmonic of interest of the square wave signal.
    To increase the robustness of the calculation, the input waveform
    is divided in to BSize sample chunks and the DFS coeff's 
    are averaged together

    calculate the sign of each term by time shifting the waveform 
    and checking the second harmonic. If it is close to pi then
    this is a positive term, if its close to zero then it is negative
    """
    HALF_PERIOD = PERIOD/2
    XK2_INDEX = 2*XK1_INDEX 
    a=np.shape(x_input)
    #divide the input into blocks
    Nb=int(np.floor(a[0]/Bsize)) #number of blocks

    def is_odd(num):
        return num % 2 != 0

    on_samples = d*PERIOD

    window = (on_samples/k)
    if k == 1 and ratio > 1:
        window = window/ratio
    window = round(window)

    if (is_odd(window) == False):
        window += 1

    xk_sum = 0
    for n in range(Nb):
        N=Bsize
        x=x_input[(n)*Bsize:(n+1)*Bsize]

        x = rollingmedian(x,window)
        Y = fft(x)
        P1 = (2/Bsize)*(Y[0:(Bsize//2)])

        XK1 = P1[XK1_INDEX]
        XK2 = P1[XK2_INDEX]

        xk1 = np.abs(XK1)
        
        #correct the time shift
        R=Bsize*np.angle(XK1)/(2*np.pi*XK1_INDEX)
        xk2_ph=np.angle(XK2*np.exp(-1j*2*np.pi*XK2_INDEX*R/Bsize))

        if(np.abs(xk2_ph)>np.pi/2): 
            xk1 = -xk1

        if k > 1:
            if (is_odd(k) == False):
                xk1 = -xk1
        else:
            if ratio > 1:
                if (is_odd(round(ratio)) == False):
                    xk1 = -xk1

        xk_sum=xk_sum+xk1

    xk_avg = xk_sum/Nb

    ####################################
    # a_k = (2*A/k*pi)*sin(k*pi*Tp/T)  #
    # where Tp is the lengt of pulse   #
    # and T is the period.             #
    # The duty cycle, d = Tp/T         #
    ####################################

    A_scale = k*(np.pi/2) * (1/np.sin(k*np.pi*d))
    A = xk_avg*A_scale

    return A
