import logging
import struct
from abc import ABC, abstractmethod, ABCMeta
from typing import Type, Callable, Union, Any, List

from ..dstat.comm import SerialConnection
from ..dstat.experiment_handler import ExperimentHandler
from ..dstat.experiment_process import BaseExperimentProcess
from ..dstat.state import DStatState
from ..dstat.boards import V1_2Board

logger = logging.getLogger(__name__)

state = DStatState()
board = V1_2Board()


class ExperimentContainer(ABC, metaclass=ABCMeta):
    dstat_params = {'buffer_false': bool, 'adc_rate': str,
                    'adc_pga': str, 'gain': int,
                    'short_true': bool, 'mux_channel': int}

    @abstractmethod
    def __init__(self, params: dict[str, Any], mux: int = 1):
        self.mux = mux

        self.data_cols: Union[None, list[str]] = None
        self.data_format: str = ''
        self.cmd_str: list[str] = ["EA{buffer} {adc_rate:x} {adc_pga:x} ",
                                   "EG{gain} {short_true:d} "]

        self.adc_rates = {label: code for label, code
                          in zip(board.adc_rate_labels, board.adc_rate_codes)}
        self.adc_pga = {label: code for label, code
                        in zip(board.adc_pga_labels, board.adc_pga_codes)}

        self.param_input: dict[str, Union[Callable, List]] = {'buffer_false': bool,
                                                              'adc_rate': list(self.adc_rates.keys()),
                                                              'adc_pga': list(self.adc_pga.keys()),
                                                              'gain': int,
                                                              'short_true': bool,
                                                              }
        self.param_input_limits: dict[str, tuple[Union[str, int], Union[str, int]]] = {}
        self.param_input_display_names: dict[str, str] = {}
        self.param_tables: dict[str, list[str]] = {}

        self.param_docs: dict[str, str] = {'buffer_false': 'Disable ADC buffer',
                                           'adc_rate': 'ADC sample rate',
                                           'adc_pga': 'ADC PGA value',
                                           'gain': 'IV converter gain',
                                           'short_true': 'Short RE and CE internally'}
        self.param_converters: dict[str, Callable] = {}
        self.defaults: dict = {'buffer_false': False, 'short_true': False, 'adc_rate': '60', 'adc_pga': '2',
                               'gain': 3000}
        self.params = params
        self.progress_max = 0
        self.progress_iters = 0
        self.progress_scan = 0

        self.handler_instance: Union[None, ExperimentHandler] = None

        self.commands: List[str] = []
        self.calculate_data_bytes()

        if mux > 1:
            self.cmd_str += ['EX{mux_channel} ']
            self.param_input['mux_channel'] = list(range(mux))
            self.param_docs['mux_channel'] = 'Multiplexer Channel'
            self.defaults['mux_channel'] = 0

    @property
    @abstractmethod
    def experiment_id(self) -> Type[str]:
        raise NotImplementedError

    @property
    @abstractmethod
    def display_name(self) -> Type[str]:
        raise NotImplementedError

    @property
    @abstractmethod
    def data_bytes(self) -> Type[ExperimentHandler]:
        raise NotImplementedError

    @property
    @abstractmethod
    def handler(self) -> Type[ExperimentHandler]:
        raise NotImplementedError

    @property
    @abstractmethod
    def process(self) -> Type[BaseExperimentProcess]:
        raise NotImplementedError

    @property
    def plots(self) -> dict[str, dict[str, str]]:
        return {'current': {'x': 'voltage', 'y': 'current'}}

    def calculate_data_bytes(self) -> int:
        struct_formats = ['H' if i.islower() else 'l' for i in self.data_format]
        return struct.Struct(f'<{"".join(struct_formats)}').size

    def add_missing_from_defaults(self):
        params = self.params
        self.params = self.defaults.copy()
        self.params |= params

        self.commands = []
        for i in self.cmd_str:
            try:
                cmd, params = i
                param_list = []
                for param in params:
                    param_list += self.params[param]

                self.commands += [(cmd.format(**self.params), param_list)]
            except ValueError:
                self.commands += [i.format(**self.params)]

    def parse_params(self):
        self.params |= {key: self.param_converters[key](value) for key, value in self.params.items()
                        if key in self.param_converters}

        self.params['adc_rate'] = self.adc_rates[self.params['adc_rate']]
        self.params['adc_pga'] = self.adc_pga[self.params['adc_pga']]
        self.params['gain'] = state.board_instance.gain.index(self.params['gain'])

        if self.params['buffer_false']:
            self.params['buffer'] = '0'
        else:
            self.params['buffer'] = '2'
        logger.info(self.params)

    def get_proc(self) -> BaseExperimentProcess:
        proc = self.process()
        self.parse_params()

        for i in self.cmd_str:
            proc.parse_command_string(i, self.params)

        proc.data_bytes = self.data_bytes
        return proc

    def start_handler(self, ser: SerialConnection):
        self.handler_instance = self.handler(ser=ser, data_cols=self.data_cols, data_format=self.data_format)
        self.handler_instance.adc_gain = state.board_instance.gain[self.params['gain']]
        self.handler_instance.adc_pga = state.board_instance.adc_pga[
            state.board_instance.adc_pga_codes.index(self.params['adc_pga'])]

    @abstractmethod
    def get_progress(self) -> float:
        pass
