import logging
import pickle
from math import ceil

import pynng
import trio

from dstat_interface.core.dstat.comm import SerialDevices, dstat_connect
from dstat_interface.core.dstat.comm import state as dstat_state
from dstat_interface.core.experiments.experiment_container import ExperimentContainer
from dstat_interface.core.experiments.lsv import LSVExperimentContainer, CVExperimentContainer
from dstat_interface.core.experiments.swv import SWVExperimentContainer, DPVExperimentContainer
from dstat_interface.core.experiments.chronoamp import CAExperimentContainer
from dstat_interface.core.tasks import ExperimentBaseTasks
from dstat_interface.utils import runtime_dir

logger = logging.getLogger(__name__)

e = [LSVExperimentContainer, CVExperimentContainer, SWVExperimentContainer, DPVExperimentContainer, CAExperimentContainer]
experiments: dict[str, type[ExperimentContainer]] = {exp.experiment_id: exp for exp in e}


class DaemonExperimentTasks(ExperimentBaseTasks):
    def __init__(self, exp_con: ExperimentContainer, socket: pynng.Pub0, socket_int: pynng.Bus0):
        super().__init__(exp_con)
        self.socket = socket
        self.socket_int = socket_int
        self.data_index = 0
        self.tasks += [self.send_data, self.check_interrupts]

    async def update_progress(self, cancel_scope: trio.CancelScope):
        while True:
            scan = self.exp_con.progress_scan
            last_progress = 0

            while True:
                await trio.sleep(0.2)
                try:
                    progress = ceil(self.exp_con.get_progress())

                # End of Scan
                except StopIteration:
                    await self.socket.asend(f'progress: {scan}, 100'.encode('utf-8'))
                    break

                if progress > last_progress:
                    await self.socket.asend(f'progress: {scan}, {progress}'.encode('utf-8'))
                    last_progress = progress
                if progress >= 100:
                    break

    async def send_data(self, cancel_scope: trio.CancelScope):
        while True:
            await trio.sleep(0.2)
            new_index = len(self.exp_con.handler_instance.data['timestamp'])
            if new_index == 0:
                continue
            data = {key: value[self.data_index:new_index] for key, value in self.exp_con.handler_instance.data.items()}

            await self.socket.asend(b'data:' + pickle.dumps(data))
            self.data_index = new_index
            if self.exp_con.handler_instance.done:
                cancel_scope.cancel()
                return

    async def check_interrupts(self, cancel_scope: trio.CancelScope):
        while True:
            msg = await self.socket_int.arecv()
            if msg == b'ABORT':
                dstat_state.ser.stop_exp()


class DStatDaemon(object):
    def __init__(self, listen_addr_pub=None, listen_addr_ctrl=None, listen_addr_interrupt=None):
        if listen_addr_pub is None:
            self.address_pub = 'ipc://' + runtime_dir() + '/dstat-interfaced-pub'
        if listen_addr_ctrl is None:
            self.address_ctrl = 'ipc://' + runtime_dir() + '/dstat-interfaced-ctrl'
        if listen_addr_interrupt is None:
            self.address_int = 'ipc://' + runtime_dir() + '/dstat-interfaced-int'

        self.socket_pub = pynng.Pub0(listen=self.address_pub)
        self.socket_ctrl = pynng.Rep0(listen=self.address_ctrl)
        self.socket_interrupt = pynng.Bus0(listen=self.address_int)

        self.device_list = SerialDevices()
        self.connected = False

        self.command_map = {'fetch_devices': self.fetch_devices,
                            'connect_dstat': self.connect_dstat,
                            'disconnect_dstat': self.disconnect_dstat,
                            'fetch_experiments': self.fetch_experiments,
                            'run_experiment': self.run_experiment}

    async def fetch_devices(self, payload: dict = None) -> dict:
        self.device_list.refresh()
        return {'ports': self.device_list.ports}

    async def connect_dstat(self, payload: dict = None) -> dict:
        if not self.connected:
            self.connected = dstat_connect(payload['port'])
            msg = f'DStat connected: {self.connected}'
        else:
            msg = "DStat already connected"
        return {'connected': self.connected,
                'dstat_version': dstat_state.dstat_version,
                'firmware_version': dstat_state.firmware_version,
                'board': dstat_state.board_instance,
                'msg': msg}

    async def disconnect_dstat(self, payload: dict = None):
        if self.connected:
            dstat_state.ser.close()
            self.connected = False
            return {'connected': False, 'msg': 'DStat disconnected'}
        else:
            return {'connected': False, 'msg': 'DStat already disconnected'}

    async def fetch_experiments(self, payload: dict = None):
        return {'experiments': list(experiments.values())}

    async def run_experiment(self, exp_dict: dict):
        exp = exp_dict['experiment_id']
        del exp_dict['experiment_id']

        experiment_container = experiments[exp](exp_dict, mux=dstat_state.board_instance.channels)
        dstat_state.ser.start_exp(experiment_container.get_proc())
        experiment_container.start_handler(dstat_state.ser)
        tasks = DaemonExperimentTasks(experiment_container, self.socket_pub, self.socket_interrupt)
        await tasks.loop()
        return {'status': 'done'}

    async def run(self):
        while True:
            command: str
            payload: dict
            command, payload = pickle.loads(await self.socket_ctrl.arecv())
            await self.socket_ctrl.asend(pickle.dumps(await (self.command_map[command](payload))))

    def close(self):
        self.socket_pub.close()
        self.socket_ctrl.close()

    def __del__(self):
        self.close()


async def main():
    daemon = DStatDaemon()
    async with trio.open_nursery() as nursery:
        nursery.start_soon(daemon.run)

if __name__ == '__main__':
    trio.run(main)
