# -*- coding: utf-8 -*-

import nilmdb
from nilmdb.utils.printf import *
from nilmdb.utils import datetime_tz

from nose.tools import *
from nose.tools import assert_raises
import itertools

from nilmdb.utils.interval import IntervalError
from nilmdb.server.interval import Interval, DBInterval, IntervalSet

# so we can test them separately
from nilmdb.utils.interval import Interval as UtilsInterval

from testutil.helpers import *
import unittest

# set to False to skip live renders
do_live_renders = False
def render(iset, description = "", live = True):
    import testutil.renderdot as renderdot
    r = renderdot.RBTreeRenderer(iset.tree)
    return r.render(description, live and do_live_renders)

def makeset(string):
    """Build an IntervalSet from a string, for testing purposes

    Each character is 1 second
    [ = interval start
    | = interval end + next start
    ] = interval end
    . = zero-width interval (identical start and end)
    anything else is ignored
    """
    iset = IntervalSet()
    for i, c in enumerate(string):
        day = i + 10000
        if (c == "["):
            start = day
        elif (c == "|"):
            iset += Interval(start, day)
            start = day
        elif (c == ")"):
            iset += Interval(start, day)
            del start
        elif (c == "."):
            iset += Interval(day, day)
    return iset

class TestInterval:
    def test_client_interval(self):
        # Run interval tests against the Python version of Interval.
        global Interval
        NilmdbInterval = Interval
        Interval = UtilsInterval
        self.test_interval()
        self.test_interval_intersect()
        Interval = NilmdbInterval

        # Other helpers in nilmdb.utils.interval
        i = [ UtilsInterval(1,2), UtilsInterval(2,3), UtilsInterval(4,5) ]
        eq_(list(nilmdb.utils.interval.optimize(i)),
            [ UtilsInterval(1,3), UtilsInterval(4,5) ])
        eq_(UtilsInterval(1234567890123456, 1234567890654321).human_string(),
            "[ Fri, 13 Feb 2009 18:31:30.123456 -0500 -> " +
            "Fri, 13 Feb 2009 18:31:30.654321 -0500 ]")

    def test_interval(self):
        # Test Interval class
        os.environ['TZ'] = "America/New_York"
        datetime_tz._localtz = None
        (d1, d2, d3) = [ nilmdb.utils.time.parse_time(x)
                         for x in [ "03/24/2012", "03/25/2012", "03/26/2012" ] ]

        # basic construction
        i = Interval(d1, d2)
        i = Interval(d1, d3)
        eq_(i.start, d1)
        eq_(i.end, d3)

        # assignment is allowed, but not verified
        i.start = d2
        #with assert_raises(IntervalError):
        #    i.end = d1
        i.start = d1
        i.end = d2

        # end before start
        with assert_raises(IntervalError):
            i = Interval(d3, d1)

        # compare
        assert(Interval(d1, d2) == Interval(d1, d2))
        assert(Interval(d1, d2) < Interval(d1, d3))
        assert(Interval(d1, d3) > Interval(d1, d2))
        assert(Interval(d1, d2) < Interval(d2, d3))
        assert(Interval(d1, d3) < Interval(d2, d3))
        assert(Interval(d2, d2+1) > Interval(d1, d3))
        assert(Interval(d3, d3+1) == Interval(d3, d3+1))
        #with assert_raises(TypeError): # was AttributeError, that's wrong
        #    x = (i == 123)

        # subset
        eq_(Interval(d1, d3).subset(d1, d2), Interval(d1, d2))
        with assert_raises(IntervalError):
            x = Interval(d2, d3).subset(d1, d2)

        # big integers, negative integers
        x = Interval(5000111222000000, 6000111222000000)
        eq_(str(x), "[5000111222000000 -> 6000111222000000)")
        x = Interval(-5000111222000000, -4000111222000000)
        eq_(str(x), "[-5000111222000000 -> -4000111222000000)")

        # misc
        i = Interval(d1, d2)
        eq_(repr(i), repr(eval(repr(i))))
        eq_(str(i), "[1332561600000000 -> 1332648000000000)")

    def test_interval_intersect(self):
        # Test Interval intersections
        dates = [ 100, 200, 300, 400 ]
        perm = list(itertools.permutations(dates, 2))
        prod = list(itertools.product(perm, perm))
        should_intersect = {
            False: [4, 5, 8, 20, 48, 56, 60, 96, 97, 100],
            True: [0, 1, 2, 12, 13, 14, 16, 17, 24, 25, 26, 28, 29,
                   32, 49, 50, 52, 53, 61, 62, 64, 65, 68, 98, 101, 104]
            }
        for i,((a,b),(c,d)) in enumerate(prod):
            try:
                i1 = Interval(a, b)
                i2 = Interval(c, d)
                eq_(i1.intersects(i2), i2.intersects(i1))
                in_(i, should_intersect[i1.intersects(i2)])
            except IntervalError:
                assert(i not in should_intersect[True] and
                       i not in should_intersect[False])
        with assert_raises(TypeError):
            x = i1.intersects(1234)

    def test_intervalset_construct(self):
        # Test IntervalSet construction
        dates = [ 100, 200, 300, 400 ]

        a = Interval(dates[0], dates[1])
        b = Interval(dates[1], dates[2])
        c = Interval(dates[0], dates[2])
        d = Interval(dates[2], dates[3])

        iseta = IntervalSet(a)
        isetb = IntervalSet([a, b])
        isetc = IntervalSet([a])
        ne_(iseta, isetb)
        eq_(iseta, isetc)
        with assert_raises(TypeError):
            x = iseta != 3
        ne_(IntervalSet(a), IntervalSet(b))

        # Note that assignment makes a new reference (not a copy)
        isetd = IntervalSet(isetb)
        isete = isetd
        eq_(isetd, isetb)
        eq_(isetd, isete)
        isetd -= a
        ne_(isetd, isetb)
        eq_(isetd, isete)

        # test iterator
        for interval in iseta:
            pass

        # overlap
        with assert_raises(IntervalError):
            x = IntervalSet([a, b, c])

        # bad types
        with assert_raises(Exception):
            x = IntervalSet([1, 2])

        iset = IntervalSet(isetb)   # test iterator
        eq_(iset, isetb)
        eq_(len(iset), 2)
        eq_(len(IntervalSet()), 0)

        # Test adding
        iset = IntervalSet(a)
        iset += IntervalSet(b)
        eq_(iset, IntervalSet([a, b]))

        iset = IntervalSet(a)
        iset += b
        eq_(iset, IntervalSet([a, b]))

        iset = IntervalSet(a)
        iset.iadd_nocheck(b)
        eq_(iset, IntervalSet([a, b]))

        iset = IntervalSet(a) + IntervalSet(b)
        eq_(iset, IntervalSet([a, b]))

        iset = IntervalSet(b) + a
        eq_(iset, IntervalSet([a, b]))

        # A set consisting of [0-1],[1-2] should match a set consisting of [0-2]
        eq_(IntervalSet([a,b]), IntervalSet([c]))
        # Etc
        ne_(IntervalSet([a,d]), IntervalSet([c]))
        ne_(IntervalSet([c]), IntervalSet([a,d]))
        ne_(IntervalSet([c,d]), IntervalSet([b,d]))

        # misc
        eq_(repr(iset), repr(eval(repr(iset))))
        eq_(str(iset),
            "[[100 -> 200), [200 -> 300)]")

    def test_intervalset_geniset(self):
        # Test basic iset construction
        eq_(makeset("  [----)   "),
            makeset("  [-|--)   "))

        eq_(makeset("[)  [--)   ") +
            makeset(" [)    [--)"),
            makeset("[|) [-----)"))

        eq_(makeset("  [-------)"),
            makeset("  [-|-----|"))


    def test_intervalset_intersect_difference(self):
        # Test intersection (&)
        with assert_raises(TypeError): # was AttributeError
            x = makeset("[--)") & 1234

        def do_test(a, b, c, d):
            # a & b == c (using nilmdb.server.interval)
            ab = IntervalSet()
            for x in b:
                for i in (a & x):
                    ab += i
            eq_(ab,c)

            # a & b == c (using nilmdb.utils.interval)
            eq_(IntervalSet(nilmdb.utils.interval.intersection(a,b)), c)

            # a \ b == d
            eq_(IntervalSet(nilmdb.utils.interval.set_difference(a,b)), d)

        # Intersection with intervals
        do_test(makeset("[---|---)[)"),
                makeset("  [------) "),
                makeset("  [-----)  "), # intersection
                makeset("[-)      [)")) # difference

        do_test(makeset("[---------)"),
                makeset(" [---)     "),
                makeset(" [---)     "), # intersection
                makeset("[)   [----)")) # difference

        do_test(makeset(" [---)     "),
                makeset("[---------)"),
                makeset(" [---)     "), # intersection
                makeset("           ")) # difference

        do_test(makeset("    [-----)"),
                makeset(" [-----)   "),
                makeset("    [--)   "), # intersection
                makeset("       [--)")) # difference

        do_test(makeset(" [--)  [--)"),
                makeset("  [------) "),
                makeset("  [-)  [-) "), # intersection
                makeset(" [)      [)")) # difference

        do_test(makeset("      [---)"),
                makeset(" [--)      "),
                makeset("           "), # intersection
                makeset("      [---)")) # difference

        do_test(makeset("    [-|---)"),
                makeset(" [-----|-) "),
                makeset("    [----) "), # intersection
                makeset("         [)")) # difference

        do_test(makeset("    [-|-)  "),
                makeset(" [-|--|--) "),
                makeset("    [---)  "), # intersection
                makeset("           ")) # difference

        do_test(makeset("[-)[-)[-)[)"),
                makeset(" [)  [|)[) "),
                makeset(" [)   [)   "), # intersection
                makeset("[) [-) [)[)")) # difference

        # Border cases -- will give different results if intervals are
        # half open or fully closed.  In nilmdb, they are half open.
        do_test(makeset("      [---)"),
                makeset(" [----)    "),
                makeset("           "), # intersection
                makeset("      [---)")) # difference

        do_test(makeset(" [----)[--)"),
                makeset("[-) [--) [)"),
                makeset(" [) [-)  [)"), # intersection
                makeset("  [-)  [-) ")) # difference

        # Set difference with bounds
        a = makeset(" [----)[--)")
        b = makeset("[-) [--) [)")
        c = makeset("[----)     ")
        d = makeset("  [-)      ")
        eq_(nilmdb.utils.interval.set_difference(
            a.intersection(list(c)[0]), b.intersection(list(c)[0])), d)

        # Fill out test coverage for non-subsets
        def diff2(a,b, subset):
            return nilmdb.utils.interval._interval_math_helper(
                a, b, (lambda a, b: b and not a), subset=subset)
        with assert_raises(nilmdb.utils.interval.IntervalError):
            list(diff2(a,b,True))
        list(diff2(a,b,False))

        # Empty second set
        eq_(nilmdb.utils.interval.set_difference(a, IntervalSet()), a)

        # Empty second set
        eq_(nilmdb.utils.interval.set_difference(a, IntervalSet()), a)

class TestIntervalDB:
    def test_dbinterval(self):
        # Test DBInterval class
        i = DBInterval(100, 200, 100, 200, 10000, 20000)
        eq_(i.start, 100)
        eq_(i.end, 200)
        eq_(i.db_start, 100)
        eq_(i.db_end, 200)
        eq_(i.db_startpos, 10000)
        eq_(i.db_endpos, 20000)
        eq_(repr(i), repr(eval(repr(i))))

        # end before start
        with assert_raises(IntervalError):
            i = DBInterval(200, 100, 100, 200, 10000, 20000)

        # db_start too late
        with assert_raises(IntervalError):
            i = DBInterval(100, 200, 150, 200, 10000, 20000)

        # db_end too soon
        with assert_raises(IntervalError):
            i = DBInterval(100, 200, 100, 150, 10000, 20000)

        # actual start, end can be a subset
        a = DBInterval(150, 200, 100, 200, 10000, 20000)
        b = DBInterval(100, 150, 100, 200, 10000, 20000)
        c = DBInterval(150, 160, 100, 200, 10000, 20000)

        # Make a set of DBIntervals
        iseta = IntervalSet([a, b])
        isetc = IntervalSet(c)
        assert(iseta.intersects(a))
        assert(iseta.intersects(b))

        # Test subset
        with assert_raises(IntervalError):
            x = a.subset(150, 250)

        # Subset of those IntervalSets should still contain DBIntervals
        for i in IntervalSet(iseta.intersection(Interval(125,250))):
            assert(isinstance(i, DBInterval))

class TestIntervalTree:

    def test_interval_tree(self):
        import random
        random.seed(1234)

        # make a set of 100 intervals
        iset = IntervalSet()
        j = 100
        for i in random.sample(xrange(j),j):
            interval = Interval(i, i+1)
            iset += interval
        render(iset, "Random Insertion")

        # remove about half of them
        for i in random.sample(xrange(j),j):
            if random.randint(0,1):
                iset -= Interval(i, i+1)

        # try removing an interval that doesn't exist
        with assert_raises(IntervalError):
            iset -= Interval(1234,5678)
        render(iset, "Random Insertion, deletion")

        # make a set of 100 intervals, inserted in order
        iset = IntervalSet()
        j = 100
        for i in xrange(j):
            interval = Interval(i, i+1)
            iset += interval
        render(iset, "In-order insertion")

class TestIntervalSpeed:
    @unittest.skip("this is slow")
    def test_interval_speed(self):
        import yappi
        import time
        import random
        import math

        print
        yappi.start()
        speeds = {}
        limit = 22 # was 20
        for j in [ 2**x for x in range(5,limit) ]:
            start = time.time()
            iset = IntervalSet()
            for i in random.sample(xrange(j),j):
                interval = Interval(i, i+1)
                iset += interval
            speed = (time.time() - start) * 1000000.0
            printf("%d: %g μs (%g μs each, O(n log n) ratio %g)\n",
                   j,
                   speed,
                   speed/j,
                   speed / (j*math.log(j))) # should be constant
            speeds[j] = speed
        yappi.stop()
        yappi.print_stats(sort_type=yappi.SORTTYPE_TTOT, limit=10)
