from datetime import datetime
from typing import Any

from ..dstat.experiment_handler import ExperimentHandler
from ..dstat.experiment_process import ExperimentProcess
from ..dstat.utility import abs_mv_to_dac, rel_mv_to_dac, param_test_uint16, param_test_non_zero_uint8
from ..experiments.experiment_container import ExperimentContainer


class CAExperimentHandler(ExperimentHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data['time'] = []

    def data_handler(self, data_input: tuple[datetime, int, bytes]):
        unpacked: tuple[int, int, int]
        date, scan, data = data_input
        unpacked = self.struct.unpack(data)
        self.data['scan'].append(scan)
        self.data['timestamp'].append(date)

        for n, i in enumerate(unpacked):
            try:
                self.data[self.data_cols[n]].append(self.data_convert[n](i))
            except TypeError:  # If no converter
                self.data[self.data_cols[n]].append(i)
        self.data['time'].append(self.data['time_s'][-1] + self.data['time_ms'][-1])


class CAExperimentContainer(ExperimentContainer):
    experiment_id = 'ca'
    display_name = 'Chronoamperometry'
    process = ExperimentProcess
    handler = CAExperimentHandler
    data_bytes = 0

    def __init__(self, params: dict[str, Any], mux: int = 1):
        super().__init__(params, mux=mux)

        self.data_cols = ['time_s', 'time_ms', 'current']
        self.data_format = 'smA'

        self.cmd_str += [('ER{n_steps} 0 ', ('voltages', 'times'))]
        self.param_input |= {'times': list, 'voltages': list}
        self.param_input_display_names |= {'times': 'Time (s)', 'voltages': 'Voltage (mV)'}
        self.param_tables = {'Steps:': ['times', 'voltages']}
        self.param_input_limits |= {'times': (0, 'time_max'), 'voltages': ('mv_min', 'mv_max'), }
        self.param_converters |= {'times': lambda x: [param_test_uint16(i) for i in x],
                                  'voltages': lambda x: [abs_mv_to_dac(i) for i in x], }
        self.defaults |= {'times': 0, 'voltages': 0}

        if params:
            self.progress_end = sum(self.params['times'])
            self.params['n_steps'] = len(self.params['times'])
        else:
            self.progress_max = 0
        self.progress_iters = 1
        self.data_bytes = self.calculate_data_bytes()

    def get_progress(self) -> float:
        if self.handler_instance.done:
            return 100
        try:
            return (1 - abs(
                self.progress_end - self.handler_instance.data['time_s'][-1]) / self.progress_end) * 100
        except IndexError:
            return 0



