# cython: language_level=2

"""Interval, IntervalSet

The Interval implemented here is just like
nilmdb.utils.interval.Interval, except implemented in Cython for
speed.

Represents an interval of time, and a set of such intervals.

Intervals are half-open, ie. they include data points with timestamps
[start, end)
"""

# First implementation kept a sorted list of intervals and used
# biesct() to optimize some operations, but this was too slow.

# Second version was based on the quicksect implementation from
# python-bx, modified slightly to handle floating point intervals.
# This didn't support deletion.

# Third version is more similar to the first version, using a rb-tree
# instead of a simple sorted list to maintain O(log n) operations.

# Fourth version is an optimized rb-tree that stores interval starts
# and ends directly in the tree, like bxinterval did.

from ..utils.time import min_timestamp as nilmdb_min_timestamp
from ..utils.time import max_timestamp as nilmdb_max_timestamp
from ..utils.time import timestamp_to_string
from ..utils.iterator import imerge
from ..utils.interval import IntervalError
import itertools

cimport rbtree
from libc.stdint cimport uint64_t, int64_t

ctypedef int64_t timestamp_t

cdef class Interval:
    """Represents an interval of time."""

    cdef public timestamp_t start, end

    def __init__(self, timestamp_t start, timestamp_t end):
        """
        'start' and 'end' are arbitrary numbers that represent time
        """
        if start >= end:
            # Explicitly disallow zero-width intervals (since they're half-open)
            raise IntervalError("start %s must precede end %s" % (start, end))
        self.start = start
        self.end = end

    def __repr__(self):
        s = repr(self.start) + ", " + repr(self.end)
        return self.__class__.__name__ + "(" + s + ")"

    def __str__(self):
        return ("[" + timestamp_to_string(self.start) +
                " -> " + timestamp_to_string(self.end) + ")")

    # Compare two intervals.  If non-equal, order by start then end
    def __lt__(self, Interval other):
        return (self.start, self.end) < (other.start, other.end)
    def __gt__(self, Interval other):
        return (self.start, self.end) > (other.start, other.end)
    def __le__(self, Interval other):
        return (self.start, self.end) <= (other.start, other.end)
    def __ge__(self, Interval other):
        return (self.start, self.end) >= (other.start, other.end)
    def __eq__(self, Interval other):
        return (self.start, self.end) == (other.start, other.end)
    def __ne__(self, Interval other):
        return (self.start, self.end) != (other.start, other.end)

    cpdef intersects(self, Interval other):
        """Return True if two Interval objects intersect"""
        if (self.end <= other.start or self.start >= other.end):
            return False
        return True

    cpdef subset(self, timestamp_t start, timestamp_t end):
        """Return a new Interval that is a subset of this one"""
        # A subclass that tracks additional data might override this.
        if start < self.start or end > self.end:
            raise IntervalError("not a subset")
        return Interval(start, end)

cdef class DBInterval(Interval):
    """
    Like Interval, but also tracks corresponding start/end times and
    positions within the database.  These are not currently modified
    when subsets are taken, but can be used later to help zero in on
    database positions.

    The actual 'start' and 'end' will always fall within the database
    start and end, e.g.:
        db_start = 100, db_startpos = 10000
        start = 123
        end = 150
        db_end = 200, db_endpos = 20000
    """

    cpdef public timestamp_t db_start, db_end
    cpdef public uint64_t db_startpos, db_endpos

    def __init__(self, start, end,
                 db_start, db_end,
                 db_startpos, db_endpos):
        """
        'db_start' and 'db_end' are arbitrary numbers that represent
        time.  They must be a strict superset of the time interval
        covered by 'start' and 'end'.  The 'db_startpos' and
        'db_endpos' are arbitrary database position indicators that
        correspond to those points.
        """
        Interval.__init__(self, start, end)
        self.db_start = db_start
        self.db_end = db_end
        self.db_startpos = db_startpos
        self.db_endpos = db_endpos
        if db_start > start or db_end < end:
            raise IntervalError("database times must span the interval times")

    def __repr__(self):
        s = repr(self.start) + ", " + repr(self.end)
        s += ", " + repr(self.db_start) + ", " + repr(self.db_end)
        s += ", " + repr(self.db_startpos) + ", " + repr(self.db_endpos)
        return self.__class__.__name__ + "(" + s + ")"

    cpdef subset(self, timestamp_t start, timestamp_t end):
        """
        Return a new DBInterval that is a subset of this one
        """
        if start < self.start or end > self.end:
            raise IntervalError("not a subset")
        return DBInterval(start, end,
                          self.db_start, self.db_end,
                          self.db_startpos, self.db_endpos)

cdef class IntervalSet:
    """
    A non-intersecting set of intervals.
    """

    cdef public rbtree.RBTree tree

    def __init__(self, source=None):
        """
        'source' is an Interval or IntervalSet to add.
        """
        self.tree = rbtree.RBTree()
        if source is not None:
            self += source

    def __iter__(self):
        for node in self.tree:
            if node.obj:
                yield node.obj

    def __len__(self):
        return sum(1 for x in self)

    def __repr__(self):
        descs = [ repr(x) for x in self ]
        return self.__class__.__name__ + "([" + ", ".join(descs) + "])"

    def __str__(self):
        descs = [ str(x) for x in self ]
        return  "[" + ", ".join(descs) + "]"

    def __match__(self, other):
        # This isn't particularly efficient, but it shouldn't get used in the
        # general case.
        """Test equality of two IntervalSets.

        Treats adjacent Intervals as equivalent to one long interval,
        so this function really tests whether the IntervalSets cover
        the same spans of time."""
        i = 0
        j = 0
        outside = True

        def is_adjacent(a, b):
            """Return True if two Intervals are adjacent (same end or start)"""
            if a.end == b.start or b.end == a.start:
                return True
            else:
                return False

        this = list(self)
        that = list(other)

        try:
            while True:
                if (outside):
                    # To match, we need to be finished both sets
                    if (i >= len(this) and j >= len(that)):
                        return True
                    # Or the starts need to match
                    if (this[i].start != that[j].start):
                        return False
                    outside = False
                else:
                    # We can move on if the two interval ends match
                    if (this[i].end == that[j].end):
                        i += 1
                        j += 1
                        outside = True
                    else:
                        # Whichever ends first needs to be adjacent to the next
                        if (this[i].end < that[j].end):
                            if (not is_adjacent(this[i],this[i+1])):
                                return False
                            i += 1
                        else:
                            if (not is_adjacent(that[j],that[j+1])):
                                return False
                            j += 1
        except IndexError:
            return False

    # Use __richcmp__ instead of __eq__, __ne__ for Cython.
    def __richcmp__(self, other, int op):
        if op == 2: # ==
            return self.__match__(other)
        elif op == 3: # !=
            return not self.__match__(other)
        return False
    #def __eq__(self, other):
    #    return self.__match__(other)
    #
    #def __ne__(self, other):
    #    return not self.__match__(other)

    def __iadd__(self, object other not None):
        """Inplace add -- modifies self

        This throws an exception if the regions being added intersect."""
        if isinstance(other, Interval):
            if self.intersects(other):
                raise IntervalError("Tried to add overlapping interval "
                                    "to this set")
            self.tree.insert(rbtree.RBNode(other.start, other.end, other))
        else:
            for x in other:
                self.__iadd__(x)
        return self

    def iadd_nocheck(self, Interval other not None):
        """Inplace add -- modifies self.
        'Optimized' version that doesn't check for intersection and
        only inserts the new interval into the tree."""
        self.tree.insert(rbtree.RBNode(other.start, other.end, other))

    def __isub__(self, Interval other not None):
        """Inplace subtract -- modifies self

        Removes an interval from the set.  Must exist exactly
        as provided -- cannot remove a subset of an existing interval."""
        i = self.tree.find(other.start, other.end)
        if i is None:
            raise IntervalError("interval " + str(other) + " not in tree")
        self.tree.delete(i)
        return self

    def __add__(self, other not None):
        """Add -- returns a new object"""
        new = IntervalSet(self)
        new += IntervalSet(other)
        return new

    def __and__(self, other not None):
        """
        Compute a new IntervalSet from the intersection of this
        IntervalSet with one other interval.

        Output intervals are built as subsets of the intervals in the
        first argument (self).
        """
        out = IntervalSet()
        for i in self.intersection(other):
            out.tree.insert(rbtree.RBNode(i.start, i.end, i))
        return out

    def intersection(self, Interval interval not None, orig = False):
        """
        Compute a sequence of intervals that correspond to the
        intersection between `self` and the provided interval.
        Returns a generator that yields each of these intervals
        in turn.

        Output intervals are built as subsets of the intervals in the
        first argument (self).

        If orig = True, also return the original interval that was
        (potentially) subsetted to make the one that is being
        returned.
        """
        if orig:
            for n in self.tree.intersect(interval.start, interval.end):
                i = n.obj
                subset = i.subset(max(i.start, interval.start),
                                  min(i.end, interval.end))
                yield (subset, i)
        else:
            for n in self.tree.intersect(interval.start, interval.end):
                i = n.obj
                subset = i.subset(max(i.start, interval.start),
                                  min(i.end, interval.end))
                yield subset

    cpdef intersects(self, Interval other):
        """Return True if this IntervalSet intersects another interval"""
        for n in self.tree.intersect(other.start, other.end):
            if n.obj.intersects(other):
                return True
        return False

    def find_end(self, timestamp_t t):
        """
        Return an Interval from this tree that ends at time t, or
        None if it doesn't exist.
        """
        n = self.tree.find_left_end(t)
        if n and n.obj.end == t:
            return n.obj
        return None
