#!/usr/bin/python

import nilmdb.client
from nilmdb.client import Client
from nilmdb.client.numpyclient import NumpyClient
from nilmdb.utils.printf import *
from nilmdb.utils.time import (parse_time, timestamp_to_human,
                               timestamp_to_seconds)
from nilmdb.utils.interval import Interval

import nilmtools

import itertools
import argparse
import numpy as np
import json

#import ipdb
#use: ipdb.set_trace()

version = 0.1


class StreamInfo(object):
    def __init__(self, url, info, config):
        self.url = url
        self.info = info
        try:
            self.path = info[0]
            self.layout = info[1]
            self.layout_type = self.layout.split('_')[0]
            self.layout_count = int(self.layout.split('_')[1])
            self.total_count = self.layout_count + 1
            self.timestamp_min = info[2]
            self.timestamp_max = info[3]
            self.rows = info[4]
            self.seconds = nilmdb.utils.time.timestamp_to_seconds(info[5])
            if(not config):
                raise Exception("missing NilmManager config metadata for stream %s" % self.path)

            config = json.loads(config)
            self.scaleFactors = [1.0] #initialize with the timestamp scaleFactor
            self.offsets = [0.0] #initialize with the timestamp offset
            for i in range(self.layout_count):
                self.scaleFactors.append(float(config['streams'][i]['scale_factor']))
                self.offsets.append(float(config['streams'][i]['offset']))
            self.scaleFactors = np.array(self.scaleFactors)
            self.offsets = np.array(self.offsets)

        except IndexError, TypeError:
            pass

    def process_args(self, args, errorFunc):
        try:
            args = args.split(',');
            self.transforms = [];
            for arg in args:
                arg = arg.split(':')
                #check if this input should be a single stream or 
                # "all" streams
                if arg[0]=="all":
                    transform = {
                        'input_col': "all",
                        'output_col': int(arg[1])+1,
#                        'offset': self.offsets #need all of them
#                        'scaleFactor': self.scaleFactors #need all of them
                        }
                else:
                    transform = {
                        'input_col':   int(arg[0])+1,
                        'output_col':  int(arg[1])+1,
#                        'offset':      self.offsets[int(arg[0])],
#                        'scaleFactor': self.scaleFactors[int(arg[0])]
                        }
                self.transforms.append(transform)
#            if(isResampled):
#                self.transforms.sort(key=lambda transform: transform['input_col'])

        except IndexError:
            embed()
            errorFunc("requested columns are not available on stream %s" % self.path)

    def transform_data(self, data):
        #subtract the offset
        data = data-self.offsets

        #multiply by the scale factor
        data = data * self.scaleFactors
        #build a new matrix with only the desired data
        #the matrix is 3D with each element as
        #[[t1, x1],[t2,x2]...],[[t1,y1],[t2,y2],...]

        transformed = []#np.zeros((len(self.transforms),np.shape(data)[0],2))
#        i=0
        for t in self.transforms:
            if(t['input_col']=="all"):
                transformed.append(data)
            else:
                transformed.append(data[:,[0,t['input_col']]])
 #           i+=1
        return transformed

    def string(self, interhost):
        """Return stream info as a string.  If interhost is true,
        include the host URL."""
        if interhost:
            return sprintf("[%s] ", self.url) + str(self)
        return str(self)

    def __str__(self):
        """Return stream info as a string."""
        return sprintf("%s (%s), %.2fM rows, %.2f hours",
                       self.path, self.layout, self.rows / 1e6,
                       self.seconds / 3600.0)

def get_stream_info(client, path):
    """Return a StreamInfo object about the given path, or None if it
    doesn't exist"""
    streams = client.stream_list(path, extended = True)
    if len(streams) != 1:
        return None
    config = client.stream_get_metadata(path,"config_key__")
    config = config['config_key__']
    return StreamInfo(client.geturl(), streams[0], config)

class Boundary(object):
    def __init__(time,isEnd):
        self.time=time
        self.isEnd=isEnd

class Filter(object):

    def setup_parser(self, description = "Filter data"):
        parser = argparse.ArgumentParser(
            formatter_class = argparse.RawDescriptionHelpFormatter,
            description = description)

        group = parser.add_argument_group("General filter arguments")

        group.add_argument('-i', '--input',  nargs='+', action="store",
                           metavar = 'stream', help='source stream(s) formatted: str_name[x:y,x:y,...] where x is the actual column number in the stream and y is the desired column number of the formatted output')
        group.add_argument('-o', '--output', action="store",
                           metavar = 'stream', help='destination stream')
        group.add_argument("-s", "--start",
                           metavar="TIME", type=self.arg_time,
                           help="Starting timestamp"
                           "(free-form, inclusive)")
        group.add_argument("-e", "--end",
                           metavar="TIME", type=self.arg_time,
                           help="Ending timestamp"
                           "(free-form, noninclusive)")
        group.add_argument("-u", "--url", action="store",
                           default="http://localhost/nilmdb/",
                           help = "NilmDB URL (default: %(default)s)")
        group.add_argument("--no_output", help="no output stream", action="store_true")
        self._parser = parser
        return parser

    def parse_args(self, argv = None):
        args = self._parser.parse_args(argv)
        self._inputStreams=args.input
        self._outputStream=args.output
        if( not self._inputStreams):
            self._parser.error("missing input streams")
        #check if we have an output stream
        self.no_output = args.no_output
        if(self.no_output and args.output):
            self._parser.error("no_output=True but an output stream was given")
        if(self.no_output == False and not self._outputStream):
            self._parser.error("missing output stream")

        for path in self._inputStreams:
            if(self._outputStream == path):
                self._parser.error("the output stream cannot be "+
                                   "one of the input streams");

        self._client = NumpyClient(args.url)
        self.sources  =  []
        
        #inputStream is /input/stream/name[x:y,x:y,...]
        # x: a number or "all", if it is a number the input is the specified
        #    column of the stream, if it is "all" then use all columns
        for inputStream in self._inputStreams:
            streamName = inputStream[0:inputStream.index('[')]
            source = get_stream_info(self._client, streamName)
            if(not source):
                self._parser.error("source path [" + streamName + "] not found")
            source.process_args(inputStream.split('[')[1][:-1], 
                                self._parser.error)
            self.sources.append(source)

        #figure out how many stream inputs are actually requested
        self.inputWidth = 0
        for source in self.sources:
            self.inputWidth+=len(source.transforms)


        #make sure all of the input stream column requests are unique
        #and sum up to totalWidth
        for i in range(self.inputWidth):
            #make sure one and only one stream is request column 'i'
            reserved = False
            for source in self.sources:
                for transform in source.transforms:
                    if(transform['output_col']-1==i):
                        if(not reserved):
                            reserved = True
                        else:
                            self._parser.error("multiple mappings to column %d, check input stream mapping" % i)
            if(not reserved):
                self._parser.error("column %d has no input, check input stream mapping" % i)

        #setup the output stream
        if(not self.no_output):
            self.destination = get_stream_info(self._client, self._outputStream)
            if(not self.destination):
                self._parser.error("destination path [" + 
                                   self._outputStream + "] not found")

        #setup the start and end times
        if(args.start==None):
            self._parser.error("missing start time");
        if(args.end==None):
            self._parser.error("missing end time");

        self.start = args.start
        self.end = args.end

        #all done
        return args

    def initState(self, state):
        pass #no state to save

    # Misc helpers
    def arg_time(self, toparse):
        """Parse a time string argument"""
        try:
            return nilmdb.utils.time.parse_time(toparse)
        except ValueError as e:
            raise argparse.ArgumentTypeError(sprintf("%s \"%s\"",
                                                     str(e), toparse))

    #manually set the start and end bounds to override the 
    #args values
    def resetInterval(self,start,end):
        self.start=start
        self.end=end

    def sourceExtents(self):
        """return an Interval object spanning 
        the max(start) and min(end) of the source streams"""
        start_time = None
        end_time = None
        for source in self.sources:
            if(start_time==None):
                start_time = source.timestamp_min
            elif(start_time<source.timestamp_min):
                start_time = source.timestamp_min
            if(end_time==None):
                end_time = source.timestamp_max
            elif(end_time>source.timestamp_max):
                end_time = source.timestamp_max
        if(start_time==None or end_time==None):
            return None #no data!

        return Interval(start_time,end_time)
                
    def calcIntervalEnd(self,source,start,maxrows,maxend):
        """recursively find the maximum ending timestamp to get
        maxrows out of this stream given the starting 
        timestamp 'start' """
        count=self._client.stream_count(source.path, start, maxend)        
        if(count<=maxrows):
            return maxend
        #assume uniform sampling in order to calculate new maxend
        duration=maxend-start;
        rate = float(count)/duration;
        maxend = (0.9*maxrows)/rate + start;
        return self.calcIntervalEnd(source,start,maxrows,maxend)

    
    def intervals(self):
        """Generate all the intervals that this filter should process
        Calculate the union using a sweepline algorithm
        Binary flags indicate which stream owns the interval
        1.) add all the interval boundaries to a list
        2.) sort the list
        3.) sweep through the list incrementing a counter on start boundaries and
            decrementing it on ending boundaries the counter is incremented 
            by a different power of 2 for each stream so a valid interval
            is indicated when the counter is at its maximum possible value
            (max_flag)

        """
    
        _intervals=[]
        boundaries=[]
        # Step 1
        flag=0
        max_flag=0
        for source in self.sources:
            for (start, end) in self._client.stream_intervals(
                source.path, start = self.start, end = self.end):
                boundaries.extend(((start,2**flag),(end,-(2**flag))))
            max_flag+=2**flag
            flag+=1

        # Step 2
        boundaries = sorted(boundaries, key=lambda boundary: boundary[0]) 


        i=0
        start = None
        #step 3
        count=0
        for boundary in boundaries:
            count+=boundary[1]
            if(count==max_flag):
                start = boundary[0]
            if(count!=max_flag and start):
                _intervals.append(Interval(start,boundary[0]))
                start = None
        for interval in self._optimize_int(_intervals):
            yield interval

    def _optimize_int(self, it):
        """Join and yield adjacent intervals from the iterator 'it'"""
        saved_int = None
        for interval in it:
            if saved_int is not None:
                if saved_int.end == interval.start:
                    interval.start = saved_int.start
                else:
                    yield saved_int
            saved_int = interval
        if saved_int is not None:
            yield saved_int


    
    def _reorder(self,data):
        """Reorder the data to match the requested output column formatting
        specified by the user. This information is stored in the transform objects
        for each source the incoming data is already grouped by source """
        ordered_data = range(len(data))
        k=0
        for source in self.sources:
            i=0
            for t in source.transforms:
                ordered_data[t['output_col']-1] = data[i+k]
                i+=1
            k+=i
        return ordered_data


    def dummy_inserter(self, path, data, start = None, end = None,
                            layout = None):
        yield DummyInserter()
        
    def process(self, function, state, args = None, rows = 100000):
        """Call 'function' with a Numpy array corresponding to
        the data between 'start' and 'end'.  The data is converted 
        to a Numpy array in chunks of 'rows' rows at a time.

        'function' should be defined as:
           def function(data, interval, args, insert_func, state)

        'data': array of data to process -- may be empty

        'interval': 'start' and 'end' parameters (not necessarily
        the interval of this particular chunk of data if there is
        more than 'rows' items between 'start' and 'end')

        'args': opaque arguments passed to process_numpy

        'insert_func': function to call in order to insert array of data.
        Should be passed a 2-dimensional array of data to insert.
        Data timestamps must be within the provided interval.

        'state': an object representing any persistent state of the 
        function. For stateless filters this can be ignored

        Return value of 'function' is the state object representing any
        persisten state of the function. For stateless filters this can be None.

        """
        if args is None:
            args = {}
        else:
            args = vars(args) #convert namespace to dictionary

        #if no output stream is specified, use a dummy inserter
        if(self.no_output):
            inserter=DummyClient
            self.destination_path = ""
        else:
            inserter = self._client.stream_insert_numpy_context
            self.destination_path = self.destination.path

        extractor = self._client.stream_extract_numpy
        last_timestamp_processed = -1
        for interval in self.intervals():
            last_timestamp_processed=interval.end
            curStart = interval.start
            with inserter(self.destination_path,
                          interval.start, interval.end) as insert_ctx:
                insert_function = insert_ctx.insert
                print interval
                while curStart < interval.end:
                    curEnd=None

                    #Find out how long this interval can be given the source streams
                    for source in self.sources:
                        sourceEnd = self.calcIntervalEnd(source,
                                                         curStart,rows,interval.end);
                        if(not curEnd):
                            curEnd = sourceEnd
                        else:
                            curEnd=min(curEnd,sourceEnd)
                                       
                    #extract data from sources over the [curStart,curEnd) interval
                    unordered_data = []
                    empty_source = False #flag set True if one of the sources has no data over this interval
                    for source in self.sources:
                        temp = []
                        for chunk in extractor(source.path,curStart,curEnd):
                            temp.extend(chunk)
                        if(len(temp)==0): ##*******SKIP THIS*******
                            #NOTE: Before, we looked ahead to expand the interval,
                            #this seems to be bad (eg when you have lots of broken up intervals
                            #it breaks when you skip ahead to the *next* interval). BUT this was here
                            #for resampling so if there are errors in the future with resampling, take a look
                            #at this code again
                            # Error scenario
                            # Trainola output is sparse over an interval:
                            #     |       .       . . .         |    | ..           . |
                            # 1.)  <--process this much    -->
                            # 2.)                             <-> process the "remainder"
                            # On run 2, the code below jumps into the *next* interval without
                            # updating curEnd so you get a bad data insertion error
                            #
                            empty_source = True
                            break
                            #-------begin *bad* code----------
                            #no points over this interval (bw << master)
                            #pull a wider time range until we get at least one value
                            #only look ahead becuase the resampler requires future timestamps
                            #for interpolation

                            factor = curEnd-curStart
                            while(len(temp)==0):
                                for chunk in extractor(source.path,curStart,curEnd+factor):
                                    temp.extend(chunk)
                                factor=factor*2
#                            print "got %d points over the interval [%s:%s]" % \
#                                (len(temp),timestamp_to_human(curStart-factor),timestamp_to_human(curEnd+factor))
                        

                        temp = source.transform_data(temp)
                        unordered_data += temp
                    if(empty_source):
                        curStart = curEnd
                        printf("no more data, skipping\n")
                        continue
                    ordered_data = self._reorder(unordered_data)
                    #print the timestamp of this interval so we can track progress
                    printf("###%d\n" % curStart)

                    #run the filter function
#                    if(self.no_output):
#                        function(ordered_data,Interval(curStart,curEnd),
#                                         args,state)
#                    else:
                    if(curEnd==interval.end):
                        args["isEnd"]=True
                    else:
                        args["isEnd"]=False
                    function(ordered_data,Interval(curStart,curEnd),
                          args,insert_function, state)

                    curStart=curEnd
        return last_timestamp_processed

class DummyClient:
    def __init__(self,path,start,end):
        pass

    def insert(self, array):
        print "Error: dummy insert function should not be called"        
    def __enter__(self):
        return DummyInserter()
    def __exit__(self,path,start,end):
        pass

class DummyInserter:
    def insert(self, array):
        print "Error: dummy insert function should not be called"        

if __name__ == "__main__":
    print "Do not call this script directly, must be used by a filter"
    exit(1)


