#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""A linear or quadratic classifier,
but the test and training sets are defined by the
C{-group} flag.
The idea is that you can define groups and a group
is treated as a unit when the data is split into the
test and training set.

Why would you want to do this?    For instance,
if you are building classifiers to separate languages,
and you have several subjects per language, it is possible
that your subject is learning peculiarities of the individual
subjects, rather than properties of the language.

Without grouping, half of subject A's data might be in the
training set and half in the test set.    The classifier may
learn A's properties and then use that knowledge in the test.
Thus, without grouping, the classifier can succeed without generalizing
from subject to subject.

With grouping, the classifier needs to learn the language
from (e.g.) subjects A, B, C and then extrapolate that knowledge
onto other subjects from the same language (e.g. D, E, F).
Thus, the classifier is forced to learn general properties of the
language, not specific properties of the individual.

To use this, give the
C{-group PATTERN NUM}
switch.  PATTERN is a L{regular expression<re>}, possibly including parenthesized
regions (L{re.match.group}), and NUM is an integer that selects which region is
used as the group name.    NUM=0 means the entire regular
expression, NUM=1 means the first parenthesized region, etc.

See the usage notes and flags for C{l_classifier}.

@note: Classifying 2300 entries into 5 groups, based on a 3-dimensional feature
	vector takes about 40 minutes (in 2010, on a single processor).

@note: This code was described in an appendix to
	"Dimensions of durational variation in speech",
	by Anastassia Loukina, Greg Kochanski, Burton Rosner, Chilin Shih, and Elinor Keane,
	submitted 2010 to J. Acoustical Society of America.
"""


from gmisclib import die
from gmisclib import fiatio
from gmisclib import g_closure
from g_classifiers import q_classifier_r as Q
from g_classifiers import qd_classifier_guts as QC
from g_classifiers import l_classifier_guts as LC


def go_group_q(fd, n_per_dim=QC.N_PER_DIM, ftrim=None,
		grouper=None, coverage=QC.COVERAGE, ftest=QC.FTEST,
		verbose=True, modify_class=None):
	d = Q.read_data(fd)
	print '# classes: %s' % (' '.join(Q.list_classes(d)))
	print '# groups: %s' % (' '.join(Q.list_groups(d, grouper)))

	modelchoice = g_closure.Closure(QC.qd_classifier_desc, g_closure.NotYet,
					ftrim=ftrim)
	classout = fiatio.writer(open('classified.fiat', 'w'))
	classout.header('ftrim', ftrim)

	summary, out, wrong = Q.compute_group_class(
						d,
						n_per_dim=n_per_dim,
						modelchoice=modelchoice,
						builder=Q.forest_build,
						classout=classout, coverage=coverage,
						ftest=ftest, verbose=verbose,
						grouper=grouper, modify_class=modify_class
						)
	Q.default_writer(summary, out, classout, wrong)



def go_group_l(fd, n_per_dim=QC.N_PER_DIM, ftrim=None,
		grouper=None, coverage=QC.COVERAGE, ftest=QC.FTEST,
		verbose=True, modify_class=None):
	d = Q.read_data(fd)
	print '# classes: %s' % (' '.join(Q.list_classes(d)))
	print '# groups: %s' % (' '.join(Q.list_groups(d, grouper)))

	classout = fiatio.writer(open('classified.fiat', 'w'))

	modelchoice = g_closure.Closure(LC.l_classifier_desc, g_closure.NotYet,
					ftrim=ftrim)
	classout.header('ftrim', ftrim)

	summary, out, wrong = Q.compute_group_class(
						d,
						modelchoice=modelchoice,
						n_per_dim=n_per_dim,
						builder=Q.forest_build,
						classout=classout, coverage=coverage,
						ftest=ftest, verbose=verbose,
						grouper=grouper, modify_class=modify_class
						)

	Q.default_writer(summary, out, classout, wrong)




if __name__ == '__main__':
	import sys

	Modify = None
	Stdin = sys.stdin
	ftrim = None
	Verbose = 1
	arglist = sys.argv[1:]
	ct = 'Q'
	Gpat = None
	Gwhich = None
	while arglist and arglist[0].startswith('-'):
		arg = arglist.pop(0)
		if arg == '--':
			break
		elif arg == '-coverage':
			QC.COVERAGE = float(arglist.pop(0))
		elif arg == '-nperdim':
			QC.N_PER_DIM = int(arglist.pop(0))
		elif arg == '-Q':
			ct = 'Q'
		elif arg == '-L':
			ct = 'L'
		elif arg == '-ftest':
			QC.FTEST = float(arglist.pop(0))
			assert 0.0 < FTEST < 1.0
		elif arg == '-c':
			ftrim = float(arglist.pop(0))
			assert 0.0 <= ftrim < 0.5
		elif arg == '-quiet':
			verbose = 0
		elif arg == '-verbose':
			verbose = -1
		elif arg == '-flatten':
			Modify = Q.default_modify_class
		elif arg == '-i':
			Stdin = open(arglist.pop(0), 'r')
		elif arg == '-group':
			Gpat = arglist.pop(0)
			Gwhich = int(arglist.pop(0))
		elif arg == '-D':
			Q.D += 1
		else:
			die.die('Unrecognized argument: %s' % arg)
	if len(arglist) != 0:
		die.die("Extra arguments: %s" % str(arglist))
	if Gpat is None:
		die.die("You must specify a pattern for grouping with -group P w")

	grouper = Q.grouper_c(Gpat, Gwhich)
	if ct == 'Q':
		go_group_q(Stdin, coverage=QC.COVERAGE, n_per_dim=QC.N_PER_DIM,
					ftrim=ftrim, ftest=QC.FTEST, verbose=Verbose,
					grouper=grouper, modify_class=Modify)
	else:
		go_group_l(Stdin, coverage=QC.COVERAGE, n_per_dim=QC.N_PER_DIM,
					ftrim=ftrim, ftest=QC.FTEST, verbose=Verbose,
					grouper=grouper, modify_class=Modify)
