import nilmdb
from nilmdb.utils.printf import *

import nose
from nose.tools import *
from nose.tools import assert_raises

from testutil.helpers import *
import threading

class Thread(threading.Thread):
    def __init__(self, target):
        self.target = target
        threading.Thread.__init__(self)

    def run(self):
        try:
            self.target()
        except AssertionError as e:
            self.error = e
        else:
            self.error = None

class Test():
    def __init__(self):
        self.test = 1234

    @classmethod
    def asdf(cls):
        pass

    def foo(self, exception = False, reenter = False):
        if exception:
            raise Exception()
        self.bar(reenter)

    def bar(self, reenter):
        if reenter:
            self.foo()
        return 123

    def baz_threaded(self, target):
        t = Thread(target)
        t.start()
        t.join()
        return t

    def baz(self, target):
        target()

class TestThreadSafety(object):
    def tryit(self, c, threading_ok, concurrent_ok):
        eq_(c.test, 1234)
        c.foo()
        t = Thread(c.foo)
        t.start()
        t.join()
        if threading_ok and t.error:
            raise Exception("got unexpected error: " + str(t.error))
        if not threading_ok and not t.error:
            raise Exception("failed to get expected error")
        try:
            c.baz(c.foo)
        except AssertionError as e:
            if concurrent_ok:
                raise Exception("got unexpected error: " + str(e))
        else:
            if not concurrent_ok:
                raise Exception("failed to get expected error")
        t = c.baz_threaded(c.foo)
        if (concurrent_ok and threading_ok) and t.error:
            raise Exception("got unexpected error: " + str(t.error))
        if not (concurrent_ok and threading_ok) and not t.error:
            raise Exception("failed to get expected error")

    def test(self):
        proxy = nilmdb.utils.threadsafety.verify_proxy
        self.tryit(Test(), True, True)
        self.tryit(proxy(Test(), True, True, True), False, False)
        self.tryit(proxy(Test(), True, True, False), False, True)
        self.tryit(proxy(Test(), True, False, True), True, False)
        self.tryit(proxy(Test(), True, False, False), True, True)
        self.tryit(proxy(Test, True, True, True)(), False, False)
        self.tryit(proxy(Test, True, True, False)(), False, True)
        self.tryit(proxy(Test, True, False, True)(), True, False)
        self.tryit(proxy(Test, True, False, False)(), True, True)

        proxy(proxy(proxy(Test))()).foo()

        c = proxy(Test())
        c.foo()
        try:
            c.foo(exception = True)
        except Exception:
            pass
        c.foo()
