import nilmdb.server

from nose.tools import *
from nose.tools import assert_raises
import distutils.version
import json
import itertools
import os
import sys
import threading
import urllib.request, urllib.error, urllib.parse
from urllib.request import urlopen
from urllib.error import HTTPError
import io
import time
import requests
import socket
import sqlite3
import cherrypy

from nilmdb.utils import serializer_proxy
from nilmdb.server.interval import Interval

testdb = "tests/testdb"

#@atexit.register
#def cleanup():
#    os.unlink(testdb)

from testutil.helpers import *

def setup_module():
    # Make sure port is free
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    try:
        sock.bind(("127.0.0.1", 32180))
    except OSError:
        raise AssertionError("port 32180 must be free for tests")
    sock.close()

class Test00Nilmdb(object):  # named 00 so it runs first
    def test_NilmDB(self):
        recursive_unlink(testdb)

        db = nilmdb.server.NilmDB(testdb)
        db.close()
        db = nilmdb.server.NilmDB(testdb)
        db.close()
        db.close()

    def test_error_cases(self):
        # Test some misc error cases to get better code coverage

        with assert_raises(OSError) as e:
            nilmdb.server.NilmDB("/dev/null/bogus")
        in_("can't create tree", str(e.exception))

        # Version upgrades
        con = sqlite3.connect(os.path.join(testdb, "data.sql"))
        con.execute("PRAGMA user_version = 2");
        con.close()
        with assert_raises(Exception) as e:
            db = nilmdb.server.NilmDB(testdb)
        in_("can't use database version 2", str(e.exception))

        con = sqlite3.connect(os.path.join(testdb, "data.sql"))
        con.execute("PRAGMA user_version = -1234");
        con.close()
        with assert_raises(Exception) as e:
            db = nilmdb.server.NilmDB(testdb)
        in_("unknown database version -1234", str(e.exception))

        recursive_unlink(testdb)

        nilmdb.server.NilmDB.verbose = 1
        (old, sys.stdout) = (sys.stdout, io.StringIO())
        db = nilmdb.server.NilmDB(testdb)
        (output, sys.stdout) = (sys.stdout.getvalue(), old)
        nilmdb.server.NilmDB.verbose = 0
        db.close()
        in_("Database schema updated to 1", output)

        # Corrupted database (bad ranges)
        recursive_unlink(testdb)
        db = nilmdb.server.NilmDB(testdb)
        db.con.executescript("""
        INSERT INTO streams VALUES (1, "/test", "int32_1");
        INSERT INTO ranges VALUES (1, 100, 200, 100, 200);
        INSERT INTO ranges VALUES (1, 150, 250, 150, 250);
        """)
        db.close()
        db = nilmdb.server.NilmDB(testdb)
        with assert_raises(nilmdb.server.NilmDBError):
            db.stream_intervals("/test")
        db.close()
        recursive_unlink(testdb)

    def test_stream(self):
        db = nilmdb.server.NilmDB(testdb)
        eq_(db.stream_list(), [])

        # Bad path
        with assert_raises(ValueError):
            db.stream_create("foo/bar/baz", "float32_8")
        with assert_raises(ValueError):
            db.stream_create("/foo", "float32_8")
        # Bad layout type
        with assert_raises(ValueError):
            db.stream_create("/newton/prep", "NoSuchLayout")
        db.stream_create("/newton/prep", "float32_8")
        db.stream_create("/newton/raw", "uint16_6")
        db.stream_create("/newton/zzz/rawnotch", "uint16_9")

        # Verify we got 3 streams
        eq_(db.stream_list(), [ ["/newton/prep", "float32_8"],
                                ["/newton/raw", "uint16_6"],
                                ["/newton/zzz/rawnotch", "uint16_9"]
                                ])
        # Match just one type or one path
        eq_(db.stream_list(layout="uint16_6"), [ ["/newton/raw", "uint16_6"] ])
        eq_(db.stream_list(path="/newton/raw"), [ ["/newton/raw", "uint16_6"] ])

        # Set / get metadata
        eq_(db.stream_get_metadata("/newton/prep"), {})
        eq_(db.stream_get_metadata("/newton/raw"), {})
        meta1 = { "description": "The Data",
                  "v_scale": "1.234" }
        meta2 = { "description": "The Data" }
        meta3 = { "v_scale": "1.234" }
        db.stream_set_metadata("/newton/prep", meta1)
        db.stream_update_metadata("/newton/prep", {})
        db.stream_update_metadata("/newton/raw", meta2)
        db.stream_update_metadata("/newton/raw", meta3)
        eq_(db.stream_get_metadata("/newton/prep"), meta1)
        eq_(db.stream_get_metadata("/newton/raw"), meta1)

        # fill in some misc. test coverage
        with assert_raises(nilmdb.server.NilmDBError):
            db.stream_remove("/newton/prep", 0, 0)
        with assert_raises(nilmdb.server.NilmDBError):
            db.stream_remove("/newton/prep", 1, 0)
        db.stream_remove("/newton/prep", 0, 1)

        with assert_raises(nilmdb.server.NilmDBError):
            db.stream_extract("/newton/prep", count = True, binary = True)

        db.close()

class TestBlockingServer(object):
    def setUp(self):
        self.db = serializer_proxy(nilmdb.server.NilmDB)(testdb)

    def tearDown(self):
        self.db.close()

    def test_blocking_server(self):
        # Server should fail if the database doesn't have a "_thread_safe"
        # property.
        with assert_raises(KeyError):
            nilmdb.server.Server(object())

        # Start web app on a custom port
        self.server = nilmdb.server.Server(self.db, host = "127.0.0.1",
                                           port = 32180, stoppable = True)

        def start_server():
            # Run it
            event = threading.Event()
            def run_server():
                self.server.start(blocking = True, event = event)
            thread = threading.Thread(target = run_server)
            thread.start()
            if not event.wait(timeout = 10):
                raise AssertionError("server didn't start in 10 seconds")
            return thread

        # Start server and request for it to exit
        thread = start_server()
        req = urlopen("http://127.0.0.1:32180/exit/", timeout = 1)
        thread.join()

        # Mock some signals that should kill the server
        def try_signal(sig):
            old = cherrypy.engine.wait
            def raise_sig(*args, **kwargs):
                raise sig()
            cherrypy.engine.wait = raise_sig
            thread = start_server()
            thread.join()
            cherrypy.engine.wait = old
        try_signal(SystemExit)
        try_signal(KeyboardInterrupt)

def geturl(path):
    resp = urlopen("http://127.0.0.1:32180" + path, timeout = 10)
    body = resp.read()
    return body.decode(resp.headers.get_content_charset() or 'utf-8')

def getjson(path):
    return json.loads(geturl(path))

class TestServer(object):

    def setUp(self):
        # Start web app on a custom port
        self.db = serializer_proxy(nilmdb.server.NilmDB)(testdb)
        self.server = nilmdb.server.Server(self.db, host = "127.0.0.1",
                                           port = 32180, stoppable = False)
        self.server.start(blocking = False)

    def tearDown(self):
        # Close web app
        self.server.stop()
        self.db.close()

    def test_server(self):
        # Make sure we can't force an exit, and test other 404 errors
        for url in [ "/exit", "/favicon.ico" ]:
            with assert_raises(HTTPError) as e:
                geturl(url)
            eq_(e.exception.code, 404)

        # Root page
        in_("This is NilmDB", geturl("/"))

        # Check version
        eq_(distutils.version.LooseVersion(getjson("/version")),
            distutils.version.LooseVersion(nilmdb.__version__))

    def test_stream_list(self):
        # Known streams that got populated by an earlier test (test_nilmdb)
        streams = getjson("/stream/list")

        eq_(streams, [
            ['/newton/prep', 'float32_8'],
            ['/newton/raw', 'uint16_6'],
            ['/newton/zzz/rawnotch', 'uint16_9'],
            ])

        streams = getjson("/stream/list?layout=uint16_6")
        eq_(streams, [['/newton/raw', 'uint16_6']])

        streams = getjson("/stream/list?layout=NoSuchLayout")
        eq_(streams, [])


    def test_stream_metadata(self):
        with assert_raises(HTTPError) as e:
            getjson("/stream/get_metadata?path=foo")
        eq_(e.exception.code, 404)

        data = getjson("/stream/get_metadata?path=/newton/prep")
        eq_(data, {'description': 'The Data', 'v_scale': '1.234'})

        data = getjson("/stream/get_metadata?path=/newton/prep"
                       "&key=v_scale")
        eq_(data, {'v_scale': '1.234'})

        data = getjson("/stream/get_metadata?path=/newton/prep"
                       "&key=v_scale&key=description")
        eq_(data, {'description': 'The Data', 'v_scale': '1.234'})

        data = getjson("/stream/get_metadata?path=/newton/prep"
                       "&key=v_scale&key=foo")
        eq_(data, {'foo': None, 'v_scale': '1.234'})

        data = getjson("/stream/get_metadata?path=/newton/prep"
                       "&key=foo")
        eq_(data, {'foo': None})

    def test_cors_headers(self):
        # Test that CORS headers are being set correctly

        # Normal GET should send simple response
        url = "http://127.0.0.1:32180/stream/list"
        r = requests.get(url, headers = { "Origin": "http://google.com/" })
        eq_(r.status_code, 200)
        if "access-control-allow-origin" not in r.headers:
            raise AssertionError("No Access-Control-Allow-Origin (CORS) "
                                 "header in response:\n", r.headers)
        eq_(r.headers["access-control-allow-origin"], "http://google.com/")

        # OPTIONS without CORS preflight headers should result in 405
        r = requests.options(url, headers = {
            "Origin": "http://google.com/",
            })
        eq_(r.status_code, 405)

        # OPTIONS with preflight headers should give preflight response
        r = requests.options(url, headers = {
            "Origin": "http://google.com/",
            "Access-Control-Request-Method": "POST",
            "Access-Control-Request-Headers": "X-Custom",
            })
        eq_(r.status_code, 200)
        if "access-control-allow-origin" not in r.headers:
            raise AssertionError("No Access-Control-Allow-Origin (CORS) "
                                 "header in response:\n", r.headers)
        eq_(r.headers["access-control-allow-methods"], "GET, HEAD")
        eq_(r.headers["access-control-allow-headers"], "X-Custom")

    def test_post_bodies(self):
        # Test JSON post bodies
        r = requests.post("http://127.0.0.1:32180/stream/set_metadata",
                          headers = { "Content-Type": "application/json" },
                          data = '{"hello": 1}')
        eq_(r.status_code, 404) # wrong parameters

        r = requests.post("http://127.0.0.1:32180/stream/set_metadata",
                          headers = { "Content-Type": "application/json" },
                          data = '["hello"]')
        eq_(r.status_code, 415) # not a dict

        r = requests.post("http://127.0.0.1:32180/stream/set_metadata",
                          headers = { "Content-Type": "application/json" },
                          data = '[hello]')
        eq_(r.status_code, 400) # badly formatted JSON
