import nilmdb
import nilmdb.fsck

from nilmdb.utils.printf import printf
from nose.tools import assert_raises

import io
import sys
import shutil
import traceback
from testutil.helpers import *

class TestFsck(object):

    def run(self, db, fix=True, skip=False, checker=None):
        """Run fsck client with the given test database.  Save the output and
        exit code.
        """
        if db is not None:
            recursive_unlink("tests/fsck-testdb")
            shutil.copytree(f"tests/fsck-data/{db}", "tests/fsck-testdb",
                            ignore=shutil.ignore_patterns(
                                "git-empty-dir-placeholder"))
        class stdio_wrapper:
            def __init__(self, stdin, stdout, stderr):
                self.io = (stdin, stdout, stderr)
            def __enter__(self):
                self.saved = ( sys.stdin, sys.stdout, sys.stderr )
                ( sys.stdin, sys.stdout, sys.stderr ) = self.io
            def __exit__(self, type, value, traceback):
                ( sys.stdin, sys.stdout, sys.stderr ) = self.saved
        # Empty input
        infile = io.TextIOWrapper(io.BytesIO(b""))
        # Capture stdout, stderr
        errfile = io.TextIOWrapper(io.BytesIO())
        outfile = errfile
        with stdio_wrapper(infile, outfile, errfile) as s:
            try:
                f = nilmdb.fsck.Fsck("tests/fsck-testdb", fix)
                if checker:
                    checker(f)
                else:
                    f.check(skip_data=skip)
                sys.exit(0)
            except SystemExit as e:
                exitcode = e.code
            except Exception as e:
                traceback.print_exc()
                exitcode = 1

        # Capture raw binary output, and also try to decode a Unicode
        # string copy.
        self.captured_binary = outfile.buffer.getvalue()
        try:
            outfile.seek(0)
            self.captured = outfile.read()
        except UnicodeDecodeError:
            self.captured = None

        self.exitcode = exitcode

    def ok(self, *args, **kwargs):
        self.run(*args, **kwargs)
        if self.exitcode != 0:
            self.dump()
            eq_(self.exitcode, 0)

    def okmsg(self, db, expect, **kwargs):
        self.ok(db, **kwargs)
        self.contain(expect)

    def fail(self, *args, exitcode=None, **kwargs):
        self.run(*args, **kwargs)
        if exitcode is not None and self.exitcode != exitcode:
            # Wrong exit code
            self.dump()
            eq_(self.exitcode, exitcode)
        if self.exitcode == 0:
            # Success, when we wanted failure
            self.dump()
            ne_(self.exitcode, 0)

    def failmsg(self, db, expect, **kwargs):
        self.fail(db, **kwargs)
        self.contain(expect)

    def contain(self, checkstring, contain=True):
        if contain:
            in_(checkstring, self.captured)
        else:
            nin_(checkstring, self.captured)

    def dump(self):
        printf("\n===dump start===\n%s===dump end===\n", self.captured)


    def test_fsck(self):
        self.okmsg("test1", "\nok")

        def check_paths_twice(f):
            f.check()
            f.check_paths()
            f.bulk.close()
        self.ok("test1", checker=check_paths_twice)

        with open("tests/fsck-testdb/data.lock", "w") as lock:
            nilmdb.utils.lock.exclusive_lock(lock)
            self.failmsg(None, "Database already locked")

        self.failmsg("test1a", "SQL database missing")
        self.failmsg("test1b", "Bulk data directory missing")
        self.failmsg("test1c", "database version 0 too old")

        self.okmsg("test2", "\nok")

        self.failmsg("test2a", "duplicated ID 1 in stream IDs")
        self.failmsg("test2b", "interval ID 2 not in streams")
        self.failmsg("test2c", "metadata ID 2 not in streams")
        self.failmsg("test2d", "duplicate metadata key")
        self.failmsg("test2e", "duplicated path")
        self.failmsg("test2f", "bad layout")
        self.failmsg("test2g", "bad count")
        self.failmsg("test2h", "missing bulkdata dir")
        self.failmsg("test2i", "bad bulkdata table")
        self.failmsg("test2j", "overlap in intervals")
        self.failmsg("test2k", "overlap in file offsets", fix=False)
        self.ok("test2k1")
        self.failmsg("test2l", "unsupported bulkdata version")
        self.failmsg("test2m", "bad rows_per_file")
        self.failmsg("test2n", "bad files_per_dir")
        self.failmsg("test2o", "layout mismatch")
        self.failmsg("test2p", "missing data files", fix=False)
        self.contain("This may be fixable")
        self.okmsg("test2p", "Removing empty subpath")
        self.failmsg("test2p1", "please manually remove the file")
        self.okmsg("test2p2", "Removing empty subpath")
        self.failmsg("test2q", "extra bytes present", fix=False)
        self.okmsg("test2q", "Truncating file")

        self.failmsg("test2r", "error accessing rows", fix=False)
        self.okmsg("test2r", "end position is past endrows")

        self.okmsg("test2r1", "actually it can't be truncated")
        self.contain("Deleting the entire interval")
        self.contain("restarting fsck")
        self.ok("test2r2")

        self.failmsg("test2s", "non-monotonic timestamp (1000000 -> 12345)")

        def check_small_maxrows(f):
            f.maxrows_override = 1
            f.check()
        self.fail("test2t", checker=check_small_maxrows)
        self.contain("first interval timestamp 0 is not greater")

        self.ok("test2t", skip=True)

        self.failmsg("test2u", "data timestamp 1234567890 at row 28 outside")
        self.failmsg("test2u1", "data timestamp 7 at row 0 outside")

        @nilmdb.fsck.fsck.retry_if_raised(Exception, max_retries=3)
        def foo():
            raise Exception("hi")
        with assert_raises(Exception):
            foo()

        self.failmsg("test2v", "can't load _format, but data is also present")
        self.failmsg("test2v1", "bad bulkdata table")
        self.failmsg("test2v2", "empty, with corrupted format file", fix=False)
        self.okmsg("test2v2", "empty, with corrupted format file")

        self.failmsg("test2w1", "out of range, and zero", fix=False)
        self.okmsg("test2w1", "Will try truncating table")
        self.contain("Deleting the entire interval")

        self.failmsg("test2w2", "non-monotonic, and zero", fix=False)
        self.okmsg("test2w2", "Will try truncating table")
        self.contain("new end: time 237000001, pos 238")

        self.failmsg("test2x1", "overlap in file offsets", fix=False)
        self.okmsg("test2x1", "truncating")

        self.failmsg("test2x2", "unfixable overlap")
        self.failmsg("test2x3", "unfixable overlap")
