#! python

"""
	-draw module/name arg_to_draw_fn
"""

import re
import math
import sys
import cPickle
import numpy
from gmisclib import die
from gmisclib import dictops
from gmisclib import load_mod
import gmisclib.Numeric_gpk as NG
from gmisclib import weighted_percentile as WP
from gmisclib import mcmc_logtools as LT
from gmisclib import mcmc_indexclass as IC

TRIM = 0.05
STRETCH = 0.15





def convergence(per_fn, argv, m, arg, selector, hdr):
	import pylab
	pd = m.pd_factory(argv, hdr=hdr)
	tmp = []
	for ll in selector(per_fn):
		pd.convergence(ll.idxr(), arg, tmp, pylab, ll.iter)
	pd.convergence(None, arg, tmp, pylab, None)
	pylab.show()



def do_pd_plot(per_fn, argv, m, arg, selector, hdr):
	import pylab
	pd = m.pd_factory(argv, hdr=hdr)
	for ll in selector(per_fn):
		pd.plot(ll.idxr(), arg, ll.iter, pylab)
	pd.plot(None, arg, None, pylab)	# This is to allow any post-processing after all the points are computed
	pylab.show()


def truncate(s, n):
	if len(s) > n:
		return s[:max(0,n-3)] + '...'
	return s


def plot_adjusted_logp(lll, colorspec):
	"""This plots the logp, adjusted back to T=1, on the assumption of
	a parabolic maximum.
	"""
	logpsa = []
	ia = []
	ndim = None
	for q in lll:
		if ndim is None:
			ndim = q.idxr().n()
		try:
			logpsa.append( q.logp()+0.5*ndim*(q.T()-1.0) )
			ia.append(q.iter)
		except KeyError:
			pass
	if logpsa:
		import pylab
		pylab.plot(ia, logpsa, "y:")


def plot_logp(per_fn, argv, m, arg, selector, hdr):
	import pylab
	maxpl = []
	minpl = []
	for (k, lll) in per_fn.items():
		plot_adjusted_logp(lll, 'k:')
		logps = [q.logp() for q in lll]
		pylab.plot([q.iter for q in lll], logps)
		pylab.xlabel('Iterations')
		pylab.ylabel('Log probability that model predicts data')
		pylab.title(truncate(' '.join(argv), 40))
		mn, mx = WP.wp(logps, None, [TRIM, 1.0-TRIM])
		minpl.append(mn)
		maxpl.append(mx)
	mn = min(minpl)
	mx = max(maxpl)
	pylab.ylim(mn - STRETCH*(mx-mn), mx + STRETCH*(mx-mn))
	pylab.show()


def do_pd_print(per_fn, argv, m, arg, selector, hdr):
	pd = m.pd_factory(argv, hdr)
	for ll in selector(per_fn):
		pd.do_print(ll.idxr(), arg, ll.iter)	# This is to allow any post-processing after all the points are computed
	pd.do_print(None, arg, None)
	sys.stdout.flush()


class ModelEvaluator(object):
	def __init__(self, mod, fcn, arg=None):
		self.mod = mod
		self.fcn = fcn
		self.arg = arg


	def __call__(self, per_fn, hdrs, selector, xargs):
		if 'Argv' in hdrs:
			argv = cPickle.loads(hdrs['Argv'])
		else:
			argv = hdrs['ARGV'].split()
		argv.extend(xargs)
		die.info("Args=%s" % (' '.join(argv)))
		return self.fcn(per_fn, argv, self.mod, self.arg, selector, hdrs)




def print_correlations(per_fn, selector):
	mean, covar, n, idxr_map = LT.indexer_covar(per_fn, selector)
	if covar is None:
		return
	rmap = dictops.rev1to1(idxr_map)
	evals, evecs = numpy.linalg.eigh(covar)
	m = evals.shape[0]
	mev = numpy.median(evals)
	for i in range(max(0, m-10), m):
		if evals[i] > max(3*mev, evals[-1]*0.03):
			print '# eigenvalue %d eval= %g (median eval= %g )' % (i, evals[i], mev)
			v = evecs[:,i]
			mxv = NG.N_maximum(numpy.absolute(v))
			tmp = []
			f = math.sqrt(evals[-1]/evals[i]) * 0.15
			for j in range(m):
				if abs(v[j]) > f*mxv:
					tmp.append( (abs(v[j]), IC.index._fmt(rmap[j]), v[j]) )
			tmp.sort()
			for (av, nmj, vj) in tmp[max(0, len(tmp)-15):]:
				print '#\t%.3f * %s' % (vj, nmj)



def process_logs(per_fn, hdr, selector=None, ProbStuff=[], xargs=[], latex=False):
	if latex:
		fmtps = '{1} & {0:.2f} \\\\'
		fmtn = '% n= {0:d}'
		fmtsu = '%  samples used = {0:d}; filename = {1}'
		fmtP = ['{2} & ${0:.2f} \pm {1:.1f}$ \\\\',
			'{1} & {0:.2f} \\\\'
			]
		fmtV = ['{2} & ${0:7g} \pm {1:6g}$ \\\\',
			'{1} & {0:7g} \\\\'
			]
	else:
		fmtps = '{0:.2f} {1}'
		fmtn = '# n= {0:d}'
		fmtsu = '#  samples used = {0:d}; filename = {1}'
		fmtP = ['{0:.2f} +- {1:.1f} {2}',
			'{0:.2f} {1}'
			]
		fmtV = ['{0:7g} +- {1:6g} {2}',
			'{0:7g} {1}'
			]
	assert selector is not None
	source = dictops.dict_of_accums()
	list(selector(per_fn, source))	# Selector is one of those stupid generators.
	for (fn, nsamp) in sorted(source.items()):
		print fmtsu.format(nsamp, fn)

	infolist = LT.logp_stdev(per_fn, selector)
	for (nm, avg, sig) in infolist:
		print fmtP[sig is None].format(avg, sig, IC.index._fmt(nm))
	for ps in ProbStuff:
		for (nm, pml) in ps(per_fn, hdr, selector, xargs):
			print fmtps.format(pml, IC.index._fmt(nm))

	avglist = LT.indexer_stdev(per_fn, selector, weight_by_T=True)
	print fmtn.format(len(avglist))
	avglist.sort(lambda a, b: LT.key_cmp(a[0], b[0]))
	for (nm, avg, sig) in avglist:
		print fmtV[sig is None].format(avg, sig, IC.index._fmt(nm))

	print_correlations(per_fn, selector)



def run(arglist):
	Selector = LT.all
	ProbStuff = []
	Trigger = LT.TRIGGER
	uid = None
	Draw = []
	xargs = []
	list_prms = False
	latex = False
	while arglist and arglist[0].startswith('-'):
		arg = arglist.pop(0)
		if arg == '-best':
			Selector = LT.overall_best
		elif arg == '-eachbest':
			Selector = LT.each_best
		elif arg == '-good':
			Selector = LT.some_after_convergence
		elif arg == '-eachgood':
			Selector = LT.near_each_max
		elif arg == '-last':
			Selector = LT.last
		elif arg == '-all':
			Selector = LT.all
		elif arg == '-uid':
			uid = arglist.pop(0)
		elif arg == '-xarg':
			xargs.append(arglist.pop(0))
		elif arg == '-plot':
			Draw.append( ModelEvaluator( None, plot_logp))
		elif arg in ['-convergence', '-Convergence']:
			use_sys_path = arg=='-convergence'
			modname = arglist.pop(0)
			convergence_arg = arglist.pop(0)
			Draw.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							convergence,
							convergence_arg,
							)
					)
		elif arg in ['-draw', '-Draw']:
			use_sys_path = arg=='-draw'
			modname = arglist.pop(0)
			draw_arg = arglist.pop(0)
			Draw.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							do_pd_plot,
							draw_arg
							)
						)
		elif arg in ['-print', '-Print']:
			use_sys_path = arg=='-print'
			modname = arglist.pop(0)
			print_arg = arglist.pop(0)
			Draw.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							do_pd_print,
							print_arg
							)
						)
		elif arg == '-prms':
			list_prms = True
		elif arg == '-latex':
			latex = True
		elif arg == '-fromstart':
			Trigger = None
		elif arg == '--':
			break
		elif arg in ['-ModelCompare', '-modelcompare']:
			use_sys_path = arg=='-modelcompare'
			modname = arglist.pop(0)
			bayes_arg = arglist.pop(0)
			ProbStuff.append( ModelEvaluator(load_mod.load_named(modname, use_sys_path),
							LT.P_bayes,
							bayes_arg
							)
						)
		else:
			die.die('Unrecognized flag: %s' % arg)

	if len(arglist) == 0:
		print __doc__
		die.die("Empty argument list!")
	per_fn, hdr = LT.read_uid_many_files(arglist, uid, Nsamp=1000, tail=0.0, trigger=Trigger)
	per_fn = LT.drop_files(per_fn, LT.FILE_DROP_FAC)

	if len(per_fn)==0:
		die.die("No data has been read in from %s" % arglist)
	made_output = False
	if list_prms:
		LT.list_prm_samples(per_fn, Selector, sys.stdout)
		made_output = True
	for drw in Draw:
		drw(per_fn, hdr, selector=Selector, xargs=xargs)
		made_output = True
	if latex or not made_output:
		process_logs(per_fn, hdr, selector=Selector, ProbStuff=ProbStuff, xargs=xargs, latex=latex)


if __name__ == '__main__':
	run(sys.argv[1:])
