#!/usr/bin/env python
#     DStat Interface - An interface for the open hardware DStat potentiostat
#     Copyright (C) 2014  Michael D. M. Dryden - 
#     Wheeler Microfluidics Laboratory <http://microfluidics.utoronto.ca>
#         
#     
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation, either version 3 of the License, or
#     (at your option) any later version.
#     
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.
#     
#     You should have received a copy of the GNU General Public License
#     along with this program.  If not, see <http://www.gnu.org/licenses/>.

import logging
import multiprocessing as mp
import time
from datetime import datetime
from typing import Union, Tuple

import serial
from pkg_resources import parse_version
from serial.tools import list_ports

from dstat_interface.core.dstat import boards
from dstat_interface.core.dstat.state import DStatState
from .utility import SettingsProcess, VersionCheckProcess
from .experiment_process import BaseExperimentProcess
from .simulator import SerialSim

logger = logging.getLogger(__name__)
dstat_logger = logging.getLogger(f'{__name__}.DSTAT')
exp_logger = logging.getLogger(f'{__name__}.Experiment')

state = DStatState()


class AlreadyConnectedError(Exception):
    def __init__(self):
        super().__init__('Serial instance already connected.')


class NotConnectedError(Exception):
    def __init__(self):
        super().__init__('Serial instance not connected.')


class TransmitError(Exception):
    def __init__(self):
        super().__init__('No reply received.')


def _serial_process(ser_port: str, proc_pipe: mp.connection.Connection, ctrl_pipe: mp.connection.Connection,
                    data_pipe: mp.connection.Connection, poll_s: float = .1):
    ser_logger = logging.getLogger("{}._serial_process".format(__name__))
    exc = None

    for i in range(5):
        time.sleep(1)  # Give OS time to enumerate

        try:
            if ser_port == 'simulator':
                ser = SerialSim(ser_port, timeout=1)
            else:
                ser = serial.Serial(ser_port, timeout=1)
            ser_logger.info('Connecting')
            break
        except serial.SerialException as e:
            exc = e
    else:
        if exc:
            raise exc
        else:
            raise serial.SerialException(f'Could not connect on {ser_port}')

    ser.write(b'!0 ')

    for i in range(10):
        if ser.readline().rstrip() == b'@ACK 0':
            if ser.readline().rstrip() == b'@RCV 0':
                break
        else:
            time.sleep(.5)
            ser.reset_input_buffer()
            ser.write(b'!0 ')
            time.sleep(poll_s)
    else:
        raise TransmitError()

    while True:
        # These can only be called when no experiment is running, otherwise ctrl_pipe passes to proc
        if ctrl_pipe.poll():
            ctrl_buffer: str = ctrl_pipe.recv()
            if ctrl_buffer in ('a', "DISCONNECT"):
                proc_pipe.send("ABORT")
                try:
                    ser.write(b'a')
                except serial.SerialException:
                    return 0
                ser_logger.info("ABORT")

                if ctrl_buffer == "DISCONNECT":
                    ser_logger.info("DISCONNECT")
                    ser.rts = False
                    ser._update_dtr_state()  # Need DTR update on Windows
                    ser.close()
                    proc_pipe.send("DISCONNECT")
                    return 0
            else:
                ser.write(ctrl_buffer.encode('ascii'))

        elif proc_pipe.poll():
            # Flush ctrl_pipe
            while ctrl_pipe.poll():
                ctrl_pipe.recv()

            try:
                return_code = proc_pipe.recv().run(ser, ctrl_pipe, data_pipe)
            except serial.SerialException:
                proc_pipe.send("DISCONNECT")
                ser.rts = False
                ser._update_dtr_state()  # Need DTR update on Windows
                ser.close()
                return 0
            ser_logger.info('Return code: %s', str(return_code))

            proc_pipe.send(return_code)

        else:
            time.sleep(poll_s)


class SerialConnection(object):
    def __init__(self):
        self.connected = False
        self.proc_pipe_p, self.proc_pipe_c = mp.Pipe(duplex=True)
        self.ctrl_pipe_p, self.ctrl_pipe_c = mp.Pipe(duplex=True)
        self.data_pipe_p, self.data_pipe_c = mp.Pipe(duplex=True)
        self.proc: Union[None, mp.Process] = None

    def open(self, ser_port: str):
        if self.connected is False:
            self.proc_pipe_p, self.proc_pipe_c = mp.Pipe(duplex=True)
            self.ctrl_pipe_p, self.ctrl_pipe_c = mp.Pipe(duplex=True)
            self.data_pipe_p, self.data_pipe_c = mp.Pipe(duplex=True)

            self.proc = mp.Process(target=_serial_process, args=(ser_port,
                                                                 self.proc_pipe_c, self.ctrl_pipe_c, self.data_pipe_c))
            self.proc.start()
            if self.proc.is_alive() is False:
                raise ConnectionError()
            self.connected = True
            return True
        else:
            raise AlreadyConnectedError()

    def assert_connected(self):
        if self.connected is False:
            raise NotConnectedError()

    def start_exp(self, exp: BaseExperimentProcess):
        self.assert_connected()
        self.proc_pipe_p.send(exp)

    def stop_exp(self):
        self.send_ctrl('a')

    def get_proc(self, block=False):
        self.assert_connected()
        if block:
            return self.proc_pipe_p.recv()
        else:
            if self.proc_pipe_p.poll() is True:
                return self.proc_pipe_p.recv()
            else:
                return None

    def get_ctrl(self, block=False):
        self.assert_connected()
        if block:
            return self.ctrl_pipe_p.recv()
        else:
            if self.ctrl_pipe_p.poll() is True:
                return self.ctrl_pipe_p.recv()
            else:
                return None

    def get_data(self, block=False) -> Union[None, Tuple[datetime, int, bytes]]:
        self.assert_connected()
        if block:
            return self.data_pipe_p.recv()
        else:
            if self.data_pipe_p.poll():
                return self.data_pipe_p.recv()
            else:
                return None

    def flush_data(self):
        self.assert_connected()
        while self.data_pipe_p.poll() is True:
            self.data_pipe_p.recv()

    def send_ctrl(self, ctrl: str):
        self.assert_connected()
        self.ctrl_pipe_p.send(ctrl)

    def close(self):
        logger.info('Disconnecting')
        self.send_ctrl('DISCONNECT')
        self.proc.join()
        self.connected = False


def dstat_connect(ser_port) -> bool:
    """Tries to contact DStat and get version. Stores version in state.
    If no response, returns False, otherwise True.

    Arguments:
    ser_port -- address of serial port to use
    """
    state.ser = SerialConnection()

    state.ser.open(ser_port)
    state.ser.start_exp(VersionCheckProcess())
    result = state.ser.get_proc(block=True)
    if result == 'SERIAL_ERROR':
        state.dstat_version = None
        state.firmware_version = None
        return False

    cmd, buffer = state.ser.get_data(block=True)
    if cmd != 'V':
        logger.error('DStat did not respond correctly')
        return False
    version, state.firmware_version = buffer
    state.dstat_version = parse_version(version)

    cmd, buffer = state.ser.get_data(block=True)
    if cmd != 'X':
        logger.error('DStat did not reply to mux channel request')
        return False

    if buffer > 1:
        mux = True
    else:
        mux = False

    board = boards.find_board(state.dstat_version, mux=mux)
    if board:
        state.board_instance = board()
    logger.debug('version_check done')
    time.sleep(.1)

    return True


def read_settings():
    """Tries to contact DStat and get settings. Returns dict of
    settings.
    """
    state.ser.flush_data()
    state.ser.start_exp(SettingsProcess(task='r'))
    state.settings = state.ser.get_data(block=True)

    logger.info("Read settings from DStat")
    logger.debug("read_settings: %s", state.ser.get_proc(block=True))


def write_settings():
    """Tries to write settings to DStat from global settings var.
    """
    logger.debug("Settings to write: %s", state.settings)

    state.ser.flush_data()
    state.ser.start_exp(SettingsProcess(task='w', settings=state.settings))
    logger.info("Wrote settings to DStat")
    logger.debug("write_settings: %s", state.ser.get_proc(block=True))


class SerialDevices(object):
    """Retrieves and stores list of serial devices in self.ports"""
    def __init__(self):
        self.ports = []
        self.refresh()

    def refresh(self):
        """Refreshes list of ports."""
        try:
            self.ports, _, _ = zip(*list_ports.grep("DSTAT"))
        except ValueError:
            self.ports = []
            logger.error("No serial ports found")


if __name__ == '__main__':
    root_logger = logging.getLogger()
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    log_handlers = [logging.StreamHandler()]
    log_formatter = logging.Formatter(
        fmt='%(asctime)u %(levelname)u: [%(name)u] %(message)u',
        datefmt='%H:%M:%S',
    )
    for handler in log_handlers:
        handler.setFormatter(log_formatter)
        root_logger.addHandler(handler)
    root_logger.setLevel(logging.DEBUG)

    ports = SerialDevices()

    while True:
        try:
            dstat_connect(ports.ports[0])
            # dstat_connect('simulator')
            logger.info('DStat ver: %s Firmware ver: %s', state.dstat_version, state.firmware_version)
            read_settings()
            logger.info('Settings:\n\t%s', '\n\t'.join([f'{key}: {value}' for key, value in state.settings.items()]))
            break
        except IndexError:
            logger.info('No DStat Found')
            time.sleep(5)
            ports.refresh()
    state.ser.close()
