from datetime import datetime
import logging
import struct
from typing import Callable, Union, Sequence, Tuple

from .comm import SerialConnection


logger = logging.getLogger(__name__)
dstat_logger = logging.getLogger("{}.DSTAT".format(__name__))
exp_logger = logging.getLogger("{}.Experiment".format(__name__))


class ExperimentHandler(object):
    def __init__(self, ser: SerialConnection, data_cols: Sequence[str] = ('voltage', 'current'),
                 data_format: str = 'vA'):

        self.ser = ser
        self.data = {'timestamp': [], 'scan': []}
        self.data |= {i: [] for i in data_cols}
        self.data_cols = data_cols
        self.data_convert = [self.get_data_converter(i) for i in data_format]
        struct_formats = ['H' if i.islower() else 'l' for i in data_format]
        self.struct = struct.Struct(f'<{"".join(struct_formats)}')
        self.done = False

        self.adc_trim = 0
        self.adc_gain = 3000
        self.adc_pga = 2

        self.first_skipped = False

    def get_data_converter(self, fmt: str) -> Union[Callable, None]:
        converters = {'v': self.dac_to_volts, 'A': self.adc_to_amps, 'm': self.ms_to_s}
        try:
            return converters[fmt]
        except KeyError:
            return None

    @staticmethod
    def dac_to_volts(dac: int) -> float:
        return (dac - 32768) * 3000. / 65536

    def adc_to_amps(self, adc: int) -> float:
        return (adc + self.adc_trim) * (1.5 / self.adc_gain / 8388607) / (self.adc_pga / 2)

    @staticmethod
    def ms_to_s(ms: int) -> float:
        return ms/1000.

    def experiment_running_data(self):
        """Receive data from experiment process and add to
        current_exp.data['data].
        """
        try:
            incoming = self.ser.get_data()
            if incoming is not None:
                if not self.first_skipped:
                    self.first_skipped = True
                    return True
                self.data_handler(incoming)
            return True

        except EOFError as err:
            logger.error(err)
            return False
        except IOError as err:
            logger.error(err)
            return False

    def get_all_data(self):
        """
        Processes remaining data in queue
        """
        while True:
            incoming = self.ser.get_data()
            if incoming is not None:
                self.data_handler(incoming)
            else:
                self.done = True
                return

    def data_handler(self, data_input: Tuple[datetime, int, bytes]):
        date, scan, data = data_input
        unpacked = self.struct.unpack(data)
        self.data['scan'].append(scan)
        self.data['timestamp'].append(date)

        for n, i in enumerate(unpacked):
            try:
                self.data[self.data_cols[n]].append(self.data_convert[n](i))
            except TypeError:  # If no converter
                self.data[self.data_cols[n]].append(i)

    def experiment_running_proc(self):
        """Receive proc signals from experiment process."""
        try:
            proc_buffer = self.ser.get_proc()
            if proc_buffer is not None:
                if proc_buffer in ["DONE", "SERIAL_ERROR", "ABORT"]:
                    if proc_buffer == "SERIAL_ERROR":
                        self.ser.close()

                else:
                    logger.warning("Unrecognized experiment return code: %s",
                                   proc_buffer)
                return False
            return True

        except EOFError as err:
            logger.warning("EOFError: %s", err)
            return False
        except IOError as err:
            logger.warning("IOError: %s", err)
            return False
