# Fixed record size bulk data storage

import os
import re
import sys
import pickle
import tempfile

from nilmdb.utils.printf import sprintf
from nilmdb.utils.time import timestamp_to_string
import nilmdb.utils

import nilmdb.utils.lock
from . import rocket

# Up to 256 open file descriptors at any given time.
# These variables are global so they can be used in the decorator arguments.
table_cache_size = 32
fd_cache_size = 8


@nilmdb.utils.must_close(wrap_verify=False)
class BulkData():
    def __init__(self, basepath, **kwargs):
        if isinstance(basepath, str):
            self.basepath = self._encode_filename(basepath)
        else:
            self.basepath = basepath
        self.root = os.path.join(self.basepath, b"data")
        self.lock = self.root + b".lock"
        self.lockfile = None

        # Tuneables
        if "file_size" in kwargs and kwargs["file_size"] is not None:
            self.file_size = kwargs["file_size"]
        else:
            # Default to approximately 128 MiB per file
            self.file_size = 128 * 1024 * 1024

        if "files_per_dir" in kwargs and kwargs["files_per_dir"] is not None:
            self.files_per_dir = kwargs["files_per_dir"]
        else:
            # 32768 files per dir should work even on FAT32
            self.files_per_dir = 32768

        if "initial_nrows" in kwargs and kwargs["initial_nrows"] is not None:
            self.initial_nrows = kwargs["initial_nrows"]
        else:
            # First row is 0
            self.initial_nrows = 0

        # Make root path
        if not os.path.isdir(self.root):
            os.mkdir(self.root)

        # Create the lock
        self.lockfile = open(self.lock, "w")
        if not nilmdb.utils.lock.exclusive_lock(self.lockfile):
            raise IOError('database at "' +
                          self._decode_filename(self.basepath) +
                          '" is already locked by another process')

    def close(self):
        self.getnode.cache_remove_all()
        if self.lockfile:
            nilmdb.utils.lock.exclusive_unlock(self.lockfile)
            self.lockfile.close()
            try:
                os.unlink(self.lock)
            except OSError:
                pass
            self.lockfile = None

    def _encode_filename(self, path):
        # Translate unicode strings to raw bytes, if needed.  We
        # always manipulate paths internally as bytes.
        return path.encode('utf-8')

    def _decode_filename(self, path):
        # Translate raw bytes to unicode strings, escaping if needed
        return path.decode('utf-8', errors='backslashreplace')

    def _create_check_ospath(self, ospath):
        if ospath[-1:] == b'/':
            raise ValueError("invalid path; should not end with a /")
        if Table.exists(ospath):
            raise ValueError("stream already exists at this path")
        if os.path.isdir(ospath):
            # Look for any files in subdirectories.  Fully empty subdirectories
            # are OK; they might be there during a rename
            for (root, dirs, files) in os.walk(ospath):
                if files:
                    raise ValueError(
                        "non-empty subdirs of this path already exist")

    def _create_parents(self, unicodepath):
        """Verify the path name, and create parent directories if they
        don't exist.  Returns a list of elements that got created."""
        path = self._encode_filename(unicodepath)

        if path[0:1] != b'/':
            raise ValueError("paths must start with / ")
        [group, node] = path.rsplit(b"/", 1)
        if group == b'':
            raise ValueError("invalid path; path must contain at least one "
                             "folder")
        if node == b'':
            raise ValueError("invalid path; should not end with a /")
        if not Table.valid_path(path):
            raise ValueError("path name is invalid or contains reserved words")

        # Create the table's base dir.  Note that we make a
        # distinction here between NilmDB paths (always Unix style,
        # split apart manually) and OS paths (built up with
        # os.path.join)

        # Make directories leading up to this one
        elements = path.lstrip(b'/').split(b'/')
        made_dirs = []
        try:
            # Make parent elements
            for i in range(len(elements)):
                ospath = os.path.join(self.root, *elements[0:i])
                if Table.exists(ospath):
                    raise ValueError("path is subdir of existing node")
                if not os.path.isdir(ospath):
                    os.mkdir(ospath)
                    made_dirs.append(ospath)
        except Exception:
            # Remove paths that we created
            for ospath in reversed(made_dirs):
                os.rmdir(ospath)
            raise

        return elements

    def create(self, unicodepath, layout_name):
        """
        unicodepath: path to the data (e.g. u'/newton/prep').
        Paths must contain at least two elements, e.g.:
           /newton/prep
           /newton/raw
           /newton/upstairs/prep
           /newton/upstairs/raw

        layout_name: string for nilmdb.layout.get_named(), e.g. 'float32_8'
        """
        elements = self._create_parents(unicodepath)

        # Make the final dir
        ospath = os.path.join(self.root, *elements)
        self._create_check_ospath(ospath)
        os.mkdir(ospath)

        try:
            # Write format string to file
            Table.create(ospath, layout_name, self.file_size,
                         self.files_per_dir)

            # Open and cache it
            self.getnode(unicodepath)
        except Exception:
            exc_info = sys.exc_info()
            try:
                os.rmdir(ospath)
            except OSError:
                pass
            raise exc_info[1].with_traceback(exc_info[2])

        # Success
        return

    def _remove_leaves(self, unicodepath):
        """Remove empty directories starting at the leaves of unicodepath"""
        path = self._encode_filename(unicodepath)
        elements = path.lstrip(b'/').split(b'/')
        for i in reversed(list(range(len(elements)))):
            ospath = os.path.join(self.root, *elements[0:i+1])
            try:
                os.rmdir(ospath)
            except OSError:
                pass

    def rename(self, oldunicodepath, newunicodepath):
        """Move entire tree from 'oldunicodepath' to
        'newunicodepath'"""
        oldpath = self._encode_filename(oldunicodepath)
        newpath = self._encode_filename(newunicodepath)

        # Get OS paths
        oldelements = oldpath.lstrip(b'/').split(b'/')
        oldospath = os.path.join(self.root, *oldelements)
        newelements = newpath.lstrip(b'/').split(b'/')
        newospath = os.path.join(self.root, *newelements)

        # Basic checks
        if oldospath == newospath:
            raise ValueError("old and new paths are the same")

        # Remove Table object at old path from cache
        self.getnode.cache_remove(self, oldunicodepath)

        # Move the table to a temporary location
        tmpdir = tempfile.mkdtemp(prefix=b"rename-", dir=self.root)
        tmppath = os.path.join(tmpdir, b"table")
        os.rename(oldospath, tmppath)

        try:
            # Check destination path
            self._create_check_ospath(newospath)

            # Create parent dirs for new location
            self._create_parents(newunicodepath)

            # Move table into new location
            os.rename(tmppath, newospath)
        except Exception:
            # On failure, move the table back to original path
            os.rename(tmppath, oldospath)
            os.rmdir(tmpdir)
            raise

        # Prune old dirs
        self._remove_leaves(oldunicodepath)
        os.rmdir(tmpdir)

    def destroy(self, unicodepath):
        """Fully remove all data at a particular path.  No way to undo
        it!  The group/path structure is removed, too."""
        path = self._encode_filename(unicodepath)

        # Get OS path
        elements = path.lstrip(b'/').split(b'/')
        ospath = os.path.join(self.root, *elements)

        # Remove Table object from cache
        self.getnode.cache_remove(self, unicodepath)

        # Remove the contents of the target directory
        if not Table.exists(ospath):
            raise ValueError("nothing at that path")
        for (root, dirs, files) in os.walk(ospath, topdown=False):
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                os.rmdir(os.path.join(root, name))

        # Remove leftover empty directories
        self._remove_leaves(unicodepath)

    # Cache open tables
    @nilmdb.utils.lru_cache(size=table_cache_size,
                            onremove=lambda x: x.close())
    def getnode(self, unicodepath):
        """Return a Table object corresponding to the given database
        path, which must exist."""
        path = self._encode_filename(unicodepath)
        elements = path.lstrip(b'/').split(b'/')
        ospath = os.path.join(self.root, *elements)
        return Table(ospath, self.initial_nrows)


@nilmdb.utils.must_close(wrap_verify=False)
class Table():
    """Tools to help access a single table (data at a specific OS path)."""
    # See design.md for design details

    # Class methods, to help keep format details in this class.
    @classmethod
    def valid_path(cls, root):
        """Return True if a root path is a valid name"""
        return b"_format" not in root.split(b"/")

    @classmethod
    def exists(cls, root):
        """Return True if a table appears to exist at this OS path"""
        return os.path.isfile(os.path.join(root, b"_format"))

    @classmethod
    def create(cls, root, layout, file_size, files_per_dir):
        """Initialize a table at the given OS path with the
        given layout string"""

        # Calculate rows per file so that each file is approximately
        # file_size bytes.
        rkt = rocket.Rocket(layout, None)
        rows_per_file = max(file_size // rkt.binary_size, 1)
        rkt.close()

        fmt = {
            "rows_per_file": rows_per_file,
            "files_per_dir": files_per_dir,
            "layout": layout,
            "version": 3
        }
        with open(os.path.join(root, b"_format"), "wb") as f:
            pickle.dump(fmt, f, 2)

    # Normal methods
    def __init__(self, root, initial_nrows=0):
        """'root' is the full OS path to the directory of this table"""
        self.root = root
        self.initial_nrows = initial_nrows

        # Load the format
        with open(os.path.join(self.root, b"_format"), "rb") as f:
            fmt = pickle.load(f)

        if fmt["version"] != 3:
            # Old versions used floating point timestamps, which aren't
            # valid anymore.
            raise NotImplementedError("old version " + str(fmt["version"]) +
                                      " bulk data store is not supported")

        self.rows_per_file = fmt["rows_per_file"]
        self.files_per_dir = fmt["files_per_dir"]
        self.layout = fmt["layout"]

        # Use rocket to get row size and file size
        rkt = rocket.Rocket(self.layout, None)
        self.row_size = rkt.binary_size
        self.file_size = rkt.binary_size * self.rows_per_file
        rkt.close()

        # Find nrows
        self.nrows = self._get_nrows()

    def close(self):
        self.file_open.cache_remove_all()

    # Internal helpers
    def _get_nrows(self):
        """Find nrows by locating the lexicographically last filename
        and using its size"""
        # Note that this just finds a 'nrows' that is guaranteed to be
        # greater than the row number of any piece of data that
        # currently exists, not necessarily all data that _ever_
        # existed.
        regex = re.compile(b"^[0-9a-f]{4,}$")

        # Find the last directory.  We sort and loop through all of them,
        # starting with the numerically greatest, because the dirs could be
        # empty if something was deleted but the directory was unexpectedly
        # not deleted.
        subdirs = sorted(filter(regex.search, os.listdir(self.root)),
                         key=lambda x: int(x, 16), reverse=True)

        for subdir in subdirs:
            # Now find the last file in that dir
            path = os.path.join(self.root, subdir)
            files = list(filter(regex.search, os.listdir(path)))
            if not files:
                # Empty dir: try the next one
                continue

            # Find the numerical max
            filename = max(files, key=lambda x: int(x, 16))
            offset = os.path.getsize(os.path.join(self.root, subdir, filename))

            # Convert to row number
            return self._row_from_offset(subdir, filename, offset)

        # No files, so no data.  We typically start at row 0 in this
        # case, although initial_nrows is specified during some tests
        # to exercise other parts of the code better.  Since we have
        # no files yet, round initial_nrows up so it points to a row
        # that would begin a new file.
        nrows = ((self.initial_nrows + (self.rows_per_file - 1)) //
                 self.rows_per_file) * self.rows_per_file
        return nrows

    def _offset_from_row(self, row):
        """Return a (subdir, filename, offset, count) tuple:

          subdir: subdirectory for the file
        filename: the filename that contains the specified row
          offset: byte offset of the specified row within the file
           count: number of rows (starting at offset) that fit in the file
        """
        filenum = row // self.rows_per_file
        # It's OK if these format specifiers are too short; the filenames
        # will just get longer but will still sort correctly.
        dirname = sprintf(b"%04x", filenum // self.files_per_dir)
        filename = sprintf(b"%04x", filenum % self.files_per_dir)
        offset = (row % self.rows_per_file) * self.row_size
        count = self.rows_per_file - (row % self.rows_per_file)
        return (dirname, filename, offset, count)

    def _row_from_offset(self, subdir, filename, offset):
        """Return the row number that corresponds to the given
        'subdir/filename' and byte-offset within that file."""
        if (offset % self.row_size) != 0:
            # this shouldn't occur, unless there is some corruption somewhere
            raise ValueError("file offset is not a multiple of data size")
        filenum = int(subdir, 16) * self.files_per_dir + int(filename, 16)
        row = (filenum * self.rows_per_file) + (offset // self.row_size)
        return row

    def _remove_or_truncate_file(self, subdir, filename, offset=0):
        """Remove the given file, and remove the subdirectory too
        if it's empty.  If offset is nonzero, truncate the file
        to that size instead."""
        # Close potentially open file in file_open LRU cache
        self.file_open.cache_remove(self, subdir, filename)
        if offset:
            # Truncate it
            with open(os.path.join(self.root, subdir, filename), "r+b") as f:
                f.truncate(offset)
        else:
            # Remove file
            os.remove(os.path.join(self.root, subdir, filename))
            # Try deleting subdir, too
            try:
                os.rmdir(os.path.join(self.root, subdir))
            except Exception:
                pass

    # Cache open files
    @nilmdb.utils.lru_cache(size=fd_cache_size,
                            onremove=lambda f: f.close())
    def file_open(self, subdir, filename):
        """Open and map a given 'subdir/filename' (relative to self.root).
        Will be automatically closed when evicted from the cache."""
        # Create path if it doesn't exist
        try:
            os.mkdir(os.path.join(self.root, subdir))
        except OSError:
            pass
        # Return a rocket.Rocket object, which contains the open file
        return rocket.Rocket(self.layout,
                             os.path.join(self.root, subdir, filename))

    def append_data(self, data, start, end, binary=False):
        """Parse the formatted string in 'data', according to the
        current layout, and append it to the table.  If any timestamps
        are non-monotonic, or don't fall between 'start' and 'end',
        a ValueError is raised.

        Note that data is always of 'bytes' type.

        If 'binary' is True, the data should be in raw binary format
        instead: little-endian, matching the current table's layout,
        including the int64 timestamp.

        If this function succeeds, it returns normally.  Otherwise,
        the table is reverted back to its original state by truncating
        or deleting files as necessary."""
        data_offset = 0
        last_timestamp = nilmdb.utils.time.min_timestamp
        tot_rows = self.nrows
        count = 0
        linenum = 0
        try:
            while data_offset < len(data):
                # See how many rows we can fit into the current file,
                # and open it
                (subdir, fname, offs, count) = self._offset_from_row(tot_rows)
                f = self.file_open(subdir, fname)

                # Ask the rocket object to parse and append up to "count"
                # rows of data, verifying things along the way.
                try:
                    if binary:
                        appender = f.append_binary
                    else:
                        appender = f.append_string
                    (added_rows, data_offset, last_timestamp, linenum
                     ) = appender(count, data, data_offset, linenum,
                                  start, end, last_timestamp)
                except rocket.ParseError as e:
                    (linenum, colnum, errtype, obj) = e.args
                    if binary:
                        where = "byte %d: " % (linenum)
                    else:
                        where = "line %d, column %d: " % (linenum, colnum)
                    # Extract out the error line, add column marker
                    try:
                        if binary:
                            raise IndexError
                        bad = data.splitlines()[linenum-1]
                        bad += b'\n' + b' ' * (colnum - 1) + b'^'
                    except IndexError:
                        bad = b""
                    if errtype == rocket.ERR_NON_MONOTONIC:
                        err = "timestamp is not monotonically increasing"
                    elif errtype == rocket.ERR_OUT_OF_INTERVAL:
                        if obj < start:
                            err = sprintf("Data timestamp %s < start time %s",
                                          timestamp_to_string(obj),
                                          timestamp_to_string(start))
                        else:
                            err = sprintf("Data timestamp %s >= end time %s",
                                          timestamp_to_string(obj),
                                          timestamp_to_string(end))
                    else:
                        err = str(obj)
                    bad_str = bad.decode('utf-8', errors='backslashreplace')
                    raise ValueError("error parsing input data: " +
                                     where + err + "\n" + bad_str)
                tot_rows += added_rows
        except Exception:
            # Some failure, so try to roll things back by truncating or
            # deleting files that we may have appended data to.
            cleanpos = self.nrows
            while cleanpos <= tot_rows:
                (subdir, fname, offs, count) = self._offset_from_row(cleanpos)
                self._remove_or_truncate_file(subdir, fname, offs)
                cleanpos += count
            # Re-raise original exception
            raise
        else:
            # Success, so update self.nrows accordingly
            self.nrows = tot_rows

    def get_data(self, start, stop, binary=False):
        """Extract data corresponding to Python range [n:m],
        and returns a formatted string"""
        if (start is None or stop is None or
                start > stop or start < 0 or stop > self.nrows):
            raise IndexError("Index out of range")

        ret = []
        row = start
        remaining = stop - start
        while remaining > 0:
            (subdir, filename, offset, count) = self._offset_from_row(row)
            if count > remaining:
                count = remaining
            f = self.file_open(subdir, filename)
            if binary:
                ret.append(f.extract_binary(offset, count))
            else:
                ret.append(f.extract_string(offset, count))
            remaining -= count
            row += count
        return b"".join(ret)

    def __getitem__(self, row):
        """Extract timestamps from a row, with table[n] notation."""
        if row < 0 or row >= self.nrows:
            raise IndexError("Index out of range")
        (subdir, filename, offset, count) = self._offset_from_row(row)
        f = self.file_open(subdir, filename)
        return f.extract_timestamp(offset)

    def _remove_rows(self, subdir, filename, start, stop):
        """Helper to mark specific rows as being removed from a
        file, and potentially remove or truncate the file itself."""
        # Close potentially open file in file_open LRU cache
        self.file_open.cache_remove(self, subdir, filename)

        # We keep a file like 0000.removed that contains a list of
        # which rows have been "removed".  Note that we never have to
        # remove entries from this list, because we never decrease
        # self.nrows, and so we will never overwrite those locations in the
        # file.  Only when the list covers the entire extent of the
        # file will that file be removed.
        datafile = os.path.join(self.root, subdir, filename)
        cachefile = datafile + b".removed"
        try:
            with open(cachefile, "rb") as f:
                ranges = pickle.load(f)
            cachefile_present = True
        except Exception:
            ranges = []
            cachefile_present = False

        # Append our new range and sort
        ranges.append((start, stop))
        ranges.sort()

        # Merge adjacent ranges into "out"
        merged = []
        prev = None
        for new in ranges:
            if prev is None:
                # No previous range, so remember this one
                prev = new
            elif prev[1] == new[0]:
                # Previous range connected to this new one; extend prev
                prev = (prev[0], new[1])
            else:
                # Not connected; append previous and start again
                merged.append(prev)
                prev = new
        # Last range we were looking at goes into the file.  We know
        # there was at least one (the one we just removed).
        merged.append(prev)

        # If the range covered the whole file, we can delete it now.
        # Note that the last file in a table may be only partially
        # full (smaller than self.rows_per_file).  We purposely leave
        # those files around rather than deleting them, because the
        # remainder will be filled on a subsequent append(), and things
        # are generally easier if we don't have to special-case that.
        if (len(merged) == 1 and
                merged[0][0] == 0 and merged[0][1] == self.rows_per_file):
            # Delete files
            if cachefile_present:
                os.remove(cachefile)
            self._remove_or_truncate_file(subdir, filename, 0)
        else:
            # File needs to stick around.  This means we can get
            # degenerate cases where we have large files containing as
            # little as one row.  Try to punch a hole in the file,
            # so that this region doesn't take up filesystem space.
            offset = start * self.row_size
            count = (stop - start) * self.row_size
            nilmdb.utils.fallocate.punch_hole(datafile, offset, count)

            # Update cache.  Try to do it atomically.
            nilmdb.utils.atomic.replace_file(cachefile,
                                             pickle.dumps(merged, 2))

    def remove(self, start, stop):
        """Remove specified rows [start, stop) from this table.

        If a file is left empty, it is fully removed.  Otherwise, a
        parallel data file is used to remember which rows have been
        removed, and the file is otherwise untouched."""
        if start < 0 or start > stop or stop > self.nrows:
            raise IndexError("Index out of range")

        row = start
        remaining = stop - start
        while remaining:
            # Loop through each file that we need to touch
            (subdir, filename, offset, count) = self._offset_from_row(row)
            if count > remaining:
                count = remaining
            row_offset = offset // self.row_size
            # Mark the rows as being removed
            self._remove_rows(subdir, filename, row_offset, row_offset + count)
            remaining -= count
            row += count
