"""
Live decimation tool
-decimates prep data
"""

import nilmdb.client.numpyclient
import numpy as np
import sys
import warnings
warnings.simplefilter(action = "ignore", category = FutureWarning)

class Decimator:
    def __init__ (self,level,  base_name, factor=4, width=9, again=False):
        #create a numpy client
        self.client = nilmdb.client.numpyclient.NumpyClient(
            "http://localhost/nilmdb")
        #set the again flag: if true, incoming data is decimated
        self.again=again
        self.level=level
        self.factor=factor
        self.width = width
        self.base_name=base_name
        #check if the stream exists
        self.path = "%s~decim-%d"%(base_name,level)
        res  = self.client.stream_list(path = self.path)
        if(len(res)==0):
            #create the stream
            self.client.stream_create(self.path,"float32_%d"%(width*3))
        self.start_ts = 0
        self.last_ts = 0
        self.child = None #next level of decimation
        self.old_data = None #data waiting to be decimated

    #close out the interval when data is missing
    def close_interval(self):
        self.last_ts = 0
        self.old_data = None
        if(self.child != None):
            self.child.close_interval()

    def process(self,data):
      #check if there is old data
        if(self.old_data!=None):
        #append the new data onto the old data
            data = np.concatenate((self.old_data,data))

        (n, m) = data.shape

        # Figure out which columns to use as the source for mean, min, and max,
        # depending on whether this is the first decimation or we're decimating
        # again.  Note that we include the timestamp in the means.
        if self.again:
            c = (m - 1) // 3
            # e.g. c = 3
            # ts mean1 mean2 mean3 min1 min2 min3 max1 max2 max3
            mean_col = slice(0, c + 1)
            min_col = slice(c + 1, 2 * c + 1)
            max_col = slice(2 * c + 1, 3 * c + 1)
        else:
            mean_col = slice(0, m)
            min_col = slice(1, m)
            max_col = slice(1, m)

        # Discard extra rows that aren't a multiple of factor
        n = n // self.factor * self.factor

        if(n==0): #not enough data to work with, save it for later
            self.old_data = data
            return

        trunc_data = data[:n,:]
        # keep the leftover data
        self.old_data=np.copy(data[n:,:])

        # Reshape it into 3D so we can process 'factor' rows at a time
        trunc_data = trunc_data.reshape(n // self.factor, self.factor, m)

        # Fill the result
        out = np.c_[ np.mean(trunc_data[:,:,mean_col], axis=1),
                     np.min(trunc_data[:,:,min_col], axis=1),
                     np.max(trunc_data[:,:,max_col], axis=1) ]

        # set up the interval
        if(self.last_ts==0):
            start_ts = int(out[0][0])
        else:
            start_ts = self.last_ts
        self.last_ts = int(out[-1][0])+1
#        sys.stderr.write( "delta_start: %f | level: %d" % ((out[0][0]-start_ts)/1e6, self.level))

        # insert the data into the database
        with self.client.stream_insert_numpy_context(self.path,
                                                     start=start_ts, 
                                                     end=self.last_ts) as ctx: 
            ctx.insert(out)

        # now call the child decimation object
        if(self.child==None):
            self.child = Decimator(self.level*self.factor,self.base_name,
                                   width=self.width,factor=self.factor,again=True)
        self.child.process(out)


if __name__=="__main__":
    d = Decimator(4,"/data/adsf")
