"""This module contains code for collecting experimental data.
It uses the pygame module to give us a reasonably nice GUI.

This software is copyright Greg Kochanski (2010) and is
available under the Lesser Gnu Public License, version 3 or higher.
It was funded by the UK's Economic and Social Research
Council under project RES-062-23-1323.  This is available from
http://sourceforge.org/projects/speechresearch,
http://kochanski.org/gpk/papers/2010/aesop_data_collect, and
http://www.phon.ox.ac.uk/files/releases/2008aesopus2_data_collect.tar
"""

# import os
# import datetime
import math
import wave
import numpy
from gmisclib import die
from gmisclib import gpkmisc
from gmisclib import Numeric_gpk as NG

import pygtk
pygtk.require('2.0')
import gtk
# import cairo
import pango

class ProgressBar(gtk.DrawingArea):
	"""This is a general-purpose progress bar for GTK that
	extends horizontally.
	"""
	# __gsignals__ = {"expose-event": "override"}
	
	RGB1 = (0.1, 0.2, 0.7)
	RGB2 = (0.9, 0.4, 0.1)
	
	def __init__(self):
		gtk.DrawingArea.__init__(self)
		self.fraction = None
		# self.connect("configure_event", self.drawme, None)
		self.connect("expose-event", self.do_expose_event, None)
		self.connect("expose_event", self.do_expose_event, None)
		self.gc1 = None
		self.gc2 = None
		self.f = 0.0


	def set_fraction(self, f):
		if f is not None:
			self.f = float(f)
			assert 0.0 <= self.f
			if self.window is not None:
				cr = self.window.cairo_create()
				self.draw(cr, *self.window.get_size())


	def do_expose_event(self, w, event, d):
		cr = self.window.cairo_create()
		cr.rectangle(event.area.x, event.area.y, event.area.width, event.area.height)
		cr.clip()
		self.draw(cr, *self.window.get_size())


	def draw(self, cr, width, height):
		cr.set_source_rgb(0.5, 0.5, 0.5)
		cr.rectangle(0, 0, width, height)
		cr.fill()
		
		cr.set_source_rgb(*self.RGB1)
		iw = int(round(width*self.f))
		cr.rectangle(0, 0, iw, height)
		cr.fill()
		
		cr.set_source_rgb(*self.RGB2)
		cr.rectangle(iw, 0, width, height)
		cr.fill()
		return False


def drop_blanks(s):
	return [ t for t in s if len(t)>0 ]



class experiment_base(object):
	"""This class is intended to be the basis of a finite-state machine
	description of the experiment.
	It helps you iterate through a list of stimuli with L{get_current_stimulus}()
	and L{next_stimulus}(), it manages state transitions with L{event}(),
	and displays prompts on the screen.
	"""

	def __init__(self, stimlist, hdr, gui):
		self.stimlist = drop_blanks(stimlist)
		self.current_handler = self.S_initial
		self.last_handler = None
		self.i = -1
		self.hdr = hdr
		self.gui = gui


	def close(self):
		pass


	def get_hdrs(self):
		return self.hdr.copy()


	def get_current_stimulus(self):
		if self.i >= len(self.stimlist):
			return None
		tmp = self.hdr.copy()
		tmp.update(self.stimlist[self.i])
		return tmp


	def get(self, *kd):
		key = kd[0]
		try:
			return self.stimlist[self.i][key]
		except KeyError:
			if len(kd) == 1:
				try:
					return self.hdr[key]
				except KeyError:
					die.warn('Please define "%s" in the header or as a column in the input file.' % key)
					raise
			else:
				return self.hdr.get(key, kd[1])


	def next_stimulus(self):
		self.i += 1
		self.gui.progress(self.i, len(self.stimlist))
		return self.i < len(self.stimlist)


	def is_last_stimulus(self):
		return self.i >= len(self.stimlist)-1


	def event(self, widget, ev, log):
		handler = self.current_handler
		# print 'Event handler=', handler, ev
		if isinstance(ev, gtk.gdk.Event) and ev.type == gtk.gdk.KEY_PRESS:
			e = ev.string
		else:
			e = widget
		try:
			while True:
				handler = handler(e, log)
				if handler is False:
					# print "event yields", handler
					self.gui.destroy(None, ev)
					return handler
				elif handler is True:
					# print "event yields", handler
					return handler
				elif handler is None:
					return True
				# print "Changing handler to", handler
				self._needs_show = True
				self.current_handler = handler
				e = None
		except:
			die.catch("ERROR")
			self.gui.destroy(None)
			die.die("ERROR")
		return True
	

	def S_initial(self, ev, log):
		raise RuntimeError, "Virtual Function"
	
	
	def first_entry(self):
		if self.last_handler is self.current_handler:
			return False
		self.last_handler = self.current_handler
		return True




class GUI_base(object):
	"""Base class for the data-collection graphical user interface.
	"""

	def delete_event(self, widget, event, data=None):
		return False

	def destroy(self, widget, data=None):
		"""Shut down the GUI."""
		try:
			gtk.main_quit()
		except RuntimeError, x:
			die.die('RuntimeError "%s" while shutting down.' % str(x))
	
	def connect_experiment(self,  expcall,  log):
		"""Connect the class that describes the experiment to the GUI.
		"""
		# keyname = gtk.gdk.keyval_name(event.keyval)
		self.window.connect("key_press_event", expcall,  log)
		self.window.connect("key-press-event", expcall,  log)
		expcall(None, None, log)
	

	def make_TextView(self, extra_line_space, extra_para_space, font_name):
		box = gtk.TextView()
		box.set_editable(False)
		box.set_cursor_visible(False)
		box.set_wrap_mode(gtk.WRAP_WORD)
		box.set_justification(gtk.JUSTIFY_LEFT)
		box.set_pixels_inside_wrap(extra_line_space)
		box.set_pixels_below_lines(extra_para_space)
		box.set_pixels_above_lines(extra_para_space)
		box.set_indent(10)
		print 'Set font to [%s]' % font_name
		pangoFont = pango.FontDescription(font_name)
		box.modify_font(pangoFont)
		return box


	def make_buttons(self,  boxtop):
		self._repeat = gtk.Button('<-')
		self._repeat.show()
		boxtop.pack_start(self._repeat,  False,  False)
		self._next = gtk.Button('->')
		boxtop.pack_start(self._next,  False,  False)
		self._next.show()
		return  boxtop


	def make_boxtop(self,  font_name):
		boxtop = gtk.HBox(False, 10)

		self._sbar = gtk.Statusbar()
		self._sbar.show()
		boxtop.pack_start(self._sbar, True, True)
		
		self._pbar = ProgressBar()
		self._pbar.set_size_request(300, 50)
		self.progress(0, 1)
		# boxtop.pack_start(self._pbar, True, True)
		boxtop.pack_start(self._pbar, False, False)
		self._pbar.show()

		self._peak = self.make_TextView(0,  0, font_name)
		self.set_peaks((-0.1, -0.1))
		self._peak.show()
		boxtop.pack_start(self._peak,  False,  False)

		self.make_buttons(boxtop)
		boxtop.show()
		return boxtop




	def __init__(self, extra_line_space=5, extra_para_space=5,
				stim_font=None, top_font=None):
		self.window = gtk.Window(gtk.WINDOW_TOPLEVEL)
		self.window.set_title('Experiment Aesop 1')
		self.window.connect("delete_event", self.delete_event)
		self.window.connect("delete-event", self.delete_event)
		self.window.connect("destroy", self.destroy)
		self.window.set_border_width(10)
		self.window.maximize()
		self.window.set_focus_on_map(True)

		box1 = gtk.VBox(False, 10)
		self.window.add(box1)

		if stim_font is None:
			stim_font = "Serif 24"
		if top_font is None:
			top_font = stim_font
		
		box1.pack_start(self.make_boxtop(top_font), False, False, 10)

		self._instr = self.make_TextView(extra_line_space, extra_para_space,
						stim_font)
		self._instr.get_buffer().set_text("Instructions")
		self._instr.show()
		box1.pack_start(self._instr, False, True, 5)

		self._stim = self.make_TextView(extra_line_space, extra_para_space,
						stim_font)
		self._stim.get_buffer().set_text(".")
		self._stim.show()
		box1.pack_start(self._stim, True, True, 5)
		box1.show()
		self.status_push(1, "-")
		# Defer the self.window.show() to allow tweaks
		# of window properties, such as self.window.add_events()

	def main(self):
		"""Call this to start the GUI going."""
		self.window.show()
		gtk.main()

	def progress(self, i, n):
		"""Set the position of the progress bar.  
		The idea is that i out of n items have been completed.
		@param i: How much has been done.
		@type i: int
		@param n: What's the total amount of work.
		@type n: int
		"""
		if n > 0:
			self._pbar.set_fraction(min(i+0.5, n)/float(n))

	def status_push(self, context_id, message):
		"""Push some text onto the status bar.
		@param context_id:  What category of the thing?
		@type context_id: int
		@param message: The thing to display in the "status" area.
		@type message: str
		"""
		self._sbar.push(context_id, message)

	def status_pop(self, context_id):
		"""Pop some text off the status bar.    The most recent text of
		the specified category is removed.
		@param context_id:  What category of the thing?
		@type context_id: int
		"""
		self._sbar.pop(context_id)
		
	def stim_win(self):
		"""Return a pointer to the stimulus area."""
		return self._stim
	
	def peak_win(self):
		"""Return a pointer to the area that displays the peak amplitude."""
		return self._peak
	
	def instr_win(self):
		"""Return a pointer to the instruction area."""
		return self._instr
	
	def set_peaks(self,  peaks):
		"""Set the displayed peak amplitude."""
		text = ', '.join(['%.1f' % p for p in peaks])
		self.peak_win().get_buffer().set_text(text)


_sizes = {
        2: (numpy.int16, 16, 32767),
        4: (numpy.int32, 32, 2**31-1)
        }


def wav_peak(fn, tstart=None, t_end=None):
	"""This function checks that a WAV file is legal, and
	returns the peak amplitude for each column,
	relative to the largest possible amplitude.
	@param fn: filename
	@type fn: str
	@param tstart: where to start in the file
	@param t_end: where to start in the file
	@type tstart: float
	@type t_end: float
	@rtype: list
	@return: list of peak amplitudes, one for each column.
	"""
	print 'Checking file', fn
	w = wave.open(fn, 'r')
	if not 1000 < w.getframerate() < 100000:
		raise ValueError, "Silly frame rate in WAV file: %d" % w.getframerate()
	nc = w.getnchannels()
	if not 1 <= nc < 10:
		raise ValueError, "Silly number of channels in WAV file: %d" % nc
	if not 1 < w.getsampwidth() <= 16:
		raise ValueError, "Silly sample width in WAV file: %d" % w.getsampwidth()
	nf = w.getnframes()
	if 100 >= nf:
		raise ValueError, "Absurdly short WAV file: %d frames" % nf
	if tstart is not None:
		istart = int(round(tstart*w.getframerate()))
		w.setpos(istart)
	if t_end is not None:
		nf = int(round(t_end*w.getframerate())) - istart
	if nf <= 0:
		return (0.0, 0.0)
	numtype, bitpix, scale = _sizes[w.getsampwidth()]
	data = numpy.fromstring(w.readframes(nf), numtype)
	w.close()
	data = numpy.reshape(data, (data.shape[0]//nc, nc))
	assert len(data.shape)==2 and data.shape[1]==nc
	if data.shape[0] <= 1:
		raise ValueError, "Time range yield (nearly?) no data: %s to %s" % (tstart, t_end)
	return [ max(NG.N_maximum(data[:,i]), -NG.N_minimum(data[:,i]))/float(scale)
			for i in range(nc)
		]


def check_wav(fn,  gui):
	"""This checks a wav file and displays its peak amplitude.
	"""
	gui.set_peaks( [ 10.0*math.log10(pk) for pk in wav_peak(fn) ] )



_get_text_cache = {}
def get_text(s):
	"""Pull text from a named file.
	"""
	global _get_text_cache
	try:
		tmp = _get_text_cache[s]
	except KeyError:
		o = []
		try:
			for tmp in open(s, 'r'):
				o.append(tmp.strip())
		except IOError, x:
			die.warn("Read failed on file %s inside get_text() cause: %s" % (s, str(x)))
			raise
		tmp = '\n'.join(o)
		_get_text_cache[s] = tmp
	return tmp


def set_defaults(d, **kv):
	"""Set a value into dictionary d only if there is nothing there already.
	@param d: dictionary to be updated
	@type d: dict
	@param kv: dictionary of default values
	@type kv: dict
	"""
	for (k, v)  in kv.items():
		if k not in d:
			d[k] = v
