# -*- coding: utf-8 -*-

"""Class for performing HTTP client requests via libcurl"""

import json
import contextlib

import nilmdb.utils
import nilmdb.client.httpclient
from nilmdb.client.errors import ClientError
from nilmdb.utils.time import timestamp_to_string, string_to_timestamp


def extract_timestamp(line):
    """Extract just the timestamp from a line of data text"""
    return string_to_timestamp(line.split()[0])


class Client():
    """Main client interface to the Nilm database."""

    def __init__(self, url, post_json=False):
        """Initialize client with given URL.  If post_json is true,
        POST requests are sent with Content-Type 'application/json'
        instead of the default 'x-www-form-urlencoded'."""
        self.http = nilmdb.client.httpclient.HTTPClient(url, post_json)
        self.post_json = post_json

    # __enter__/__exit__ allow this class to be a context manager
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def _json_post_param(self, data):
        """Return compact json-encoded version of parameter"""
        if self.post_json:
            # If we're posting as JSON, we don't need to encode it further here
            return data
        return json.dumps(data, separators=(',', ':'))

    def close(self):
        """Close the connection; safe to call multiple times"""
        self.http.close()

    def geturl(self):
        """Return the URL we're using"""
        return self.http.baseurl

    def version(self):
        """Return server version"""
        return self.http.get("version")

    def dbinfo(self):
        """Return server database info (path, size, free space)
        as a dictionary."""
        return self.http.get("dbinfo")

    def stream_list(self, path=None, layout=None, extended=False):
        """Return a sorted list of [path, layout] lists.  If 'path' or
        'layout' are specified, only return streams that match those
        exact values.  If 'extended' is True, the returned lists have
        extended info, e.g.: [path, layout, extent_min, extent_max,
        total_rows, total_seconds."""
        params = {}
        if path is not None:
            params["path"] = path
        if layout is not None:
            params["layout"] = layout
        if extended:
            params["extended"] = 1
        streams = self.http.get("stream/list", params)
        return nilmdb.utils.sort.sort_human(streams, key=lambda s: s[0])

    def stream_get_metadata(self, path, keys=None):
        """Get stream metadata"""
        params = {"path": path}
        if keys is not None:
            params["key"] = keys
        return self.http.get("stream/get_metadata", params)

    def stream_set_metadata(self, path, data):
        """Set stream metadata from a dictionary, replacing all existing
        metadata."""
        params = {
            "path": path,
            "data": self._json_post_param(data)
        }
        return self.http.post("stream/set_metadata", params)

    def stream_update_metadata(self, path, data):
        """Update stream metadata from a dictionary"""
        params = {
            "path": path,
            "data": self._json_post_param(data)
        }
        return self.http.post("stream/update_metadata", params)

    def stream_create(self, path, layout):
        """Create a new stream"""
        params = {
            "path": path,
            "layout": layout
        }
        return self.http.post("stream/create", params)

    def stream_destroy(self, path):
        """Delete stream.  Fails if any data is still present."""
        params = {
            "path": path
        }
        return self.http.post("stream/destroy", params)

    def stream_rename(self, oldpath, newpath):
        """Rename a stream."""
        params = {
            "oldpath": oldpath,
            "newpath": newpath
        }
        return self.http.post("stream/rename", params)

    def stream_remove(self, path, start=None, end=None):
        """Remove data from the specified time range"""
        params = {
            "path": path
        }
        if start is not None:
            params["start"] = timestamp_to_string(start)
        if end is not None:
            params["end"] = timestamp_to_string(end)
        total = 0
        for count in self.http.post_gen("stream/remove", params):
            total += int(count)
        return total

    @contextlib.contextmanager
    def stream_insert_context(self, path, start=None, end=None):
        """Return a context manager that allows data to be efficiently
        inserted into a stream in a piecewise manner.  Data is
        provided as ASCII lines, and is aggregated and sent to the
        server in larger or smaller chunks as necessary.  Data lines
        must match the database layout for the given path, and end
        with a newline.

        Example:
          with client.stream_insert_context('/path', start, end) as ctx:
            ctx.insert('1234567890.0 1 2 3 4\\n')
            ctx.insert('1234567891.0 1 2 3 4\\n')

        For more details, see help for nilmdb.client.client.StreamInserter

        This may make multiple requests to the server, if the data is
        large enough or enough time has passed between insertions.
        """
        ctx = StreamInserter(self, path, start, end)
        yield ctx
        ctx.finalize()
        ctx.destroy()

    def stream_insert(self, path, data, start=None, end=None):
        """Insert rows of data into a stream.  data should be a string
        or iterable that provides ASCII data that matches the database
        layout for path.  Data is passed through stream_insert_context,
        so it will be broken into reasonably-sized chunks and
        start/end will be deduced if missing."""
        with self.stream_insert_context(path, start, end) as ctx:
            if isinstance(data, bytes):
                ctx.insert(data)
            else:
                for chunk in data:
                    ctx.insert(chunk)
        return ctx.last_response

    def stream_insert_block(self, path, data, start, end, binary=False):
        """Insert a single fixed block of data into the stream.  It is
        sent directly to the server in one block with no further
        processing.

        If 'binary' is True, provide raw binary data in little-endian
        format matching the path layout, including an int64 timestamp.
        Otherwise, provide ASCII data matching the layout."""
        params = {
            "path": path,
            "start": timestamp_to_string(start),
            "end": timestamp_to_string(end),
        }
        if binary:
            params["binary"] = 1
        return self.http.put("stream/insert", data, params)

    def stream_intervals(self, path, start=None, end=None, diffpath=None):
        """
        Return a generator that yields each stream interval.

        If 'diffpath' is not None, yields only interval ranges that are
        present in 'path' but not in 'diffpath'.
        """
        params = {
            "path": path
        }
        if diffpath is not None:
            params["diffpath"] = diffpath
        if start is not None:
            params["start"] = timestamp_to_string(start)
        if end is not None:
            params["end"] = timestamp_to_string(end)
        return self.http.get_gen("stream/intervals", params)

    def stream_extract(self, path, start=None, end=None,
                       count=False, markup=False, binary=False):
        """
        Extract data from a stream.  Returns a generator that yields
        lines of ASCII-formatted data that matches the database
        layout for the given path.

        If 'count' is True, return a count of matching data points
        rather than the actual data.  The output format is unchanged.

        If 'markup' is True, include comments in the returned data
        that indicate interval starts and ends.

        If 'binary' is True, return chunks of raw binary data, rather
        than lines of ASCII-formatted data.  Raw binary data is
        little-endian and matches the database types (including an
        int64 timestamp).
        """
        params = {
            "path": path,
        }
        if start is not None:
            params["start"] = timestamp_to_string(start)
        if end is not None:
            params["end"] = timestamp_to_string(end)
        if count:
            params["count"] = 1
        if markup:
            params["markup"] = 1
        if binary:
            params["binary"] = 1
        return self.http.get_gen("stream/extract", params, binary=binary)

    def stream_count(self, path, start=None, end=None):
        """
        Return the number of rows of data in the stream that satisfy
        the given timestamps.
        """
        counts = list(self.stream_extract(path, start, end, count=True))
        return int(counts[0])


class StreamInserter():
    """Object returned by stream_insert_context() that manages
    the insertion of rows of data into a particular path.

    The basic data flow is that we are filling a contiguous interval
    on the server, with no gaps, that extends from timestamp 'start'
    to timestamp 'end'.  Data timestamps satisfy 'start <= t < end'.

    Data is provided to .insert() as ASCII formatted data separated by
    newlines.  The chunks of data passed to .insert() do not need to
    match up with the newlines; less or more than one line can be passed.

    1. The first inserted line begins a new interval that starts at
    'start'.  If 'start' is not given, it is deduced from the first
    line's timestamp.

    2. Subsequent lines go into the same contiguous interval.  As lines
    are inserted, this routine may make multiple insertion requests to
    the server, but will structure the timestamps to leave no gaps.

    3. The current contiguous interval can be completed by manually
    calling .finalize(), which the context manager will also do
    automatically.  This will send any remaining data to the server,
    using the 'end' timestamp to end the interval.  If no 'end'
    was provided, it is deduced from the last timestamp seen,
    plus a small delta.

    After a .finalize(), inserting new data goes back to step 1.

    .update_start() can be called before step 1 to change the start
    time for the interval.  .update_end() can be called before step 3
    to change the end time for the interval.
    """

    # See design.md for a discussion of how much data to send.  This
    # is a soft limit -- we might send up to twice as much or so
    _max_data = 2 * 1024 * 1024
    _max_data_after_send = 64 * 1024

    def __init__(self, client, path, start, end):
        """'client' is the client object.  'path' is the database
        path to insert to.  'start' and 'end' are used for the first
        contiguous interval and may be None."""
        self.last_response = None

        self._client = client
        self._path = path

        # Start and end for the overall contiguous interval we're
        # filling
        self._interval_start = start
        self._interval_end = end

        # Current data we're building up to send.  Each string
        # goes into the array, and gets joined all at once.
        self._block_data = []
        self._block_len = 0

        self.destroyed = False

    def destroy(self):
        """Ensure this object can't be used again without raising
        an error"""
        def error(*args, **kwargs):
            raise Exception("don't reuse this context object")
        self._send_block = self.insert = self.finalize = self.send = error

    def insert(self, data):
        """Insert a chunk of ASCII formatted data in string form.  The
        overall data must consist of lines terminated by '\\n'."""
        length = len(data)
        maxdata = self._max_data

        if length > maxdata:
            # This could make our buffer more than twice what we
            # wanted to send, so split it up.  This is a bit
            # inefficient, but the user really shouldn't be providing
            # this much data at once.
            for cut in range(0, length, maxdata):
                self.insert(data[cut:(cut + maxdata)])
            return

        # Append this string to our list
        self._block_data.append(data)
        self._block_len += length

        # Send the block once we have enough data
        if self._block_len >= maxdata:
            self._send_block(final=False)
            if self._block_len >= self._max_data_after_send:
                raise ValueError("too much data left over after trying"
                                 " to send intermediate block; is it"
                                 " missing newlines or malformed?")

    def update_start(self, start):
        """Update the start time for the next contiguous interval.
        Call this before starting to insert data for a new interval,
        for example, after .finalize()"""
        self._interval_start = start

    def update_end(self, end):
        """Update the end time for the current contiguous interval.
        Call this before .finalize()"""
        self._interval_end = end

    def finalize(self):
        """Stop filling the current contiguous interval.
        All outstanding data will be sent, and the interval end
        time of the interval will be taken from the 'end' argument
        used when initializing this class, or the most recent
        value passed to update_end(), or the last timestamp plus
        a small epsilon value if no other endpoint was provided.

        If more data is inserted after a finalize(), it will become
        part of a new interval and there may be a gap left in-between."""
        self._send_block(final=True)

    def send(self):
        """Send any data that we might have buffered up.  Does not affect
        any other treatment of timestamps or endpoints."""
        self._send_block(final=False)

    def _get_first_noncomment(self, block):
        """Return the (start, end) indices of the first full line in
        block that isn't a comment, or raise IndexError if
        there isn't one."""
        start = 0
        while True:
            end = block.find(b'\n', start)
            if end < 0:
                raise IndexError
            if block[start] != b'#'[0]:
                return (start, (end + 1))
            start = end + 1

    def _get_last_noncomment(self, block):
        """Return the (start, end) indices of the last full line in
        block[:length] that isn't a comment, or raise IndexError if
        there isn't one."""
        end = block.rfind(b'\n')
        if end <= 0:
            raise IndexError
        while True:
            start = block.rfind(b'\n', 0, end)
            if block[start + 1] != b'#'[0]:
                return ((start + 1), end)
            if start == -1:
                raise IndexError
            end = start

    def _send_block(self, final=False):
        """Send data currently in the block.  The data sent will
        consist of full lines only, so some might be left over."""
        # Build the full string to send
        block = b"".join(self._block_data)

        start_ts = self._interval_start
        if start_ts is None:
            # Pull start from the first line
            try:
                (spos, epos) = self._get_first_noncomment(block)
                start_ts = extract_timestamp(block[spos:epos])
            except (ValueError, IndexError):
                pass  # no timestamp is OK, if we have no data

        if final:
            # For a final block, it must end in a newline, and the
            # ending timestamp is either the user-provided end,
            # or the timestamp of the last line plus epsilon.
            end_ts = self._interval_end
            try:
                if block[-1] != b'\n'[0]:
                    raise ValueError("final block didn't end with a newline")
                if end_ts is None:
                    (spos, epos) = self._get_last_noncomment(block)
                    end_ts = extract_timestamp(block[spos:epos])
                    end_ts += nilmdb.utils.time.epsilon
            except (ValueError, IndexError):
                pass  # no timestamp is OK, if we have no data
            self._block_data = []
            self._block_len = 0

            # Next block is completely fresh
            self._interval_start = None
            self._interval_end = None
        else:
            # An intermediate block, e.g. "line1\nline2\nline3\nline4"
            # We need to save "line3\nline4" for the next block, and
            # use the timestamp from "line3" as the ending timestamp
            # for this one.
            try:
                (spos, epos) = self._get_last_noncomment(block)
                end_ts = extract_timestamp(block[spos:epos])
            except (ValueError, IndexError):
                # If we found no timestamp, give up; we could send this
                # block later when we have more data.
                return
            if spos == 0:
                # Not enough data to send an intermediate block
                return
            if self._interval_end is not None and end_ts > self._interval_end:
                # User gave us bad endpoints; send it anyway, and let
                # the server complain so that the error is the same
                # as if we hadn't done this chunking.
                end_ts = self._interval_end
            self._block_data = [block[spos:]]
            self._block_len = (epos - spos)
            block = block[:spos]

            # Next block continues where this one ended
            self._interval_start = end_ts

        # Double check endpoints
        if (start_ts is None or end_ts is None) or (start_ts == end_ts):
            # If the block has no non-comment lines, it's OK
            try:
                self._get_first_noncomment(block)
            except IndexError:
                return
            raise ClientError("have data to send, but no start/end times")

        # Send it
        self.last_response = self._client.stream_insert_block(
            self._path, block, start_ts, end_ts, binary=False)

        return
