from typing import Any

from ..dstat.experiment_handler import ExperimentHandler
from ..dstat.experiment_process import ExperimentProcess
from ..dstat.utility import abs_mv_to_dac, rel_mv_to_dac, param_test_uint16, param_test_non_zero_uint8
from ..experiments.experiment_container import ExperimentContainer


class LSVExperimentContainer(ExperimentContainer):
    experiment_id = 'lsv'
    display_name = 'Linear Scan Voltammetry'
    process = ExperimentProcess
    handler = ExperimentHandler
    data_bytes = 0

    def __init__(self, params: dict[str, Any], mux: int = 1):
        super().__init__(params, mux=mux)

        self.data_cols = ['voltage', 'current']
        self.data_format = 'vA'

        self.cmd_str += ['EL{clean_s} {dep_s} {clean_mv} {dep_mv} {start} {stop} {slope} ']
        self.param_input |= {'clean_s': int, 'dep_s': int, 'clean_mv': int, 'dep_mv': int,
                             'start': int, 'stop': int, 'slope': int}
        self.param_input_display_names |= {'clean_s': 't_clean (s)', 'dep_s': 't_dep (s)',
                                           'clean_mv': 'V_clean (mV)', 'dep_mv': 'V_dep (mV)',
                                           'start': 'V_start (mV)', 'stop': 'V_stop (mV)', 'slope': 'Slope (mV/s)', }
        self.param_input_limits |= {'clean_s': (0, 'time_max'), 'dep_s': (0, 'time_max'),
                                    'clean_mv': ('mv_min', 'mv_max'), 'dep_mv': ('mv_min', 'mv_max'),
                                    'start': ('mv_min', 'mv_max'), 'stop': ('mv_min', 'mv_max'),
                                    'slope': (1, 5000)}
        self.param_converters |= {'clean_s': param_test_uint16, 'dep_s': param_test_uint16,
                                  'clean_mv': abs_mv_to_dac, 'dep_mv': abs_mv_to_dac,
                                  'start': abs_mv_to_dac, 'stop': abs_mv_to_dac, 'slope': rel_mv_to_dac}
        self.defaults |= {'clean_s': 0, 'dep_s': 0, 'clean_mv': 0, 'dep_mv': 0}
        if params:
            self.progress_max = abs(self.params['stop'] - self.params['start'])
            self.progress_end = self.params['stop']
        else:
            self.progress_max = 0
        self.progress_iters = 1
        self.data_bytes = self.calculate_data_bytes()

    def get_progress(self) -> float:
        if self.handler_instance.done:
            return 100
        try:
            return (1 - abs(self.progress_end - self.handler_instance.data['voltage'][-1]) / self.progress_max) * 100
        except IndexError:
            return 0


class CVExperimentContainer(ExperimentContainer):
    experiment_id = 'cv'
    display_name = 'Cyclic Voltammetry'
    process = ExperimentProcess
    handler = ExperimentHandler
    plots = {'current': {'x': 'voltage', 'y': 'current', 'hue': 'scan'}}
    data_bytes = 0

    def __init__(self, params: dict[str, Any], mux: int = 1):
        super().__init__(params, mux=mux)

        self.data_cols = ['voltage', 'current']
        self.data_format = 'vA'

        self.cmd_str += ['EC{clean_s} {dep_s} {clean_mv} {dep_mv} {v1} {v2} {start} {scans} {slope} ']
        self.param_input |= {'clean_s': int, 'dep_s': int, 'clean_mv': int, 'dep_mv': int, 'v1': int, 'v2': int,
                             'start': int, 'scans': int, 'slope': int}
        self.param_input_display_names |= {'clean_s': 't_clean (s)', 'dep_s': 't_dep (s)',
                                           'clean_mv': 'V_clean (mV)', 'dep_mv': 'V_dep (mV)',
                                           'start': 'V_start (mV)', 'v1': 'V_1 (mV)', 'v2': 'V_2 (mV)',
                                           'scans': 'Scans', 'slope': 'Slope (mV/s)'}
        self.param_input_limits |= {'clean_s': (0, 'time_max'), 'dep_s': (0, 'time_max'),
                                    'clean_mv': ('mv_min', 'mv_max'), 'dep_mv': ('mv_min', 'mv_max'),
                                    'start': ('mv_min', 'mv_max'), 'v1': ('mv_min', 'mv_max'),
                                    'v2': ('mv_min', 'mv_max'),
                                    'scans': (1, 'scans_max'), 'slope': (1, 5000)}
        self.param_converters |= {'clean_s': param_test_uint16, 'dep_s': param_test_uint16,
                                  'clean_mv': abs_mv_to_dac, 'dep_mv': abs_mv_to_dac,
                                  'v1': abs_mv_to_dac, 'v2': abs_mv_to_dac, 'start': abs_mv_to_dac,
                                  'scans': param_test_non_zero_uint8,
                                  'slope': rel_mv_to_dac}
        self.defaults |= {'clean_s': 0, 'dep_s': 0, 'clean_mv': 0, 'dep_mv': 0}
        if params:
            self.progress_max = 2 * abs(self.params['v1'] - self.params['v2'])
            self.progress_start = self.params['start']
            self.progress_v1 = self.params['v1']
            self.progress_v2 = self.params['v2']
            self.progress_iters = self.params['scans']
            self.progress_scan = 0
            self.progress_lastmv = self.params['start']
        else:
            self.progress_max = 0

        self.data_bytes = self.calculate_data_bytes()

    def get_progress(self):
        try:
            if self.handler_instance.done:
                return 100

            if self.handler_instance.data['scan'][-1] > self.progress_scan:
                self.progress_scan = self.handler_instance.data['scan'][-1]
                self.progress_lastmv = self.progress_start
                raise StopIteration

            current_mv = self.handler_instance.data['voltage'][-1]

            # if moving towards v1
            if abs(self.progress_v1 - self.progress_lastmv) > abs(self.progress_v1 - current_mv):
                # Between v2 and start v1-mv and start-mv have same sign
                if (self.progress_v1 - current_mv > 0) == (self.progress_start - current_mv > 0):
                    progress = 100 * (1 - abs(self.progress_start - current_mv) / self.progress_max)
                else:
                    progress = 100 * (.25 - abs(self.progress_v1 - current_mv) / self.progress_max)
            else:
                progress = 100 * (.75 - abs(self.progress_v2 - current_mv) / self.progress_max)

            self.progress_lastmv = current_mv
            return progress
        except IndexError:
            return 0
