import logging
from multiprocessing.connection import Connection
from typing import Union

from serial import Serial, SerialException

from .experiment_process import BaseExperimentProcess
from dstat_interface.core.dstat.state import DStatState

logger = logging.getLogger(__name__)
dstat_logger = logging.getLogger("{}.DSTAT".format(__name__))
exp_logger = logging.getLogger("{}.Experiment".format(__name__))

state = DStatState()


def abs_mv_to_dac(mv: float) -> int:
    dac = int(mv / state.board_instance.re_voltage_scale * (65536. / 3000) + 32768)
    assert 0 <= dac <= 65535
    return dac


def rel_mv_to_dac(mv: float) -> int:
    dac = int(mv / state.board_instance.re_voltage_scale * (65536. / 3000))
    assert 0 < dac <= 65535  # DStat params that take relative mv don't accept 0
    return dac


def param_test_uint16(u: int) -> int:
    assert 0 <= u <= 65535
    return int(u)


def param_test_non_zero_uint16(u: int) -> int:
    assert 0 < u <= 65535
    return int(u)


def param_test_uint8(u: int) -> int:
    assert 0 < u <= 255
    return int(u)


def param_test_non_zero_uint8(u: int) -> int:
    assert 0 < u <= 255
    return int(u)


class VersionCheckProcess(BaseExperimentProcess):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.commands += ['V', 'X']

    def serial_handler(self):
        try:
            while True:
                for line in self.serial:
                    if line.lstrip().startswith(b"#"):
                        dstat_logger.info(line.strip().decode('ascii'))
                    elif line.lstrip().startswith(b"@DONE"):
                        dstat_logger.debug(line.strip().decode('ascii'))
                        return True
                    elif line.startswith(b'V'):
                        ver_str = line.lstrip(b'V').decode('ascii')
                        pcb, sep, firmware = ver_str.strip().rpartition('-')

                        if pcb == "":
                            pcb = firmware
                            logger.info('Your firmware does not support PCB version detection.')
                            self.data_pipe.send(('V', (pcb, False)))
                        else:
                            logger.info(f"Firmware Version: {hex(int(firmware)).lstrip('0x')}")
                            self.data_pipe.send(('V', (pcb, hex(int(firmware)).lstrip('0x'))))

                        logger.info(f'PCB Version: {pcb}')
                    elif line.startswith(b'X'):
                        mux_str = line.lstrip(b'X').decode('ascii')
                        self.data_pipe.send(('X', int(mux_str)))

        except SerialException:
            return False


class SettingsProcess(BaseExperimentProcess):
    def __init__(self, task, settings: Union[None, dict[str, str]] = None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.task = task
        self.settings = settings

        self.serial: Union[None, Serial] = None
        self.ctrl_pipe: Union[None, Connection] = None
        self.data_pipe: Union[None, Connection] = None

        if self.task == 'w':
            write_buffer = [str(i) for i in self.settings.values()]
            to_write = ' '.join(write_buffer) + ' '
            logger.debug('to_write = %s', to_write)

            self.commands += ['SW' + to_write]

        elif self.task == 'r':
            self.commands += ['SR']

    def serial_handler(self):
        ser_logger = logging.getLogger(f'{__name__}._serial_process')
        try:
            while True:
                for line in self.serial:
                    if line.lstrip().startswith(b"#"):
                        ser_logger.info(line.strip().decode('ascii'))

                    elif line.lstrip().startswith(b"@DONE"):
                        ser_logger.debug(line.strip().decode('ascii'))
                        return True

                    if line.lstrip().startswith(b'S'):
                        input_line = line.lstrip()[1:].decode('ascii')
                        parted = input_line.rstrip().split(':')
                        settings = {key: int(value) for key, value in [i.split('.') for i in parted]}

                        self.data_pipe.send(settings)
        except SerialException:
            return False
