import logging
import time
from abc import ABC, abstractmethod
from multiprocessing.connection import Connection
from typing import Union
from functools import singledispatchmethod

import serial

logger = logging.getLogger(__name__)
dstat_logger = logging.getLogger("{}.DSTAT".format(__name__))
exp_logger = logging.getLogger("{}.Experiment".format(__name__))


class BaseExperimentProcess(ABC):
    def __init__(self, *args, **kwargs):
        self.data_bytes = 0
        self.serial: Union[None, serial.Serial] = None
        self.ctrl_pipe: Union[None, Connection] = None
        self.data_pipe: Union[None, Connection] = None

        self.commands = []

    @singledispatchmethod
    def parse_command_string(self, cmd_str, params: dict):
        raise NotImplementedError("cmd_str has unrecognized type %s", type(cmd_str))

    @parse_command_string.register
    def _(self, cmd_str: str, params: dict):
        self.commands += [cmd_str.format(**params)]

    @parse_command_string.register
    def _(self, cmd_str: tuple, params: dict):
        cmd: str
        data: tuple
        cmd, data = cmd_str

        cmd = cmd.format(**params)
        data_parsed = []
        for i in data:
            data_parsed += [str(i) for i in params[i]]

        self.commands += [(cmd, data_parsed)]

    def write_command(self, cmd, params=None, retry=5):
        """Write command to serial with optional number of retries."""

        def get_reply(retries=3):
            while True:
                reply = self.serial.readline().rstrip().decode('ascii')
                print(reply)
                if reply.startswith('#'):
                    dstat_logger.info(reply)
                elif reply == "":
                    retries -= 1
                    if retries <= 0:
                        raise serial.SerialException()
                else:
                    return reply

        n = len(cmd)
        if params is not None:
            n_params = len(params)

        for _ in range(retry):
            tries = 5
            while True:
                time.sleep(0.2)
                self.serial.reset_input_buffer()
                self.serial.write(f'!{n}\n'.encode('ascii'))
                time.sleep(.1)

                try:
                    reply = get_reply()
                except serial.SerialException:
                    if tries <= 0:
                        continue
                    tries -= 1
                    pass
                else:
                    break

            if reply != f"@ACK {n}":
                logger.warning("Expected ACK got: {}".format(reply))
                continue

            tries = 5

            while True:
                self.serial.write(f'{cmd}\n'.encode('ascii'))
                print(cmd)
                try:
                    reply = get_reply()
                except serial.SerialException:
                    if tries <= 0:
                        continue
                    tries -= 1
                    pass
                else:
                    break

            if reply != f"@RCV {n}":
                logger.warning("Expected RCV got: {}".format(reply))
                continue

            if params is None:
                return True

            tries = 5

            while True:
                try:
                    reply = get_reply()
                except serial.SerialException:
                    if tries <= 0:
                        continue
                    tries -= 1
                    pass
                else:
                    break

            if reply != f"@RQP {n_params}":
                logger.warning("Expected RQP %s got: %s", n_params, reply)
                continue

            for i in params:
                tries = 5
                while tries:
                    self.serial.write(f"{i} ".encode('ascii'))
                    try:
                        reply = get_reply()
                        if reply == f"@RCVC {i}":
                            break
                        elif reply == "@RCVE":
                            continue
                    except serial.SerialException:
                        tries -= 1
                        continue
                else:
                    logger.error('Communication failure')
                    return False
            return True
        return False

    def run(self, ser: serial.Serial, ctrl_pipe: Connection, data_pipe: Connection):
        """Execute experiment. Connects and sends handshake signal to DStat
        then sends self.commands. Don't call directly as a process in Windows,
        use run_wrapper instead.
        """
        self.serial = ser
        self.ctrl_pipe = ctrl_pipe
        self.data_pipe = data_pipe

        exp_logger.info('Experiment running')
        status = 'DONE'

        try:
            for i in self.commands:
                status = 'DONE'
                if isinstance(i, str):
                    logger.info('Command: %s', i)
                    print(f'command: {i}')

                    if not self.write_command(i):
                        status = 'ABORT'
                        break

                else:
                    cmd, data = i
                    print(f'command: {i}')
                    logger.info("Command: %s", cmd)

                    if not self.write_command(cmd, params=data):
                        status = 'ABORT'
                        break

                if not self.serial_handler():
                    status = 'ABORT'
                    break

                time.sleep(0.25)

        except serial.SerialException:
            status = 'SERIAL_ERROR'
        finally:
            while self.ctrl_pipe.poll():
                self.ctrl_pipe.recv()
        return status

    @abstractmethod
    def serial_handler(self):
        pass


class ExperimentProcess(BaseExperimentProcess):
    def __init__(self):
        """Adds commands for gain and ADC."""
        super(ExperimentProcess, self).__init__()
        self.datapoint = 0
        self.scan = 0
        self.time = 0

    def serial_handler(self):
        """Handles incoming serial transmissions from DStat. Returns False
        if stop button pressed and sends abort signal to instrument. Sends
        data to self.data_pipe as result of self.data_handler).
        """
        scan = 0
        start = None
        try:
            while True:
                if self.ctrl_pipe.poll():
                    ctrl = self.ctrl_pipe.recv()
                    logger.debug("serial_handler: %s", ctrl)
                    if ctrl == "DISCONNECT":
                        self.serial.write(b'a')
                        self.serial.reset_input_buffer()
                        logger.info("serial_handler: ABORT pressed!")
                        time.sleep(.3)
                        return False
                    elif ctrl == 'a':
                        self.serial.write(b'a')

                for line in self.serial:
                    if self.ctrl_pipe.poll():
                        if self.ctrl_pipe.recv() == 'a':
                            self.serial.write(b'a')

                    if line.startswith(b'B'):
                        if not start:
                            start = time.perf_counter()
                        data = (time.perf_counter()-start, scan, self.serial.read(size=self.data_bytes))
                        self.data_pipe.send(data)

                    elif line.lstrip().startswith(b'S'):
                        scan += 1

                    elif line.lstrip().startswith(b"#"):
                        dstat_logger.info(line.strip().decode('ascii'))

                    elif line.lstrip().startswith(b"@DONE"):
                        dstat_logger.debug(line.strip().decode('ascii'))
                        return True

        except serial.SerialException:
            return False
