from numpy import diff, std, zeros
from numpy.linalg import matrix_power, pinv
from scipy.signal import bessel, lfilter, tf2ss
import nilmtools.filter

"""
Transient detection parameters:

* Transient start detection
    k1 = certainty necessary to declare that a transient has started.
         larger values decrease false positives.
         smaller values decrease detection delay.
    k2 = minimum detectable step size, relative to the standard deviation.
         larger values improve noise immunity.
         smaller values improve sensitivity.
* Transient end detection
    k3 = buffer size, in samples. this is also the minimum
           "quiet length" that must separate two transients.
         larger values spuriously combine transients.
         smaller values spuriously split transients.
    k4 = transient end threshhold. condition takes the form
           std(window) / std(diff(window)) < self.k4.
         For additive white noise, 1/sqrt(2) is the optimal value.
* Moving average filter
    k5 = order of filter (integer)
    k6 = cutoff frequency divided by Nyquist frequency
"""

def main(argv = None):
    f = nilmtools.filter.Filter()
    parser = f.setup_parser("Transient Finder")
    group = parser.add_argument_group("Transient parameters")
    group.add_argument("--k1", type=float, default=33.0,
                       help="transient start certainty threshhold")
    group.add_argument("--k2", type=float, default=1.0,
                       help="minimum detectable transient step size")
    group.add_argument("--k3", type=int, default=30,
                       help="buffer size")
    group.add_argument("--k4", type=float, default=0.75,
                       help="transient end threshhold")
    group.add_argument("--k5", type=int, default=3,
                       help="order of mean filter")
    group.add_argument("--k6", type=float, default=0.017,
                       help="mean filter cutoff frequency")
    args = f.parse_args(argv)
    finder = TransientFinder(args.k1, args.k2, args.k3, args.k4, args.k5, args.k6)
    f.process_numpy(finder.process)


# Bessel low-pass filter with initialization (Chornoboy 1990)
class LowPass:
    def __init__(self, init_length=30, filt_order=3, filt_cutoff=1./60):
        self.b, self.a = bessel(filt_order, filt_cutoff)
        assert self.a[0] == 1.0
        # State space description of the linear system:
        # q[i+1] = A*q[i] + B*x[i]  and  y[i] = C*q[i] + D*x[i].
        A, B, C, D = tf2ss(self.b, self.a)
        # convert to observer canonical form (for compatibility with lfilter)
        A, B, C = A.T, C.T, B.T
        # Let X = column vector x[0]...x[init_length-1] and
        # Y = column vector y[0]...y[init_length-1].
        # We construct F and G such that  Y - X = F*q[0] + G*X.
        F = zeros((init_length, filt_order))
        G = zeros((init_length, init_length))
        for i in range(init_length):
            F[i] = C.dot(matrix_power(A,i))
            G[i:].flat[::init_length+1] = (C.dot(matrix_power(A,i-1)).dot(B)
                                           if i>0 else D-1)
        # Suppose that we want to choose q[0] = M*X for some constant matrix M.
        # We minimize the error (F*M + G) * X by choosing M to be the projection
        # given by -pseudoinverse(F)*G.
        self.M = -pinv(F).dot(G)
    def reset(self, init_data):
        """Initialize filter state to minimize error over init_data"""
        self.z = self.M.dot(init_data)
        self.update(init_data)
    def update(self, in_data):
        """Run filter on a list of inputs, returning a list of outputs"""
        out_data, self.z = lfilter(self.b, self.a, in_data, zi=self.z)
        return out_data

# CUSUM step detector (Granjon 2012)
class StepWatcher:
    def __init__(self, threshold, offset):
        self.threshold, self.offset = threshold, offset
    def reset(self):
        self.s_argmin, self.s_min = None, 0
        self.s = self.g = 0
    def update(self, index, sample):
        self.s += sample - self.offset
        if self.s < self.s_min:
            self.s_argmin, self.s_min = index, self.s
        self.g = max(self.g + sample - self.offset, 0)
        return self.s_argmin if self.g > self.threshold else None

# Now we're ready to actually find some transients
class TransientFinder:
    def __init__(self, k1, k2, k3, k4, k5, k6):
        self.up_watcher = StepWatcher(k1, k2)
        self.down_watcher = StepWatcher(k1, k2)
        self.mean_filter = LowPass(k3, k5, k6)
        self.k3, self.k4 = k3, k4
        self.reset()
    def reset(self):
        self.buffer = []
        self.in_transient = False
        self.starting = True
        # transient_over will handle the remainder of member initialization.
    def transient_start(self, timestamp):
        """Update state after a transient has been detected"""
        # Discard data until transient start is at the head of the buffer
        while self.buffer[0][0] < timestamp:
            self.pop()
        self.in_transient = True
        # Discard one sample of the transient, so that it lasts at least 1 sample
        if self.pop()[0] > timestamp:
            print("Warning: transient start not buffered. Input is pathological.")
    def transient_over(self):
        """Update state after the end of a transient"""
        self.starting = False
        self.in_transient = False
        self.up_watcher.reset()
        self.down_watcher.reset()
        self.mean_filter.reset(self.window)
        self.std = max(std(self.window), 0.001)  # TODO: make minimum non-arbitrary
    def pop(self):
        """Remove one element from the buffer and append it to this step's output"""
        timestamp, value = self.buffer.pop(0)
        self.result.append((timestamp, self.in_transient))
        return (timestamp, value)
    def update(self, timestamp, value):
        self.result = []
        # Pop the oldest sample if the buffer is full.
        if len(self.buffer) >= self.k3:
            self.pop()
        # Add new sample to buffer.
        self.buffer.append((timestamp, value))
        # If we're in a transient, check whether it has ended.
        if self.in_transient or self.starting:
            # Only check for transient end if buffer is full.
            if len(self.buffer) == self.k3:
                self.window = [b[1] for b in self.buffer]
                if std(self.window) <= std(diff(self.window)) * self.k4:
                    self.transient_over()
        # If we're not in a transient, check whether one has started.
        else:
            err = value - self.mean_filter.update([value])[0]
            u_step = self.up_watcher.update(timestamp, err / self.std)
            d_step = self.down_watcher.update(timestamp, -err / self.std)
            if u_step is not None:
                self.transient_start(u_step)
            elif d_step is not None:
                self.transient_start(d_step)
        # Return the old samples that were popped during this timestep.
        # Number of samples returned is between 0 and k3-1 inclusive.
        return self.result
    def process(self, data, interval, args, insert_func, final):
        """Wrapper around update() for nilm filter interface."""
        for timestamp, power in data[:, :2]:  # Only want T and P1
            output = self.update(timestamp, power)
            if output: insert_func(output)
        if final:
            self.reset()  # reset if we're done with an interval
        return len(data)  # we always process all data in an interval

# Helper function to run transient detector over a 1d list
def transients(arr, *args, **kwargs):
    t = TransientFinder(*args, **kwargs)
    result = []
    for i, d in enumerate(arr):
        result.extend([p[1] for p in t.update(i,d)])
    return result


if __name__ == "__main__":
    main()
