import serial
import time
import sys
import asyncio

from .errors import CaptureError
import logging

import numpy as np
from joule.utilities import (timestamp_to_human,
                              seconds_to_timestamp, time_now)

ROW_BYTES = 8 * 2 + 2 + 2  # 8 channels, 2 byte alignment, 2 byte status
ROWS_PER_BLOCK = 3000  # read ~3 blocks per second
BLOCK_SIZE = ROW_BYTES * ROWS_PER_BLOCK


class SerialCapture():

    def __init__(self, device_path):
        # tunable constants
        self.data_ts_inc = 1e6 / 3000.0  # 3KHz sampling
        self.max_gap = seconds_to_timestamp(10)
        self.device = serial.Serial(device_path, 3000000, timeout=2)
        self.stop_requested = False

    async def run(self, output,
                  max_gap=10,
                  align=True,
                  nrows=0):
        self._start_meter()
        # set up timestamps
        self.clock_ts = time_now()
        self.data_ts = self.clock_ts
        rows_processed = 0
        # convert gap to microseconds
        self.max_gap = seconds_to_timestamp(max_gap)

        while(not self.stop_requested):
            if(align):
                self._align_clock()
            # collapse on device read error
            await asyncio.sleep(0.1)
            data = self.device.read(BLOCK_SIZE)
            # recover on data parsing errors
            try:
                np_data = self._parse(data)
                await output.write(np_data)
                rows_processed += len(np_data)
                if(nrows > 0 and rows_processed >= nrows):
                    self.stop_requested = True

            except CaptureError as e:
                logging.error("Capture Error: [%s] restarting data stream" % e)
                sys.stderr.flush()
                self._start_meter()
                self.clock_ts = time_now()
                self.data_ts = self.clock_ts
            sys.stderr.flush()

    def stop(self):
        self.stop_requested = True

    def close(self):
        self.device.write(b'o')
        self.device.close()

    def _start_meter(self):
        # turn off the ADC in case it was left on
        self.device.write(b'o')
        time.sleep(0.1)
        # clear out old data of uncertain provenance
        self.device.reset_input_buffer()
        # while len(self.device.read()) == 1:
        #    pass
        # now we can turn on the ADC for real
        self.device.write(b's')
        # realign the buffer to a complete line
        self._realign()

    def _align_clock(self):
        self.clock_ts = time_now()
        if (self.data_ts - self.max_gap) > self.clock_ts:
            print("Data is coming in too fast: data time "
                  "is %s but clock time is only %s" %
                  (timestamp_to_human(self.data_ts),
                   timestamp_to_human(self.clock_ts)))
            exit(1)  # can't recover

        if (self.data_ts + self.max_gap) < self.clock_ts:
            print("Skipping data timestamp forward from "
                  "%s to %s to match clock time\n" % (
                      timestamp_to_human(self.data_ts),
                      timestamp_to_human(self.clock_ts)))
            self.data_ts = self.clock_ts

    def _realign(self):
        # read bytes from serial until we hit an alignment sequence
        buf = bytearray()
        while len(buf) < 2 or buf[-2] != 0x7F or buf[-1] != 0x80:
            char = self.device.read(1)
            if len(char) == 0:  # timeout, exit the capture
                raise CaptureError("Timeout")
            buf.append(ord(char))

    def _parse(self, raw):
        # parse from string: 2 byte units in big endian
        data = np.fromstring(raw, dtype=np.int16)
        if(len(raw) % ROW_BYTES != 0):
            raise CaptureError("short read")
        rows = len(raw) // ROW_BYTES
        # give the array the proper shape
        data.shape = (rows, 10)
        # calculate timestamps
        top_ts = self.data_ts + rows * self.data_ts_inc
        ts = np.array(np.linspace(self.data_ts, top_ts,
                                  rows, endpoint=False), dtype=np.uint64)
        ts.shape = (rows, 1)
        self.data_ts = top_ts  # update the timestamp for the next run

        # check for corruption: last 2 bytes of each line should be 0x7F80
        testsum = np.sum(data, axis=0, dtype=np.int64)
        checksum = -32641 * rows
        if(testsum[9] != checksum):
            raise CaptureError("checksum error")

        # add timestamps
        ts_data = np.hstack((ts, data[:, 1:9]))
        # now ts_data is the form
        # [ts, d0, d1, d2, d3, ...]
        # [ts, d0, d1, d2, d3, ...]
        #  ...
        return ts_data
