from datetime import datetime
from typing import Any, Tuple

from ..dstat.experiment_handler import ExperimentHandler
from ..dstat.experiment_process import ExperimentProcess
from ..dstat.utility import abs_mv_to_dac, rel_mv_to_dac, param_test_non_zero_uint16, \
    param_test_uint16, param_test_uint8
from ..experiments.experiment_container import ExperimentContainer


class SWVExperimentHandler(ExperimentHandler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.data['current'] = []

    def data_handler(self, data_input: Tuple[datetime, int, bytes]):
        unpacked: Tuple[int, int, int]
        date, scan, data = data_input
        unpacked = self.struct.unpack(data)
        self.data['scan'].append(scan)
        self.data['timestamp'].append(date)

        for n, i in enumerate(unpacked):
            try:
                self.data[self.data_cols[n]].append(self.data_convert[n](i))
            except TypeError:  # If no converter
                self.data[self.data_cols[n]].append(i)
        self.data['current'].append(self.adc_to_amps(unpacked[1] - unpacked[2]))  # forward - reverse


class SWVExperimentContainer(ExperimentContainer):
    experiment_id = 'swv'
    display_name = 'Square Wave Voltammetry'
    process = ExperimentProcess
    handler = SWVExperimentHandler
    plots = {'current': {'x': 'voltage', 'y': 'current', 'hue': 'scan'}}
    data_bytes = 0

    def __init__(self, params: dict[str, Any], mux: int = 1):
        super().__init__(params, mux=mux)

        self.data_cols = ['voltage', 'forward', 'reverse']
        self.data_format = 'vAA'

        self.cmd_str += ['ES{clean_s} {dep_s} {clean_mv} {dep_mv} {start} {stop} {step} {pulse} {freq} {scans} ']
        self.param_input |= {'clean_s': int, 'dep_s': int, 'clean_mv': int, 'dep_mv': int,
                             'start': int, 'stop': int, 'step': int, 'pulse': int, 'freq': int,
                             'scans': int}
        self.param_input_display_names |= {'clean_s': 't_clean (s)', 'dep_s': 't_dep (s)',
                                           'clean_mv': 'V_clean (mV)', 'dep_mv': 'V_dep (mV)',
                                           'start': 'V_start (mV)', 'stop': 'V_stop (mV)', 'step': 'V_step (mV)',
                                           'pulse': 'V_pulse (mV)', 'freq': 'f (Hz)',
                                           'scans': 'Scans'}
        self.param_input_limits |= {'clean_s': (0, 'time_max'), 'dep_s': (0, 'time_max'),
                                    'clean_mv': ('mv_min', 'mv_max'), 'dep_mv': ('mv_min', 'mv_max'),
                                    'start': ('mv_min', 'mv_max'), 'stop': ('mv_min', 'mv_max'),
                                    'step': (1, 1000), 'pulse': (1, 1000), 'freq': (1, 'freq_max'),
                                    'scans': (0, 'scans_max')}
        self.param_converters |= {'clean_s': param_test_uint16, 'dep_s': param_test_uint16,
                                  'clean_mv': abs_mv_to_dac, 'dep_mv': abs_mv_to_dac,
                                  'start': abs_mv_to_dac, 'stop': abs_mv_to_dac, 'step': rel_mv_to_dac,
                                  'pulse': rel_mv_to_dac, 'freq': param_test_non_zero_uint16,
                                  'scans': param_test_uint8}
        self.defaults |= {'clean_s': 0, 'dep_s': 0, 'clean_mv': 0, 'dep_mv': 0}
        if params:
            if self.params['scans'] > 0:
                self.progress_max = 2 * abs(self.params['stop'] - self.params['start'])
            else:
                self.progress_max = abs(self.params['stop'] - self.params['start'])
            self.progress_start = self.params['start']
            self.progress_stop = self.params['stop']
            self.progress_iters = self.params['scans']
            self.progress_scan = 0
            self.progress_lastmv = self.params['start']
        else:
            self.progress_max = 0
        self.data_bytes = self.calculate_data_bytes()

    def get_progress(self):
        if self.handler_instance.done:
            return 100
        try:
            if self.params['scans'] > 0:
                if self.handler_instance.data['scan'][-1] > self.progress_scan:
                    self.progress_scan = self.handler_instance.data['scan'][-1]
                    self.progress_lastmv = self.progress_start
                    raise StopIteration

                current_mv = self.handler_instance.data['voltage'][-1]

                if abs(self.progress_stop - self.progress_lastmv) > abs(self.progress_stop - current_mv):
                    progress = 100 * (.5 - abs(self.progress_stop - current_mv) / self.progress_max)
                else:
                    progress = 100 * (1 - abs(self.progress_start - current_mv) / self.progress_max)

                self.progress_lastmv = current_mv
                return progress
            else:
                return (1 - abs(
                    self.progress_stop - self.handler_instance.data['voltage'][-1]) / self.progress_max) * 100
        except IndexError:
            return 0


class DPVExperimentContainer(ExperimentContainer):
    experiment_id = 'dpv'
    display_name = 'Differential Pulse Voltammetry'
    process = ExperimentProcess
    handler = SWVExperimentHandler
    data_bytes = 0

    def __init__(self, params: dict[str, Any], mux: int = 1):
        super().__init__(params, mux=mux)

        self.data_cols = ['voltage', 'forward', 'reverse']
        self.data_format = 'vAA'

        self.cmd_str += ['ED{clean_s} {dep_s} {clean_mv} {dep_mv} {start} {stop} {step} {pulse} {period} {width} ']
        self.param_input |= {'clean_s': int, 'dep_s': int, 'clean_mv': int, 'dep_mv': int,
                             'start': int, 'stop': int, 'step': int, 'pulse': int, 'period': int, 'width': int}
        self.param_input_display_names |= {'clean_s': 't_clean (s)', 'dep_s': 't_dep (s)',
                                           'clean_mv': 'V_clean (mV)', 'dep_mv': 'V_dep (mV)',
                                           'start': 'V_start (mV)', 'stop': 'V_stop (mV)', 'step': 'V_step (mV)',
                                           'pulse': 'V_pulse (mV)', 'period': 'Pulse Period (ms)',
                                           'width': 'Pulse Width (ms)'}
        self.param_input_limits |= {'clean_s': (0, 'time_max'), 'dep_s': (0, 'time_max'),
                                    'clean_mv': ('mv_min', 'mv_max'), 'dep_mv': ('mv_min', 'mv_max'),
                                    'start': ('mv_min', 'mv_max'), 'stop': ('mv_min', 'mv_max'),
                                    'step': (1, 1000), 'pulse': (1, 1000), 'period': (1, 1000),
                                    'width': (1, 1000)}
        self.param_converters |= {'clean_s': param_test_uint16, 'dep_s': param_test_uint16,
                                  'clean_mv': abs_mv_to_dac, 'dep_mv': abs_mv_to_dac,
                                  'start': abs_mv_to_dac, 'stop': abs_mv_to_dac, 'step': rel_mv_to_dac,
                                  'pulse': rel_mv_to_dac, 'period': param_test_uint16,
                                  'width': param_test_non_zero_uint16}
        self.defaults |= {'clean_s': 0, 'dep_s': 0, 'clean_mv': 0, 'dep_mv': 0}

        if params:
            self.progress_max = abs(self.params['stop'] - self.params['start'])
            self.progress_start = self.params['start']
            self.progress_stop = self.params['stop']
            self.progress_scan = 0
            self.progress_iters = 1
        else:
            self.progress_max = 0
        self.data_bytes = self.calculate_data_bytes()

    def get_progress(self):
        if self.handler_instance.done:
            return 100
        try:
            return (1 - abs(
                self.progress_stop - self.handler_instance.data['voltage'][-1]) / self.progress_max) * 100
        except IndexError:
            return 0
