/*
 * This program takes two input sequences, one from human and one from mouse,
 * and calculates a conservation score for them.
 */

#include <stdio.h>
#include <strings.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <ctype.h>
#include "matrix.h"
#include "iterator.h"
#include "probability.h"
#include "const.h"

typedef struct _bq {
	Iterator *i;
	struct _bq *next;
} Bq;


double scale;

int debug=0;

// viz, A G C T 1 2 3 n


Iterator *get_node(char *tree, int val)
{
	// okay a node contains a branch length and sub-nodes
	int depth=0;
	int leaf = 1;
	char *pos=0, *fin=0,*p;
	char rev[256];
	//printf("Checking %s\n",tree);
	if (tree[0] != '(') {
		fprintf(stderr,"Malformed branch!\n");
		exit(1);
	}
	rev['A'] = 0;
	rev['G'] = 1;
	rev['C'] = 2;
	rev['T'] = 3;
	rev['a'] = 0;
	rev['g'] = 1;
	rev['c'] = 2;
	rev['t'] = 3;
	rev['1'] = 4;
	rev['2'] = 5;
	rev['3'] = 6;
	rev['4'] = 7;
	Iterator *n = new Iterator();
	n->val = val;
	n->rprob = high();
	n->nprob = zero();

	// first check for leaf condition
	p = tree+1;
	while (1) {
	if (*p== '(') { leaf = 0; break;}
	if (!*p) break;
	p++;
	}

	if (!leaf) {
		// get subnode 1
		p = tree+1;
		while (1) {
			if (!*p) { fprintf(stderr,"No matching paren.\n"); exit(1); }
			if (depth < 0) { fprintf(stderr,"Unmatched paren.\n"); exit(1); }
			if (*p=='(') { if (!depth) pos = p; depth++; }
			if (*p==')') { depth--; if (!depth) { fin = p; break;} }
			p++;
		}
		int len = fin - pos;
		char buf[500];
		strncpy(buf,pos,len+1);
		buf[len+1] = 0;
		n->right = get_node(buf,val+1);
		n->right->parent = n;
	
		p = fin+1;
		while (1) {
			if (!*p) { fprintf(stderr,"No matching paren.\n"); exit(1); }
			if (depth < 0) { fprintf(stderr,"Unmatched paren.\n"); exit(1); }
			if (*p=='(') { if (!depth) pos = p; depth++; }
			if (*p==')') { depth--; if (!depth) { fin = p; break; } }
			p++;
		}
		len = fin-pos;
		strncpy(buf,pos,len+1);
		buf[len+1]=0;
		//printf("%s\n",buf);
		n->left = get_node(buf,val+30);
		n->left->parent = n;
	} else {
		char *a=0, *b=0, *c, bf[500];
		int ln;
		c = tree+1;
		while (*c) {
			if (*c==')') { b = c; break; }
			if (*c==' ') a = c;
			c++;
		}
		while (isspace(*a)) a++;
		if (!b || !a) { fprintf(stderr, "No branch pattern.\n"); exit(1); }
		ln = b - a;
		strncpy(bf,a,ln);
		bf[ln] = 0;
		n->set_base_state(rev[bf[0]], rev[bf[1]], rev[bf[2]]);
		n->set_species(bf+3);
	}
	
	//printf("%s\n", tree);
	p = tree+1;
	while(*p != ' ' && *p != '(' && *p != ')') {p++; if (!*p) { break; }}
	*p = 0;
	n->len = atof(tree+1);
	n->len /= scale;
	// see if it's cached

	n->build_matrices(m3, f3, f2, f1, in, de, n->len, _prob);
	return n;
}

Iterator *load_tree(char *tree)
{
	Iterator *n;
	n = get_node(tree,1);
	return n;
}

// A -> T
// A -> C

// transition probability of a -> b
double freq(int i1, int i2, int i3)
{
	int sub = 0, let = 1;
	if (i1 < G1) { sub += i1;let *= LET; }
	if (i2 < G1) { sub += i2*let;let *= LET; }
	if (i3 < G1) { sub += i3*let;let *= LET; }
	if (let==1) return 1;
	if (let == LET) return f1[sub];
	if (let == LET*LET) return f2[sub];
	if (let == LET*LET*LET) return f3[sub];
	return 1;
}

double freq(Stack *a)
{
	return freq(a->one,a->two,a->thr);
}

double gval(Iterator *i, int i1, int i2, int i3);

double probability(Stack *a, Stack *b,Iterator *i)
{
	if (!a) return freq(b);
	return _prob(a->one,a->two,a->thr,b->one,b->two,b->thr,i);
}

double tree_prob(Iterator *i, Stack *s)
{
	char v[20];
	v[0] = 'a';
	v[1] = 'g';
	v[2] = 'c';
	v[3] = 't';
	v[4] = '1';
	v[5] = '2';
	v[6] = '3';
	v[7] = '4';

	return i->m24[s->two+i->record->two*GLET][s->one+s->thr*GLET+i->record->one*GLET*GLET+i->record->thr*GLET*GLET*GLET];
}

double flat_prob(Iterator *i, Stack *s)
{
	return _prob(s->one,s->two,s->thr,i->record->one,i->record->two,i->record->thr,i);
}


double gs(int i1, int i2, int i3, Iterator *i)
{
	if (!i) return 1;

	double r = high();

	// you're a leaf


	if (i->is_leaf()) {
		i->record_state(i->b1,i->b2,i->b3, 0, 0);
		r = subst( i1,i2,i3, i->b1,i->b2,i->b3, i, r,0);
	//printf("GS: %e\n", r);
	return r;
	}

	int j1,j2,j3;

	for (j1=0;j1<GLET;j1++)
	for (j2=0;j2<GLET;j2++)
	for (j3=0;j3<GLET;j3++) {
		if (is_legal_nuc(j1,j2,j3)) {
		double l = 0;
		double s = 0;
		if (!i->is_leaf()) l = gval(i, j1, j2, j3);
		s = subst( i1,i2,i3, j1,j2,j3, i, r,l);
		if ((s+l) <= r) {
			// check what the likelihood of the current one
			// is... if it's lower than this one, don't replace
			// it.
			if (r == s+l) {
				int temp = i->gn[IN(i1,i2,i3)];
				i->gn[IN(i1,i2,i3)] = IN(j1,j2,j3);
				i->record_state(i1,i2,i3,0,0);
				i->make_record();
				double l = i->record_likelihood(flat_prob);
				if (l > i->gl[IN(i1,i2,i3)]) {
					// take the more likely scenario.
					i->gl[IN(i1,i2,i3)] = l;
				} else {
					i->gn[IN(i1,i2,i3)] = temp;
				}
			} else {
				i->gn[IN(i1,i2,i3)] = IN(j1,j2,j3);
			}
			r = s+l;
		}
		}
	}
//	printf("GS: %e\n", r);
	return r;
}

double gval(Iterator *i, int i1, int i2, int i3)
{
	if (i->gm[IN(i1,i2,i3)]) return i->gv[IN(i1,i2,i3)];
	
	char v[20];
	v[0] = 'a';
	v[1] = 'g';
	v[2] = 'c';
	v[3] = 't';
	v[4] = '1';
	v[5] = '2';
	v[6] = '3';
	v[7] = '4';
	double a,b;
	a = gs(i1,i2,i3,i->right);
	b = gs(i1,i2,i3,i->left);
	//if (sum(a,b) < i->rprob && branch_prob(i,i1,i2,i3) > i->nprob) {
	//}
	i->gv[IN(i1,i2,i3)] = sum(a, b);
	i->gm[IN(i1,i2,i3)] = 1;
	//printf("C: %f %f\n", a,b );
	return i->gv[IN(i1,i2,i3)];
}

double compute(Iterator *i) {
	// Compute probability of this tree
	// which is the product of each of the branches
	//double c = recursive_probability(i,NULL);

	double val = high();
	double grr = 0;
	Stack *b = new Stack;
	int i1,i2,i3;
	for (i1=0;i1<GLET;i1++) {
	for (i2=0;i2<GLET;i2++) {
	for (i3=0;i3<GLET;i3++) {
		if (!is_legal_nuc(i1,i2,i3)) continue;
		b->one = i1;
		b->two = i2;
		b->thr = i3;
		b->gap = 0;
		double s = gval(i,i1,i2,i3);
		if (s <= val) {
			// check what the likelihood of the current one
			// is... if it's lower than this one, don't replace
			// it.
			if (s == val) {
				Stack temp = *(i->record);
				i->record_state(i1,i2,i3,0,0);
				i->make_record();
				double l = i->record_likelihood(flat_prob);
				if (l > i->gl[IN(i1,i2,i3)]) {
					// take the more likely scenario.
					i->gl[IN(i1,i2,i3)] = l;
				} else {
					i->record_state(temp.one,temp.two,temp.thr,s,0);
				}
			} else {
				i->record_state(i1,i2,i3,s,0);
			}
			val = s;
		}
	}
	}
	}
 	
	return val;
}

int check_states( Iterator *i, int a, int b, int c)
{
	if (i->is_leaf()) return 1;
	if (i->states->one != a
	|| i->states->two != b
	|| i->states->thr != c) return 0;

	int q = 1;

	if (i->right && !check_states(i->right,a,b,c)) q = 0;
	if (i->left && !check_states(i->left,a,b,c)) q = 0;
	return q;
}

int has_state( Iterator *i, int a, int b, int c)
{
	if (i->states->one != a
	|| i->states->two != b
	|| i->states->thr != c) return 0;

	return 1;
}


void run_states(Iterator *i)
{
	//printf("Running states...\n");
	double p=0;
	int maxs=3500;
	double maxl=0;
	// get most parsimonious tree - how?
	double q = compute(i);
	i->make_record();
	printf("%e\n",q);
	//printf("%e vs %e\n",i->tree_score(0,0,tree_prob), i->tree_score(1,0,tree_prob));
	printf("%e vs %e\n",i->tree_score(0,1,tree_prob), i->tree_score(1,1,tree_prob));
}

int main(int argc, char **argv)
{
	Iterator *i, *p;
	int j;


	scale = 1;
	load_matrices();
	if (argc > 2) scale = atof(argv[2]);
	i = load_tree(argv[1]);
	for (p=i->tail;p;p=p->breadth_prev)
		p->init_state();
	run_states(i);
	i->print_state();
}
