import nilmdb
from nilmdb.utils.printf import *

import nose
from nose.tools import *
from nose.tools import assert_raises
import threading
import time

from testutil.helpers import *

class Foo(object):
    val = 0

    def __init__(self, asdf = "asdf"):
        self.init_thread = threading.current_thread().name

    @classmethod
    def foo(self):
        pass

    def fail(self):
        raise Exception("you asked me to do this")

    def test(self, debug = False):
        self.tester(debug)

    def t(self):
        pass

    def tester(self, debug = False):
        # purposely not thread-safe
        self.test_thread = threading.current_thread().name
        oldval = self.val
        newval = oldval + 1
        time.sleep(0.05)
        self.val = newval
        if debug:
            printf("[%s] value changed: %d -> %d\n",
                   threading.current_thread().name, oldval, newval)

class Base(object):

    def test_wrapping(self):
        self.foo.test()
        with assert_raises(Exception):
            self.foo.fail()

    def test_threaded(self):
        def func(foo):
            foo.test()
        threads = []
        for i in xrange(20):
            threads.append(threading.Thread(target = func, args = (self.foo,)))
        for t in threads:
            t.start()
        for t in threads:
            t.join()
        self.verify_result()

    def verify_result(self):
        eq_(self.foo.val, 20)
        eq_(self.foo.init_thread, self.foo.test_thread)

class ListLike(object):
    def __init__(self):
        self.thread = threading.current_thread().name
        self.foo = 0

    def __iter__(self):
        eq_(threading.current_thread().name, self.thread)
        self.foo = 0
        return self

    def __getitem__(self, key):
        eq_(threading.current_thread().name, self.thread)
        return key

    def next(self):
        eq_(threading.current_thread().name, self.thread)
        if self.foo < 5:
            self.foo += 1
            return self.foo
        else:
            raise StopIteration

class TestUnserialized(Base):
    def setUp(self):
        self.foo = Foo()

    def verify_result(self):
        # This should have failed to increment properly
        ne_(self.foo.val, 20)
        # Init and tests ran in different threads
        ne_(self.foo.init_thread, self.foo.test_thread)

class TestSerializer(Base):
    def setUp(self):
        self.foo = nilmdb.utils.serializer_proxy(Foo)("qwer")

    def test_multi(self):
        sp = nilmdb.utils.serializer_proxy
        sp(Foo("x")).t()
        sp(sp(Foo)("x")).t()
        sp(sp(Foo))("x").t()
        sp(sp(Foo("x"))).t()
        sp(sp(Foo)("x")).t()
        sp(sp(Foo))("x").t()

    def test_iter(self):
        sp = nilmdb.utils.serializer_proxy
        i = sp(ListLike)()
        eq_(list(i), [1,2,3,4,5])
        eq_(i[3], 3)
