#!/usr/bin/env python
#     DStat Interface - An interface for the open hardware DStat potentiostat
#     Copyright (C) 2014  Michael D. M. Dryden - 
#     Wheeler Microfluidics Laboratory <http://microfluidics.utoronto.ca>
#         
#     
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation, either version 3 of the License, or
#     (at your option) any later version.
#     
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.
#     
#     You should have received a copy of the GNU General Public License
#     along with this program.  If not, see <http://www.gnu.org/licenses/>.

import serial
from serial.tools import list_ports
import time
import struct
import multiprocessing as mp
from collections import OrderedDict
import logging

try:
    import gi
    gi.require_version('Gtk', '3.0')
    from gi.repository import Gtk, GObject
except ImportError:
    print "ERR: GTK not available"
    sys.exit(1)

from errors import InputError, VarError

logger = logging.getLogger("dstat.comm")
dstat_logger = logging.getLogger("dstat.comm.DSTAT")
exp_logger = logging.getLogger("dstat.comm.Experiment")

import state

class AlreadyConnectedError(Exception):
    def __init__(self):
        super(AlreadyConnectedError, self).__init__(self,
            "Serial instance already connected.")

class NotConnectedError(Exception):
    def __init__(self):
        super(NotConnectedError, self).__init__(self,
            "Serial instance not connected.")
            
class ConnectionError(Exception):
    def __init__(self):
        super(ConnectionError, self).__init__(self,
            "Could not connect.")


def _serial_process(ser_port, proc_pipe, ctrl_pipe, data_pipe):
    ser_logger = logging.getLogger("dstat.comm._serial_process")
    
    connected = False
    
    try:
        ser = delayedSerial(ser_port, baudrate=1000000, timeout=1)
        ser_logger.info("Reattaching DStat udc")
        # ser.write("!R") # Send restart command
        ser.close()
    except serial.SerialException:
        return 1
    
    for i in range(5):
        time.sleep(1) # Give OS time to enumerate
    
        try:
            ser = delayedSerial(ser_port, baudrate=1000000, timeout=1)
            ser_logger.info("Connecting")
            connected = True
        except serial.SerialException:
            pass
        
        if connected is True:
            break
            
    if ser.isOpen() is False:
        return 1
    
    ser.write("ck") # Keep this to support old firmwares
    
    ser.flushInput()
    ser.write('!')
    
    for i in range(10):
        if not ser.read()=="C":
            time.sleep(.5)
            ser.write('!')
        else:
            ser.write('V')
            break

    while True:
        # These can only be called when no experiment is running
        if ctrl_pipe.poll(): 
            ctrl_buffer = ctrl_pipe.recv()
            
            if ctrl_buffer == ('a' or "DISCONNECT"):
                proc_pipe.send("ABORT")
                ser.write('a')
                ser_logger.info("ABORT")
                
                if ctrl_buffer == "DISCONNECT":
                    ser_logger.info("DISCONNECT")
                    ser.close()
                    proc_pipe.send("DISCONNECT")
                    return 0
            
        elif proc_pipe.poll():
            while ctrl_pipe.poll():
                ctrl_pipe.recv()
            
            return_code = proc_pipe.recv().run(ser, ctrl_pipe, data_pipe)
            ser_logger.info('Return code: %s', str(return_code))

            proc_pipe.send(return_code)
        
        else:
            time.sleep(.1)
            


class SerialConnection(GObject.Object):
    __gsignals__ = {
        'connected': (GObject.SIGNAL_RUN_FIRST, None, ()),
        'disconnected': (GObject.SIGNAL_RUN_FIRST, None, ())
    }
    
    def __init__(self):
        super(SerialConnection, self).__init__()
        self.connected = False
    
    def connect(self, ser_port):
        if self.connected is False:
            self.proc_pipe_p, self.proc_pipe_c = mp.Pipe(duplex=True)
            self.ctrl_pipe_p, self.ctrl_pipe_c = mp.Pipe(duplex=True)
            self.data_pipe_p, self.data_pipe_c = mp.Pipe(duplex=True)
    
            self.proc = mp.Process(target=_serial_process, args=(ser_port,
                                    self.proc_pipe_c, self.ctrl_pipe_c,
                                    self.data_pipe_c))
            self.proc.start()
            time.sleep(.5)
            if self.proc.is_alive() is False:
                raise ConnectionError()
                return False
            self.connected = True
            self.emit('connected')
            return True
        else:
            raise AlreadyConnectedError()
            return False
    
    def assert_connected(self):
        if self.connected is False:
            raise NotConnectedError()
    
    def start_exp(self, exp):
        self.assert_connected()
        
        self.proc_pipe_p.send(exp)
    
    def stop_exp(self):
        self.assert_connected()
        self.send_ctrl('a')
        
    def get_proc(self, block=False):
        self.assert_connected()
            
        if block is True:
            return self.proc_pipe_p.recv()
        else:
            if self.proc_pipe_p.poll() is True:
                return self.proc_pipe_p.recv()
            else:
                return None
                
    def get_data(self, block=False):
        self.assert_connected()
        
        if block is True:
            return self.data_pipe_p.recv()
        else:
            if self.data_pipe_p.poll() is True:
                return self.data_pipe_p.recv()
            else:
                return None
    
    def flush_data(self):
        self.assert_connected()
        
        while self.proc_pipe_p.poll() is True:
            self.proc_pipe_p.recv()
    
    def send_ctrl(self, ctrl):
        self.assert_connected()
        
        self.ctrl_pipe_p.send(ctrl)
    
    def disconnect(self):
        self.send_ctrl('a')
        time.sleep(.2)
        self.proc.terminate()
        self.emit('disconnected')
        self.connected = False

class VersionCheck(object):
    def __init__(self):
        pass
        
    def run(self, ser, ctrl_pipe, data_pipe):
        """Tries to contact DStat and get version. Returns a tuple of
        (major, minor). If no response, returns empty tuple.
            
        Arguments:
        ser_port -- address of serial port to use
        """
        try:
            ser.flushInput()
            ser.write('!')
                
            while not ser.read()=="C":
                ser.flushInput()
                ser.write('!')
                
            ser.write('V')
            for line in ser:
                if line.startswith('V'):
                    input = line.lstrip('V')
                elif line.startswith("#"):
                    dstat_logger.info(line.lstrip().rstrip())
                elif line.lstrip().startswith("no"):
                    dstat_logger.debug(line.lstrip().rstrip())
                    ser.flushInput()
                    break
                    
            parted = input.rstrip().split('.')
            e = "PCB version: "
            e += str(input.rstrip())
            dstat_logger.info(e)
            
            data_pipe.send((int(parted[0]), int(parted[1])))
            status = "DONE"
        
        except UnboundLocalError as e:
            status = "SERIAL_ERROR"
        except SerialException as e:
            logger.error('SerialException: %s', e)
            status = "SERIAL_ERROR"
        
        finally:
            return status

def version_check(ser_port):
    """Tries to contact DStat and get version. Returns a list of
    [(major, minor), serial instance]. If no response, returns empty tuple.
        
    Arguments:
    ser_port -- address of serial port to use
    """
    # try:     
    state.ser = SerialConnection()
    
    state.ser.connect(ser_port)
    state.ser.start_exp(VersionCheck())
    result = state.ser.get_proc(block=True)
    if result == "SERIAL_ERROR":
        buffer = 1
    else:
        buffer = state.ser.get_data(block=True)
    logger.debug("version_check done")
    
    return buffer
        
    # except:
    #     pass

class Settings(object):
    def __init__(self, task, settings=None):
        self.task = task
        self.settings = settings
        
    def run(self, ser, ctrl_pipe, data_pipe):
        """Tries to contact DStat and get settings. Returns dict of
        settings.
        """
        
        self.ser = ser
        
        if 'w' in self.task:
            self.write()
            
        if 'r' in self.task:
            data_pipe.send(self.read())
        
        status = "DONE"
        
        return status
        
    def read(self):
        settings = OrderedDict()
        
        self.ser.flushInput()
        self.ser.write('!')
                
        while not self.ser.read()=="C":
            self.ser.flushInput()
            self.ser.write('!')
            
        self.ser.write('SR')
        for line in self.ser:
            if line.lstrip().startswith('S'):
                input = line.lstrip().lstrip('S')
            elif line.lstrip().startswith("#"):
                dstat_logger.info(line.lstrip().rstrip())
            elif line.lstrip().startswith("no"):
                dstat_logger.debug(line.lstrip().rstrip())
                self.ser.flushInput()
                break
                
        parted = input.rstrip().split(':')
        
        for i in range(len(parted)):
            settings[parted[i].split('.')[0]] = [i, parted[i].split('.')[1]]
        
        return settings
        
    def write(self):
        self.ser.flushInput()
        self.ser.write('!')
                
        while not self.ser.read()=="C":
            self.ser.flushInput()
            self.ser.write('!')
            
        write_buffer = range(len(self.settings))
    
        for i in self.settings: # make sure settings are in right order
            write_buffer[self.settings[i][0]] = self.settings[i][1]
        
        self.ser.write('SW')
        for i in write_buffer:
            self.ser.write(i)
            self.ser.write(' ')
        
        return
        
def read_settings():
    """Tries to contact DStat and get settings. Returns dict of
    settings.
    """
    
    state.ser.flush_data()
    state.ser.start_exp(Settings(task='r'))
    state.settings = state.ser.get_data(block=True)
    
    logger.debug("read_settings: %s", state.ser.get_proc(block=True))
    
    return
    
def write_settings():
    """Tries to write settings to DStat from global settings var.
    """
    
    state.ser.flush_data()
    state.ser.start_exp(Settings(task='w', settings=state.settings))
    
    logger.debug("write_settings: %s", state.ser.get_proc(block=True))
    
    return
    
class LightSensor:
    def __init__(self):
        pass
        
    def run(self, ser, ctrl_pipe, data_pipe):
        """Tries to contact DStat and get light sensor reading. Returns uint of
        light sensor clear channel.
        """
        
        ser.flushInput()
        ser.write('!')
                
        while not ser.read()=="C":
            self.ser.flushInput()
            ser.write('!')
            
        ser.write('T')
        for line in ser:
            if line.lstrip().startswith('T'):
                input = line.lstrip().lstrip('T')
            elif line.lstrip().startswith("#"):
                dstat_logger.info(line.lstrip().rstrip())
            elif line.lstrip().startswith("no"):
                dstat_logger.debug(line.lstrip().rstrip())
                ser.flushInput()
                break
                
        parted = input.rstrip().split('.')
        
        data_pipe.send(parted[0])
        status = "DONE"
        
        return status

def read_light_sensor():
    """Tries to contact DStat and get light sensor reading. Returns uint of
    light sensor clear channel.
    """
    
    state.ser.flush_data()
    state.ser.start_exp(LightSensor())
    
    logger.debug("read_light_sensor: %s", state.ser.get_proc(block=True))
    
    return state.ser.get_data(block=True)
    

class delayedSerial(serial.Serial): 
    """Extends Serial.write so that characters are output individually
    with a slight delay
    """
    def write(self, data):
        for i in data:
            serial.Serial.write(self, i)
            time.sleep(.001)

class SerialDevices(object):
    """Retrieves and stores list of serial devices in self.ports"""
    def __init__(self):
        try:
            self.ports, _, _ = zip(*list_ports.comports())
        except ValueError:
            self.ports = []
            logger.error("No serial ports found")
    
    def refresh(self):
        """Refreshes list of ports."""
        self.ports, _, _ = zip(*list_ports.comports())