from copy import copy
from io import IOBase
from math import ceil, floor
import logging
import struct
from time import sleep
import re
from typing import Callable, Union

import numpy as np

logger = logging.getLogger("dstat.simulator")
digit_pattern = re.compile(rb'(\d+)')


class Simulator(object):
    default_settings = {'max5443_offset': 0,
                        'tcs_enabled': 1,
                        'tcs_clear_threshold': 10000,
                        'r100_trim': 0,
                        'r3k_trim': 0,
                        'r30k_trim': 0,
                        'r300k_trim': 0,
                        'r3M_trim': 0,
                        'r30M_trim': 0,
                        'r100M_trim': 0,
                        'eis_cal1': 0,
                        'eis_cal2': 0,
                        'dac_units_true': 1}

    def __init__(self):
        self.current_state = self.main
        self.next_state: Union[None, Callable] = None
        self.output: bytes = b''
        self.input_str: bytes = b''
        self.cmd_length: int = 0
        self.command_str: bytes = b''

        self.waiting_for_params: int = 0
        self.extra_params = []
        self.current_cmd: bytes = b''
        self.current_params: list = []

        self.settings_dict = copy(Simulator.default_settings)

    def _get_input(self, size=1) -> bytes:
        if size > len(self.input_str):
            logger.error("Not enough characters in input buffer")
            raise IndexError("Not enough characters in input buffer")

        string = self.input_str[:size]
        self.input_str = self.input_str[size:]
        return string

    def _get_command(self):
        string = self.command_str[:1]
        self.command_str = self.command_str[1:]
        return string

    def _get_params(self, n):
        params = self.command_str.split(None, n)
        if len(params) > n:
            self.input_str = params.pop()
        elif len(params) < n:
            raise IndexError("Not enough characters in command buffer")
        return params

    def input(self, string):
        logger.debug("input: {}".format(string))
        self.input_str = string
        while len(self.input_str) > 0:
            self.current_state()

    def main(self):
        char = self._get_input()

        if char == b'!':
            try:
                match = re.match(digit_pattern, self.input_str)
                cmd_length = match.group(1)
            except AttributeError:
                logger.error("No character count received")
                self.input_str = b''
                return
            self.cmd_length = int(cmd_length)
            self.input_str = self.input_str[len(cmd_length) + 1:]
            self.output += f'@ACK {self.cmd_length}\n'.encode('ascii')
            if self.cmd_length == 0:
                self.output += b'@RCV 0\n'
                return
            self.current_state = self.command

            if self.input_str:
                self.current_state()

    def command(self):
        def restart():
            logger.info("USB restart received")
            return

        def version():
            logger.info("Version Check")
            self.output += f'V1.2.3-{0xfffffff:d}\n'.encode('ascii')
            return

        command_map = {b'E': self.experiment,
                       b'S': self.settings,
                       b'R': restart,
                       b'V': version}

        self.command_str = self._get_input(self.cmd_length + 1)
        self.output += f'@RCV {self.cmd_length}\n'.encode('ascii')

        char = self.command_str[:1]
        self.command_str = self.command_str[1:-1]  # Remove delimiter

        try:
            if command_map[char]():
                self.current_state = self.input_params
                return
        except KeyError:
            logger.warning("Unrecognized command {}".format(char))
            self.output += f"#ERR: Command {char} not recognized\n".encode('ascii')

        if self.current_state.__name__ != self.abort_wait.__name__:
            self.output += b"@DONE\n"
            self.current_state = self.main
        else:
            logger.info("abort_wait state")
        logger.info("Command {} Finished".format(char))

    def abort_wait(self):
        logger.info("abort_wait() called")
        self.current_state = self.abort_wait
        if self._get_input() == b'a':
            logger.info("Abort signal received")
            self.output += b"@DONE\n"
            self.current_state = self.main

    def input_params(self):
        param = self._get_params(1)
        self.extra_params.append(param)
        self.output += f'@RCVC {param}\n'.encode('ascii')
        self.waiting_for_params -= 1

        if self.waiting_for_params == 0:
            self.next_state()

    def settings(self):
        def reset():
            logger.info("Settings reset")
            self.settings_dict = self.default_settings.copy()

        def firmware():
            logger.info("Firmware update mode")

        def read():
            self.output += b"S"
            for key, value in self.settings_dict.items():
                self.output += f"{key}.{value}:".encode('ascii')
            self.output = self.output.rstrip(b':')
            self.output += b"\n"

        def write():
            params = self._get_params(10)
            for n, i in enumerate(self.settings_dict):
                self.settings_dict[i] = int(params[n])

        settings_map = {b'D': reset,
                        b'F': firmware,
                        b'R': read,
                        b'W': write
                        }

        char = self.command_str[:1]
        self.command_str = self.command_str[1:]

        try:
            settings_map[char]()
        except KeyError:
            logger.warning("Settings control %s not found", char)

    def experiment(self):
        def ads1255():
            params = self._get_params(3)
            logger.info("ADS1255 params: %s", params)
            self.output += f'#A: {params[0]} {params[1]} {params[2]}\n'.encode('ascii')

        def gain():
            params = self._get_params(2)
            logger.info("IV gain: %s", params)
            self.output += f'#G: {params[0]} {params[1]}\n'.encode('ascii')

        def lsv():
            params = self._get_params(7)
            start = int(params[-3])
            stop = int(params[-2])
            slope = int(params[-1])

            logger.info("LSV params: %s", params)

            for i in np.arange(start, stop, slope//10):
                self.output += b"B\n"
                self.output += struct.pack('<Hl', i, 500 * (i-32698))
                self.output += b"\n"

        def cv():
            params = self._get_params(9)
            start = int(params[-3])
            v1 = int(params[-5])
            v2 = int(params[-4])
            scans = int(params[-2])
            slope = int(params[-1])

            logger.info("CV params: %s", params)

            for scan in range(scans):
                for i in np.arange(start, v1, slope//100):
                    self.output += b"B\n"
                    self.output += struct.pack('<Hl', i, 500 * (i-32698) + 100 * scan)
                    self.output += b"\n"
                for i in np.arange(v1, v2, slope//100):
                    self.output += b"B\n"
                    self.output += struct.pack('<Hl', i, 500 * (i - 32698) + 100 * scan)
                    self.output += b"\n"
                for i in np.arange(v2, start, slope//100):
                    self.output += b"B\n"
                    self.output += struct.pack('<Hl', i, 500 * (i - 32698) + 100 * scan)
                    self.output += b"\n"
                self.output += b"S\n"
            self.output += b"D\n"

        def swv():
            params = self._get_params(10)
            start = int(params[-6])
            stop = int(params[-5])
            step = int(int(params[-4])*3000/65536.)
            scans = int(params[-1])

            logger.info("SWV params: %s", params)

            if scans < 1:
                scans = 1

            for scan in range(scans):
                for i in np.arange(start, stop, step):
                    self.output += b"B\n"
                    self.output += struct.pack('<Hll', i, 500 * (i - 32698) + 100 * scan, 0)
                    self.output += b"\n"
                for i in np.arange(stop, start, step):
                    self.output += b"B\n"
                    self.output += struct.pack('<Hll', i, 500 * (i - 32698) + 100 * scan, 0)
                    self.output += b"\n"
                self.output += b"S\n"
            self.output += b"D\n"

        def dpv():
            params = self._get_params(10)
            start = int(params[-6])
            stop = int(params[-5])

            logger.info("DPV params: %s", params)

            for i in np.arange(start, stop, 1):
                self.output += b"B\n"
                self.output += struct.pack('<Hll', i, 500 * (i - 32698), 0)
                self.output += b"\n"

            self.output += b"D\n"

        def pmt():
            logger.info("PMT Idle mode entered")
            self.abort_wait()

        def pot():
            params = self._get_params(2)
            seconds = int(params[0])
            logger.info("POT params: %s", params)

            if seconds == 0:
                self.abort_wait()
            else:
                for i in range(seconds):
                    self.output += b"B\n"
                    self.output += struct.pack('<Hll', i, 0, 100)

        def ca():
            params = self._get_params(2)
            steps = int(params[0])
            tcs = int(params[1])

            self.current_params = [steps, tcs]
            self.waiting_for_params = steps * 2
            self.output += f"@RQP {steps * 2}\n".encode("ascii")
            self.next_state = self.experiment
            return True

        def ca_params(steps, tcs):
            times = [int(i) for i in self.extra_params[steps:]]
            seconds = sum(times)

            for i in range(seconds):
                logger.info(i)
                self.output += b"B\n"
                self.output += struct.pack('<HHl', i, 0, 100 * i)
                self.output += b"\n"
                sleep(1)
            self.extra_params = []

        def sync():
            params = self._get_params(1)
            logger.info('Shutter Sync %s Hz', params[0])

        def shut_off():
            logger.info('Shutter Sync Off')

        def shut_close():
            logger.info('Shutter closed')

        def shut_open():
            logger.info('Shutter open')

        experiment_map = {b'A': ads1255,
                          b'G': gain,
                          b'L': lsv,
                          b'C': cv,
                          b'S': swv,
                          b'D': dpv,
                          b'M': pmt,
                          b'P': pot,
                          b'R': ca,
                          b'Z': sync,
                          b'z': shut_off,
                          b'1': shut_close,
                          b'2': shut_open
                          }

        experiment_post_params_map = {b'R': ca_params}

        if self.next_state:
            experiment_post_params_map[self.current_cmd](*self.current_params)
            self.current_params = []
            self.next_state = None
        else:
            char = self._get_command()

            try:
                return experiment_map[char]()
            except KeyError:
                logger.warning('Unrecognized exp command %s', char)
                self.output += f'#ERR: Command {char} not recognized\n'.encode('ascii')


class SerialSim(IOBase):
    def __init__(self, *args, **kwargs):
        self.sim = Simulator()
        self.is_open = True

    def open(self):
        self.is_open = True

    def close(self):
        self.is_open = False
        self.reset_input_buffer()

    def write(self, string: bytes):
        self.sim.input(string)

    def read(self, size=1) -> bytes:
        output = self.sim.output[0:size]
        self.sim.output = self.sim.output[size:]
        return output

    def reset_input_buffer(self):
        self.sim.output = b""

    def next(self) -> bytes:
        if self.sim.output == b"":
            raise StopIteration

        return self.readline()

    def readline(self, size=-1) -> bytes:
        if len(self.sim.output) == 0:
            return b""

        if size > 0:
            output, sep, remain = self.sim.output[:size].partition(b'\n')
        else:
            output, sep, remain = self.sim.output.partition(b'\n')

        if sep != b'\n':
            return b""

        self.sim.output = self.sim.output[len(output) + 1:]

        return output + b'\n'

    def __iter__(self):
        return self

    def _update_dtr_state(self):
        pass
