import threading
from nilmdb.utils.printf import sprintf


def verify_proxy(obj_or_type, check_thread=True,
                 check_concurrent=True):
    """Wrap the given object or type in a VerifyObjectProxy.

    Returns a VerifyObjectProxy that proxies all method calls to the
    given object, as well as attribute retrievals.

    When calling methods, the following checks are performed.  On
    failure, an exception is raised.

    check_thread = True     # Fail if two different threads call methods.
    check_concurrent = True # Fail if two functions are concurrently
                            # run through this proxy
    """
    class Namespace():
        pass

    class VerifyCallProxy():
        def __init__(self, func, parent_namespace):
            self.func = func
            self.parent_namespace = parent_namespace

        def __call__(self, *args, **kwargs):
            p = self.parent_namespace
            this = threading.current_thread()
            try:
                callee = self.func.__name__
            except AttributeError:
                callee = "???"

            if p.thread is None:
                p.thread = this
                p.thread_callee = callee

            if check_thread and p.thread != this:
                err = sprintf("unsafe threading: %s called %s.%s,"
                              " but %s called %s.%s",
                              p.thread.name, p.classname, p.thread_callee,
                              this.name, p.classname, callee)
                raise AssertionError(err)

            need_concur_unlock = False
            if check_concurrent:
                if not p.concur_lock.acquire(False):
                    err = sprintf("unsafe concurrency: %s called %s.%s "
                                  "while %s is still in %s.%s",
                                  this.name, p.classname, callee,
                                  p.concur_tname, p.classname, p.concur_callee)
                    raise AssertionError(err)
                else:
                    p.concur_tname = this.name
                    p.concur_callee = callee
                    need_concur_unlock = True

            try:
                ret = self.func(*args, **kwargs)
            finally:
                if need_concur_unlock:
                    p.concur_lock.release()
            return ret

    class VerifyObjectProxy():
        def __init__(self, obj_or_type, *args, **kwargs):
            p = Namespace()
            self.__ns = p
            p.thread = None
            p.thread_callee = None
            p.concur_lock = threading.Lock()
            p.concur_tname = None
            p.concur_callee = None
            self.__obj = obj_or_type
            if isinstance(obj_or_type, type):
                p.classname = self.__obj.__name__
            else:
                p.classname = self.__obj.__class__.__name__

        def __getattr__(self, key):
            attr = getattr(self.__obj, key)
            if not callable(attr):
                return VerifyCallProxy(getattr, self.__ns)(self.__obj, key)
            return VerifyCallProxy(attr, self.__ns)

        def __call__(self, *args, **kwargs):
            """Call this to instantiate the type, if a type was passed
            to verify_proxy.  Otherwise, pass the call through."""
            ret = VerifyCallProxy(self.__obj, self.__ns)(*args, **kwargs)
            if isinstance(self.__obj, type):
                # Instantiation
                self.__obj = ret
                return self
            return ret

    return VerifyObjectProxy(obj_or_type)
