#include "iterator.h"
#include "const.h"
#include <iostream>

// how to iterate: Run through your states. if you run out of states,
// go to your breadth_next and step to it's next state. Every time you
// step, re-initialize everyone who is breadth-previous to you (who, 
// hopefully, have run out of states if they are talking to you.)
#include <stdio.h>

using namespace std;
extern Stack allnucs[];

#define ACTIVE(a) (a && a->tflag)
#define IS_SET(a,b) (a & (1 << b))
Matrix<double> mo;
Matrix<double> *mhash[5000];
int mf[5000];
int m_uses[5000];
int m_lastuse[5000];
int mfcount;
int m24count;
int m33count;
Matrix<double> *mihash[5000];
Matrix<double> *m24ihash[5000];

Matrix<double> *m24hash[5000];
Matrix<double> *m33hash[5000];
Matrix<double> *m33chash[5000];
Matrix<double> *m31hash[5000];
Row<double> f0;
Row<double> f1;
Row<double> f2;
Row<double> f3;
Row<double> f3c;
Row<double> in;
Row<double> de;
int timer=0;


extern int computing;
extern int initialize;
extern int gapless;
extern int runmode;
extern int footprint;
extern double interval;
extern int rebuilding;
int memuse=0;
void shift2(double &a, double &b, const double c);
void shift3(double &a, double &b, double &c, const double d);


Iterator::~Iterator() {
	if (right) delete right;
	if (left) delete left;
	free(gm);
	free(gv);
	free(gn);
	free(gl);
	if (m33 && !m33->hashed) delete m33;
	if (m24 && !m24->hashed) delete m24;
	if (m3 && !m3->hashed) delete m3;
	if (m2) delete m2;
	if (m1) delete m1;
	if (record) delete record;
}

int usethresh = 10;

void clearhash() {
	// we're going to be smart about memory management - save things that
	// have been used more than usethresh times, or things that have been used
	// recently
	int cleared = 0, counted = 0;
	if (debug>=2) cerr<<"Memuse at "<<memuse<<"\n";
	for (int i=0;i<5000;i++) {
		if (mf[i]) {
			counted++;
			if (m_lastuse[i] > timer - 50) continue;
			if (m_uses[i] > usethresh) continue;
			cleared++;
			mf[i] = 0; // do NOT reset uses, something might become important later!
			delete mhash[i];
			delete m24hash[i];
			delete m33hash[i];
			memuse -= GLET*GLET*GLET*GLET*GLET*GLET*sizeof(double)*3;
			if (runmode == MODE_PARSIMONY
			   || runmode == MODE_RATE) {
			delete m33chash[i];
			delete m31hash[i];
			}
		}
		
	}
	if (counted / cleared > 5 || counted > 500) // become more of a hardass - 
		// we're probably late in the game.
		usethresh += 10;
	mfcount = 0;
	m24count = 0;
	if (debug>=2) cerr<<"Memuse down to "<<memuse<<"\n";
}

void load_matrices(string matrixFile, string freqFile, string insFile, string delFile)
{
	Matrix<double> t3(LET*LET*LET,LET*LET*LET);
	int i,j,ij,ab,a,b;
	char line[1500];
	double q;
	FILE *fp;

	bzero(mf,sizeof(int)*5000);
	bzero(m_uses,sizeof(int)*5000);
	bzero(m_lastuse,sizeof(int)*5000);
	mo.SetSize(LET*LET*LET,LET*LET*LET);
	f0.SetRowSize(1);
	f1.SetRowSize(LET);
	f2.SetRowSize(LET*LET);
	f3.SetRowSize(LET*LET*LET);
	f3c.SetRowSize(LET*LET*LET);
	
	f0[0] = 1;

	// load initial transition frequency matrix Q
	fp = fopen(matrixFile.c_str(),"r");
	mo.load(fp);
	fclose(fp);
	for (i=0;i<LET*LET*LET;i++) {
		double q = 0;
		for (j=0;j<LET*LET*LET;j++) if (i!=j) q-=mo[i][j];
		mo[i][i] = q;
	}
	fp = fopen(freqFile.c_str(),"r");
	i=0; j=0;
	q = 0;
	while(fgets(line,1500,fp)) {
		f3[i] = atof(line);
		q+=f3[i++];
	}
	f3 = f3 * (1.0/q);
	fclose(fp);
	int i1,i2,i3;
	for (i1=0;i1<LET;i1++)
		for (i3=0;i3<LET;i3++) {
			q = 0;
			double l = 0;
			for (i2=0;i2<LET;i2++) {
				q += f3[IN2(i1,i2,i3)];
			}
			for (i2=0;i2<LET;i2++) {
				l += f3[IN2(i1,i2,i3)];
				f3c[IN2(i1,i2,i3)] = l / q;
			}
		}
		
	for (i=0;i<LET*LET;i++) {
		for (j=0;j<LET;j++) {
			f2[i] += f3[i+j*LET*LET];
		}
	}
	for (i=0;i<LET;i++) {
		for (j=0;j<LET;j++) {
			f1[i] += f2[i+j*LET];
		}
	}

	fp = fopen(insFile.c_str(), "r");
	in.load(fp);
	fclose(fp);
	if (in.size() != 5) {
		cerr<<"Wrong length for insertions vector (must be 5).\n";
		exit(1);
	}
	fp = fopen(delFile.c_str(), "r");
	de.load(fp);
	fclose(fp);
	if (de.size() != 5) {
		cerr<<"Wrong length for deletions vector (must be 5).\n";
		exit(1);
	}
}

Iterator::Iterator()
{
	right = 0;
	left = 0;
	next = 0;
	parent = 0;
	tail = 0;
	tflag = 0;
	record = 0;
	ref = 0;
	sis = 0;
	mod = 1;
	m3 = 0;
	m24 = 0;
	m33 = 0;
	m1 = 0;
	m2 = 0;
	p1 = p2 = p3 = 0;
	ins_state = 0;
	gm = (short int*)calloc(GLET*GLET*GLET,sizeof(short int));
	gv = (double *)calloc(GLET*GLET*GLET,sizeof(double));
	gn = (int *)calloc(GLET*GLET*GLET,sizeof(int));
	gl = (double *)calloc(GLET*GLET*GLET,sizeof(double));
}


Iterator::Iterator(char *tree, Iterator *par, double sca)
{
	// NH format defines trees with arbitrary numbers of nodes, as such:
	// (A:W.0,(B:Y.0,C:Z.0),D:X.0)
	// We're going to assume they have actual binary trees that they're
	// feeding us - otherwise we'll just barf somehow
	
	scale = sca;
	right = 0;
	left = 0;
	next = 0;
	parent = par;
	tail = 0;
	tflag = 0;
	record = 0;
	ref = 0;
	sis = 0;
	mod = 1;
	m24 =0;
	m33 = 0;
	m1 = 0;
	m2 = 0;
	m3 = 0;
	ins_state = 0;
	p1 = p2 = p3 = 0;
	gm = (short int*)calloc(GLET*GLET*GLET,sizeof(short int));
	gv = (double *)calloc(GLET*GLET*GLET,sizeof(double));
	gn = (int *)calloc(GLET*GLET*GLET,sizeof(int));
	gl = (double *)calloc(GLET*GLET*GLET,sizeof(double));
	
	if (!tree) {
		cerr<<"Malformed branch! (No tree)\n";
		exit(1);
	}
	char *p;
	if (tree[0] != '(') {
		for (p=tree;*p;p++) {
			if (*p == '(' || *p == ')' || *p == ':')
			{
				cerr<<"Malformed branch! (improper name)\n";
				exit(1);
			}
		}
		species = tree;
		return;
	}

	// LEFT NODE //
	int depth = 0;
	p = tree+1;
	while (p) {
		if (!depth && *p == ':') break;
		if (*p == '(') depth++;
		if (*p == ')') depth--;
		if (depth < 0) {
			cerr<<"Malformed branch! (out of depth in left node)\n";
			cerr<<tree;
			cerr<<".\n";
			exit(1);
		}
		p++;
	}
	char buf[1024];
	int len = p - tree - 1;
	strncpy(buf,tree+1,len);
	buf[len] = 0;
	left = new Iterator(buf,this,sca);
	char *comma = p;
	while (comma) {
		if (!*comma) {
			cerr<<"Malformed branch! (no comma in left node)\n";
			exit(1);
		}
		if (*comma == ',') break;
		comma++;
	}
	len = comma - p - 1;
	strncpy(buf,p+1,len);
	buf[len] = 0;
	left->len = atof(buf);
	left->scale = sca;
	// END LEFT NODE //

	// RIGHT NODE //
	p = comma+1;
	while (p) {
		if (!depth && *p == ':') break;
		if (*p == '(') depth++;
		if (*p == ')') depth--;
		if (depth < 0) {
			cerr<<"Malformed branch! (out of depth in right node)\n";
			cerr<<tree;
			cerr<<".\n";
			exit(1);
		}
		p++;
	}
	len = p - comma - 1;
	strncpy(buf,comma+1,len);
	buf[len] = 0;
	right = new Iterator(buf,this,sca);
	// now find sub-string from colon to comma.
	char *paren = p;
	while (paren) {
		if (!*paren) {
			cerr<<"Malformed branch! (no paren in right node)\n";
			exit(1);
		}
		if (*paren == ')') break;
		paren++;
	}
	len = paren - p - 1;
	strncpy(buf,p+1,len);
	buf[len] = 0;
	right->len = atof(buf);
	right->scale = sca;
	// END RIGHT NODE //
}

void Iterator::reset()
{
 	memset(gm,0,sizeof(short int)*GLET);
 	memset(gl,0,sizeof(double)*GLET);
 	memset(gn,0,sizeof(int)*GLET);
	rprob = 0;
	tflag = 0;
	if (right) right->reset();
	if (left) left->reset();
}

void Iterator::set_base_state(int a, int b, int c)
{
	b1 = a; b2 = b; b3 = c;
}

double Iterator::substitute(int sub1, int sub2, int let)
{
	if (let == LET) return (*m1)[sub1][sub2];
	if (let == LET*LET) return (*m2)[sub1][sub2];
	if (let == LET*LET*LET) return (*m3)[sub1][sub2];
	return 1.0;
}

double Iterator::del(int d, int isize)
{
	if (!d) {
		return 1 - ((isize+1-1)*deletion[1]+(isize+2-1)*deletion[2]+(isize+3-1)*deletion[3]+(isize+7.47-1)*deletion[4]);
	}
	else
		return (isize+((d==4)?2*d:d)-1)*deletion[d];
}

double Iterator::ins(int i)
{
	return insertion[i];
}

double Iterator::tree_len()
{
	double l = len;
	if (right) l += right->tree_len();
	if (left) l += left->tree_len();
	return l;
}

double Iterator::longest_branch()
{
	double r=0, l = len;
	if (right) r = right->longest_branch();
	if (r > l) l = r;
	if (left) r = left->longest_branch();
	if (r > l) l = r;
	return l;
}

double Iterator::find_node(string s)
{
	// return length to node s
	double l;
	if (species == s) {
		return len;
	}
	if (right && (l = right->find_node(s))) {
		return l+len;
	}
	if (left && (l = left->find_node(s))) {
		return l+len;
	}
	return 0;
}

double Iterator::distance(string s1, string s2)
{
	// find distance between s1 and s2
	double dist1, dist2;
	if (!right && !left) return 0;
	if (dist1 = right->find_node(s1)) // you've got s1 in your right node
	{
		if (dist2 = right->find_node(s2)) // you've got s2 in your right node, as well. You're not the root.
		{
			return right->distance(s1,s2);
		} else {
			dist2 = left->find_node(s2);
			return dist1+dist2;
		}
	} else { // it's not in your right node, it must be in your left
		if (dist2 = left->find_node(s2))
		{
			return left->distance(s1,s2);
		} else {
			dist2 = right->find_node(s2);
			return dist1+dist2;
		}
	}
	return 0;
}

Matrix<double> *interpolate_matrix(double time, Matrix<double> **input)
{
	int low = (int)(time/interval);
	int high = (int)(time/interval+1);
	return new Matrix<double>(input[high]->interpolate(*input[low],time/interval-low));
}

double scales[6] = {0.01,0.1,1,10,100,1000};
double scalec[6] = {100, 10, 1, 0.1, 0.01, 0.001};
double df[6] = {0, 1.0, 2.0, 3.0, 4.0, 5.0};
double lf[6] = {0, 8.0, 17.0, 26.0, 35.0, 44.0};
int scalef;
double scalefd;
double lratio = 0.301029995663981;
double rratio = 2.30258509299405;
double junk;
#define FLOAT_TO_INT(in,out)  \
                    __asm__ __volatile__ ("fistpl %0" : "=m" (out) : "t" (in) : "st") ;

double interpolate_probability(double time, Matrix<double> **input, int i, int j)
{
	
	if (!time) return input[0]->raw(i,j);
	int low,high;
	double c,x;
	if (time < 0.01) {
		c = time/0.01;
		low = 0;
		high = 1;
		x=c;
	} else {
	int f;
	double cf=floor(log(time)/rratio)+2.0;
	
	for (f=0;f<6;f++) {
		if (df[f] == cf) { break; }
	}
	c = (time*scalec[f] + lf[f]) * scalefd + 1.0;
	x = modf(c,&junk);
	
	FLOAT_TO_INT(c,low);
	//low = (int)c;
	if (x > 0.5) low--;
	high = low+1;
	}

	// use raw method for this, no operator bullshit.
	double s = input[low]->raw(i,j);
	
	return (input[high]->raw(i,j)-s)*(x) + s;
}

void compute_matrix(Iterator *it, pfunc p) {
	int a1,a3,b1,b3,i,j;
	for (a1=0;a1<GLET;a1++)
	for (a3=0;a3<GLET;a3++)
	for (b1=0;b1<GLET;b1++)
	for (b3=0;b3<GLET;b3++) {
		for (i=0;i<LET;i++) {
			double b = 0;
			for (j=0;j<LET;j++) {
				b += ((*it->m33)[a1+i*GLET+a3*GLET*GLET][b1+j*GLET+b3*GLET*GLET]
						= p(a1,i,a3,b1,j,b3,3,it));
			}
			if (!b) // uh-oh - just give it an infinitesmal value
				b = 1e-20;
			for (j=0;j<LET;j++) {
				(*it->m24)[i+j*GLET][a1+a3*GLET+b1*GLET*GLET+b3*GLET*GLET*GLET]
				= (*it->m33)[a1+i*GLET+a3*GLET*GLET][b1+j*GLET+b3*GLET*GLET]/b;
				if ((*it->m24)[i+j*GLET][a1+a3*GLET+b1*GLET*GLET+b3*GLET*GLET*GLET] < 0) {
					cerr<<"Less than zero!"<<let[a1]<<let[i]<<let[a3]<<" "<<let[b1]<<let[j]<<let[b3]<<(*it->m3)[a1+i*LET+a3*LET*LET][b1+j*LET+b3*LET*LET]<<"\n";
					exit(0);
				}
			}
		}
			
		for (i=G1;i<GLET;i++) {
			for (j=0;j<LET;j++) {
				(*it->m33)[a1+i*GLET+a3*GLET*GLET][b1+j*GLET+b3*GLET*GLET]
						= p(a1,i,a3,b1,j,b3,3,it);
				(*it->m24)[i+j*GLET][a1+a3*GLET+b1*GLET*GLET+b3*GLET*GLET*GLET]
						= it->ins(i-G1+1);
			}
			for (j=G1;j<GLET;j++) {
				(*it->m33)[a1+i*GLET+a3*GLET*GLET][b1+j*GLET+b3*GLET*GLET]
						= p(a1,i,a3,b1,j,b3,3,it);
				(*it->m24)[i+j*GLET][a1+a3*GLET+b1*GLET*GLET+b3*GLET*GLET*GLET]
						= 1;
			}
		}
		for (i=0;i<LET;i++) {
			for (j=G1;j<GLET;j++) {
				(*it->m33)[a1+i*GLET+a3*GLET*GLET][b1+j*GLET+b3*GLET*GLET]
								= p(a1,i,a3,b1,j,b3,3,it);
				(*it->m24)[i+j*GLET][a1+a3*GLET+b1*GLET*GLET+b3*GLET*GLET*GLET]
						= it->del(j-G1+1,3);//*((j==G1+3)?(4.84):(j-G1+1));
			}
		}
	}
}


void init_interpolation()
{
	if (!interval) return;
	
	int a, b, i, j, num = (int)(footprint*0.3)+1;
	scalef = (num-1)/5;
	scalefd = scalef/9.0;
	Iterator *it = new Iterator();
	if (debug) {
		cerr<<"Initializing matrices for interpolation... ";
		if (debug >= 2) cerr<<"("<<num<<" matrices)";
		cerr<<"\n";
	}
	pfunc p = _prob;
	// okay, there's five phases:
	// 0.01-0.1; 0.1-1.0; 1.0-10; 10-100; 100-1000;
	// we cut num into five intervals, in each of which the above factor applies.
	num = scalef * 5+2;
	
	for (a = 0; a < num; a++) {
		// we want to be able to map until 0.1-fold of, let's say, 0.5*scale.
		// this is a time of 0.05. 
		// compute m3 matrix
		if (debug >=2) cerr<<".";
		double time = 0;
		if (a) {
			int t = (a-1)/scalef;
			time = (a-1) * 9.0/scalef * scales[t] - (t*9-1)*scales[t];
		}
		it->m3 = new Matrix<double>(mo.pade(time));
		memuse += LET*LET*LET*LET*LET*LET*sizeof(double);
		mihash[a] = it->m3;
		
		int a1,a2,a3,b1,b2,b3;
		it->m24 = new Matrix<double>();
		it->m33 = new Matrix<double>();
		it->m24->SetSize(GLET*GLET,GLET*GLET*GLET*GLET);
		it->m33->SetSize(GLET*GLET*GLET,GLET*GLET*GLET);
		memuse += GLET*GLET*GLET*GLET*GLET*GLET*sizeof(double)*2;
			
		if (it->m1) delete it->m1;
		if (it->m2) delete it->m2;
		it->m1 = new Matrix<double>();
		it->m2 = new Matrix<double>();
		it->m1->SetSize(LET,LET);
		it->m2->SetSize(LET*LET,LET*LET);

		for (i=0;i<LET*LET;i++)
			for (j=0;j<LET*LET;j++) {
			double q=0;
			int n,m;
			for (n=0;n<LET;n++)
				for (m=0;m<LET;m++) {
				q += f3[i+n*LET*LET]*(*it->m3)[i+n*LET*LET][j+m*LET*LET];
				}
				q /= f2[i];
				(*it->m2)[i][j] = q;
			}
			for (i=0;i<LET;i++)
				for (j=0;j<LET;j++) {
				double q=0;
				int n,m;
				for (n=0;n<LET;n++)
					for (m=0;m<LET;m++) {
					q += f2[i+n*LET]*(*it->m2)[i+n*LET][j+m*LET];
					}
					q /= f1[i];
					(*it->m1)[i][j] = q;
				}
		// build indels while you're at it
		it->insertion = in * (time);
		it->deletion = de * (time);
		it->insertion[0] = 1 - (it->insertion[1]+it->insertion[2]+it->insertion[3]+it->insertion[4]);
		
		compute_matrix(it,p);

		m24ihash[a] = it->m24;
	}
	it->m3 = 0;
	it->m24 = 0;
	delete it;
}

void Iterator::build_matrices(double time, pfunc p)
{
	int i,j;

	timer++;
	char buf[500];
	FILE *fp;
	Matrix<double> *t;
	actual_time = time;
	int s = (int)(time*10+0.5);
	if (rebuilding) return;
	if (mf[s]) {
		t = mhash[s];
		m_uses[s]++;
		m_lastuse[s] = timer;
	} else {
		// calculate it, then cache it.
		t = new Matrix<double>(mo.pade(time));
		memuse += GLET*GLET*GLET*GLET*GLET*GLET*sizeof(double);
		mhash[s] = t;
		t->hashed=1;
		m_uses[s]++;
		m_lastuse[s] = timer;
		mfcount++;
	}
	
	if (m2) delete m2;
	if (m1) delete m1;

	if (m33 && !m33->hashed) delete m33;
	if (m24 && !m24->hashed) delete m24;
	if (m3 && !m3->hashed) delete m3;
	
	m3 = t;
	m2 = new Matrix<double>();
	m2->SetSize(LET*LET,LET*LET);
	m1 = new Matrix<double>();
	m1->SetSize(LET,LET);
	for (i=0;i<LET*LET;i++)
		for (j=0;j<LET*LET;j++) {
			double q=0;
			int n,m;
			for (n=0;n<LET;n++)
			for (m=0;m<LET;m++) {
				q += f3[i+n*LET*LET]*(*m3)[i+n*LET*LET][j+m*LET*LET];
			}
			q /= f2[i];
			(*m2)[i][j] = q;
		}
	for (i=0;i<LET;i++)
		for (j=0;j<LET;j++) {
			double q=0;
			int n,m;
			for (n=0;n<LET;n++)
			for (m=0;m<LET;m++) {
				q += f2[i+n*LET]*(*m2)[i+n*LET][j+m*LET];
			}
			q /= f1[i];
			(*m1)[i][j] = q;
		}
	// build indels while you're at it
	insertion = in * (time);
	deletion = de * (time);
	insertion[0] = 1 - (insertion[1]+insertion[2]+insertion[3]+insertion[4]);
	//deletion[0] = 1 - ((3+1-1)*deletion[1]+(3+2-1)*deletion[2]+(3+3-1)*deletion[3]+(3+4.84-1)*deletion[4]);
	// two by four!

	if (mf[s]) {
		m24 = m24hash[s];
		m33 = m33hash[s];
		m33c = m33chash[s];
		m31 = m31hash[s];
	} else {
		int a1,a2,a3,b1,b2,b3;
		m24 = new Matrix<double>();
		m24->SetSize(GLET*GLET,GLET*GLET*GLET*GLET);
		m33 = new Matrix<double>();
		m33->SetSize(GLET*GLET*GLET,GLET*GLET*GLET);
		memuse += GLET*GLET*GLET*GLET*GLET*GLET*sizeof(double)*2;
		m33c = new Matrix<double>();
		m33c->SetSize(LET*LET*LET,LET*LET*LET);
		m31 = new Matrix<double>();
		m31->SetSize(LET*LET*LET,1);
		
		// FILL m24 and m33
		compute_matrix(this,p);
		// end fill m24 and m33

		// fill m33c and m31
		int size = LET;
		for (a1=0;a1<size;a1++)
		for (a2=0;a2<size;a2++)
		for (a3=0;a3<size;a3++) {
			double total = 0;
			for (b3=0;b3<size;b3++)
			for (b2=0;b2<size;b2++)
			for (b1=0;b1<size;b1++) {
				(*m33c)[IN2(a1,a2,a3)][IN2(b1,b2,b3)] = total + (*m3)[IN2(a1,a2,a3)][IN2(b1,b2,b3)];
				total = (*m33c)[IN2(a1,a2,a3)][IN2(b1,b2,b3)];
			}
			(*m31)[IN2(a1,a2,a3)][0] = total;
		}
		// end fill m33c and m31

		m33chash[s] = m33c;
		m31hash[s] = m31;
		m31->hashed=1;
		m33c->hashed=1;
	
		m24hash[s] = m24;
		m33hash[s] = m33;
		m24->hashed=1;
		m33->hashed=1;

		mf[s] = 1;
		m24count++;
		m33count++;
	}
}

void Iterator::record_state(int a, int b, int c, double p, double n)
{
	if (!record) record = new Stack;
	record->one = a;
	record->two = b;
	record->thr = c;
	rprob = p;
	nprob = n;
}

void Iterator::print_state(int depth)
{
	if (right) right->print_state(depth+1);
	if (left) left->print_state(depth+1);
	if (!record) return;

	int i;
	for (i=0;i<depth;i++) { printf(" ");}
	if (is_leaf()) printf("%s ",species.c_str());
	printf("%c%c%c (%d %d %d) %e %e\n", let[record->one], let[record->two], let[record->thr] , record->one, record->two, record->thr, rprob, len);
}

void Iterator::print_root()
{
	if (!record) return;

	printf("%c%c%c\n", let[record->one], let[record->two], let[record->thr]);
}

void Iterator::make_record()
{
	if (!record) return;
	if (right) {
		right->record->one = right->gn[IN(record->one,record->two,record->thr)];
		right->make_record();
	}
	if (left && !left->is_leaf()) {
		
		left->record->one = left->gn[IN(record->one,record->two,record->thr)];
		left->make_record();
	}
}

void Iterator::flash_record(int l)
{
	if (!record) return;
	if (l == 0) p1 = record->one;
	if (l == 1) p2 = record->one;
	if (l == 2) p3 = record->one;
	if (right) {
		right->flash_record(l);
	}
	if (left) {
		left->flash_record(l);
	}
}

void Iterator::record_from_flash()
{
	if (!record) record = new Stack();
	record->one = p1;
	record->two = p2;
	record->thr = p3;
	if (right) {
		right->record_from_flash();
	}
	if (left) {
		left->record_from_flash();
	}
}


// alternate score - rate computation
void Iterator::rate_branch_score(double &c, vector<double> &mr1, vector<double> &mr2)
{
	int sub = 0;
	if (parent && parent->record->two != record->two) sub = 1;
	if (sub) {
		mr1.push_back(tree_prob(this,this->record) - 1);
		mr2.push_back(tree_prob(parent->record->one,parent->record->two,parent->record->thr,
			      parent->record->one,parent->record->two,parent->record->thr,this) - 1);
	} else {
		c += tree_prob(this,parent->record) - 1;
	}
	if (right) right->rate_branch_score(c,mr1,mr2);
	if (left) left->rate_branch_score(c,mr1,mr2);
}

double exp_sum(vector<double> total, vector<double> signs, double x)
{
	double d = 0;
// 	if (debug) cerr<<"Total size is"<<total.size()<<" signs is "<<signs.size()<<"\n";
	for (int i=0;i<total.size();i++)
	{
//  		if (debug) cerr<<total[i]<<"\n";
		d -= signs[i] * exp(x * total[i]);
	}
// 	if (debug) cerr<<"Sum is "<<d<<"\n";
	return d;
}

double Iterator::golden_max(vector<double> total, vector<double> signs)
{
	// maximize for total
	double ax=0.0, bx=24.7, cx=40;
	double R = 0.61803399,C = 1.0-R;
	double tol = 0.01;
	double f1,f2,x0,x1,x2,x3;
	
	x0 = ax;
	x3 = cx;
	if (fabs(cx-bx) > fabs(bx-ax)) {
		x1 = bx;
		x2 = bx+C*(cx-bx);
	} else {
		x2 = bx;
		x1 = bx-C*(bx-ax);
	}
	f1 = exp_sum(total,signs,x1);
	f2 = exp_sum(total,signs,x2);
	while(fabs(x2-x1) > tol) {
		if (!interval && (m24count > 100 || mfcount > 100)) clearhash();
		if (f2 > f1) {
			shift3(x0,x1,x2,R*x2+C*x3);
			shift2(f1,f2,exp_sum(total,signs,x2));
		} else {
			shift3(x3,x2,x1,R*x1+C*x0);
			shift2(f2,f1,exp_sum(total,signs,x1));
		}
//  		if (debug) std::cerr<<"Brackets are ("<<x0<<","<<x1<<","<<x2<<","<<x3<<")\n";
	}
	
	if (f1 > f2) {
		return x1;
	} else {
		return x2;
	}
}

double Iterator::rate_score()
{
	double c=0;
	vector<double> mr1, mr2, total, signs;
	// how do you do this? lessee...
	// all conserved branches, compute SUM(rate*len/scale)
	// all non-conserved branches, find rate1 and rate2. add them
	
	if (right)
		right->rate_branch_score(c,mr1,mr2);
	if (left)
		left->rate_branch_score(c,mr1,mr2);
	// okay, vectors should contain all the rates. Now compute all permutations. There will be
	// 2^x, where x is number of branches with mutation rates. 
	// add them up as follows:
	// mr2 is a binary 1, mr1 is a binary 0.
/*	if (debug) cerr<<"Total for conservation is "<<c<<"\n";
	if (debug) {
		cerr<<"Out of "<<mr1.size()<<" mutations:\n";
		for (int q=0;q<mr1.size();q++) {
			cerr<<mr1[q]<<" "<<mr2[q]<<"\n";
		}
	}*/
	if (!mr1.size()) {
		return 0;
	}
	int is=-1;
	for (int k=0;k<mr1.size();k++)
		if (mr2[k] > mr1[k]) is *= -1;
	double l = 1<<mr1.size();
	for (int k=0;k<l;k++) {
		double s=0;
		int sign = 1;
		for (int j=0;j<mr1.size();j++) {
			if (IS_SET(k,j)) s += mr1[j];
			else {
				s += mr2[j];
				sign *= -1;
			}
		}
// 		if (debug) cerr<<"S : "<<s<<" k "<<k<<" sign "<<sign<<"\n";
		s += c;
		total.push_back(s);
		signs.push_back(sign*is);
	}
	// okay, you have a total vector - run golden section on it to find its max.
	return golden_max(total,signs);
}

double Iterator::cons_score(int l)
{
	double v = 0;
	int sub = 0;
	if (parent) {
		if (parent->record->two != record->two && l) {
			v += tree_prob(parent->record,parent->record,this);
		} else if (!l) {
			v += 1-tree_prob(this,parent->record);
		}
	}
	if (right) v += right->cons_score(l);
	if (left) v += left->cons_score(l);
	return v;
}

TreeP Iterator::tree_score(int black, int center, double (*f)(Iterator*, Stack*)) 
{
	TreeP p;
	int sub = 0;

	if (parent) {
		if (center) {
			if (parent->record->two != record->two) sub = 1;
		} else {
		if (parent->record->one != record->one
		|| parent->record->two != record->two
		|| parent->record->thr != record->thr) sub = 1;
		}
	if (sub && black) {
		double v = f(this,parent->record);
		p.prob *= v;
		p.count++;
//  		if (!computing && debug) {
// 		int j,curr=this->record->two;
// 		for (j = 0; j < GLET; j++) {
// 			this->record->two = j;
//  		fprintf(stderr, "M%c %c%c%c - %c%c%c %e %.1f\n",
// 				(j==curr)?'*':' ',
//  				b[parent->record->one],
//  				b[parent->record->two],
//  				b[parent->record->thr],
//  				b[record->one],
//  				b[record->two],
//  				b[record->thr],
//  				f(this,parent->record), len);
// 		}
// 		this->record->two = curr;
// 		}
	}
	else if (!sub && !black) {
		double v= f(this,parent->record);
		p.prob *= v;
		p.count++;
// 		if (center && parent->record->two == G1)
		if (debug>=3 && !computing) 
		fprintf(stderr, "%c%c%c - %c%c%c %e %e\n",
				let[parent->record->one],
				let[parent->record->two],
				let[parent->record->thr],
				let[record->one],
				let[record->two],
				let[record->thr],
				v, len);
	}
	}
	if (right) p += right->tree_score(black,center,f);
	if (left) p += left->tree_score(black,center,f);
	return p;
}

double Iterator::record_likelihood(double (*f)(Iterator*, Stack*))
{
	double p = 1;
	if (right) p *= f(right,record);
	if (left) p *= f(left,record);
	return p;
}

double Iterator::tree_likelihood(double (*f)(Iterator*, Stack*))
{
	double p = 1;
	if (parent) {
			p *= f(this,parent->record);
	}
	if (right) p *= right->tree_likelihood(f);
	if (left) p *= left->tree_likelihood(f);
	return p;
}

int Iterator::count_subs(int center)
{
	int s = 0;
	if (center) {
	if (parent && (parent->record->two != record->two)) s++;
	} else {
	if (parent && (parent->record->one != record->one)) s++;
	if (parent && (parent->record->two != record->two)) s++;
	if (parent && (parent->record->thr != record->thr)) s++;
	}
	if (right) s += right->count_subs(center);
	if (left) s += left->count_subs(center);
	return s;
}

int Iterator::count_nodes()
{
	int s = 0;
	if (right) s += right->count_nodes();
	if (left) s += left->count_nodes();
	return s+1;
}

int Iterator::count_empty()
{
	int s = 0;
	if (record && record->two > G0) s += 1;
	if (right) s += right->count_empty();
	if (left) s += left->count_empty();
	return s;
}

#define ISDEL(a,b) (a < G1 && b >= G1)
#define ISINS(a,b) (a >= G1 && b < G1)

int Iterator::count_gaps(int in)
{
	int s = 0;
	if (parent) {
	if (in) {
		if (ISINS(parent->record->two,record->two)) s++;
	} else {
		if (ISDEL(parent->record->two,record->two)) s++;
	}
	}
	if (right) s += right->count_gaps(in);
	if (left) s += left->count_gaps(in);
	return s;
}

int Iterator::flag(string s)
{
	
	if (species == s) {
		//cerr<<"Flagging "<<s<<endl;
		tflag = 1;
		return 1;
	}
	if (right) {
		if (right->flag(s)) {
			tflag = 1;
			return 1;
		}
	}
	if (left) {
		if (left->flag(s)) {
			tflag = 1;
			return 1;
		}
	}
	
	return 0;
}

void Iterator::reset_flags()
{
	tflag = 0;
	if (right) right->reset_flags();
	if (left) left->reset_flags();
}

void Iterator::reflag()
{
	if (right) right->reflag();
	if (left) left->reflag();
	if (is_leaf()) {
		b1 = g1;
		b2 = g2;
		b3 = g3;
	}
}

void Iterator::linked_list(Iterator **current)
{
	if (right) right->linked_list(current);
	if (left) left->linked_list(current);
	// don't link to the root, so the list has an end.
	if (parent) {
	(*current)->next = this;
	*current = this;
	}
}

Iterator *Iterator::flag_tree(int a, double l, pfunc p)
{
	// this builds a NEW tree which is the minimal tree containing all flagged species.
	if (!a) {
		// am I the root?
		if (ACTIVE(right) && ACTIVE(left)) {
			// I am the root.
			Iterator *root = new Iterator();
			root->len = 0;
			root->scale = scale;
			root->ref = ref;
			root->sis = sis;
			root->right = right->flag_tree(1,0,p);
			root->right->parent = root;
			root->left = left->flag_tree(1,0,p);
			root->left->parent = root;

			// build the linked list for this tree.
			Iterator *curr = root;
			root->linked_list(&curr);
			return root;
		}
		if (ACTIVE(right)) return right->flag_tree(0,0,p);
		if (ACTIVE(left)) return left->flag_tree(0,0,p);
		return 0;
	}
	// I am NOT the root! Find out if I'm a fork.
	if (ACTIVE(right) && !ACTIVE(left)) {
		return right->flag_tree(1,len+l,p);
	}
	if (ACTIVE(left) && !ACTIVE(right)) {
		
		return left->flag_tree(1,len+l,p);
	}
	if (!left && !right) {
		// wups, I'm a leaf! I BETTER return something.
		Iterator *leaf = new Iterator();
		leaf->len = len + l;
		leaf->scale = scale;
		leaf->ref = ref;
		leaf->sis = sis;
		leaf->species = species;
		leaf->g1 = b1;
		leaf->b1 = b1;
		leaf->g2 = b2;
		leaf->b2 = b2;
		leaf->g3 = b3;
		leaf->b3 = b3;
		leaf->build_matrices(leaf->len / leaf->scale,p);
		return leaf;
	}
	if (ACTIVE(right) && ACTIVE(left)) {
		Iterator *node = new Iterator();
		node->len = len + l;
		node->scale = scale;
		node->ref = ref;
		node->sis = sis;
		node->right = right->flag_tree(1,0,p);
		node->right->parent = node;
		node->left = left->flag_tree(1,0,p);
		node->left->parent = node;
		node->build_matrices(node->len / node->scale,p);
		return node;
	}
	return 0;
}

void Iterator::rebuild(double s, pfunc p)
{
	scale = s;
	build_matrices(len / scale, p);
	if (right) right->rebuild(scale,p);
	if (left) left->rebuild(scale,p);
}

Iterator *Iterator::clone_tree() 
{
	Iterator *m = new Iterator();
	m->len = len;
	m->scale = scale;
	m->record = new Stack;
	*m->record = *record;
	free(m->gm);
	free(m->gv);
	free(m->gn);
	free(m->gl);
	if (right) {
		m->right = right->clone_tree(); 
		m->right->parent = m;
	} else right = 0;
	if (left) {
		m->left = left->clone_tree();
		m->left->parent = m;
	} else left = 0;
	return m;
}

void Iterator::get_nucs(string s)
{
	if (species == s) print_root();
	if (right) right->get_nucs(s);
	if (left) left->get_nucs(s);
}

int Iterator::flag_ref(string s, int unflag)
{
	// flag all nodes with reference species
	if (species == s) ref = unflag;
	if (right) {
		if (right->flag_ref(s)) ref = unflag;
	}
	if (left) {
		if (left->flag_ref(s)) ref = unflag;
	}
	if (unflag) return ref;
	return 1;
}

int Iterator::flag_sis(string s)
{
	int r = 0;
	if (species == s) { if (debug) cerr<<"Flagging sister "<<s<<"\n";return 1; }
	if (right) {
		if (right->flag_sis(s)) {
			if (ref) {
				sis = 1;
				return 0;
			}
			else r = 1;
		}
	}
	if (left) {
		if (left->flag_sis(s)) {
			if (ref) {
				sis = 1;
				return 0;
			}
			else r = 1;
		}
	}
	return r;
}

int Iterator::set_nuc(string s, string nuc)
{
	int rev[256];
	int c = 0;
	rev['A'] = 0;
	rev['G'] = 1;
	rev['C'] = 2;
	rev['T'] = 3;
	rev['a'] = 0;
	rev['g'] = 1;
	rev['c'] = 2;
	rev['t'] = 3;
	rev['1'] = 4;
	rev['2'] = 5;
	rev['3'] = 6;
	rev['4'] = 7;
	rev['N'] = 8;
	rev['n'] = 8;
	if (species == s) {
		//cerr<<"Setting "<<species<<" to "<<nuc;
		set_base_state(rev[nuc[0]], rev[nuc[1]], rev[nuc[2]]);
		if (nuc[1] == ' ') {
		tflag = 0;
		return 0;
		}
		else {
		tflag = 1;
		return 1;
		}
	}
	if (right) c += right->set_nuc(s,nuc);
	if (left) c += left->set_nuc(s,nuc);
	if (c) tflag = 1;
	return c;
}

void Iterator::print_nucs()
{

	if (!right && !left) {
		cerr<<let[b1]<<let[b2]<<let[b3]<<" "<<species<<" ";
	}
	if (right) right->print_nucs();
	if (left) left->print_nucs();
}

void Iterator::print_rec_nucs()
{
	if (right) right->print_rec_nucs();
	if (left) left->print_rec_nucs();

	if (!right && !left) {
		cout<<":"<<let[record->one]<<let[record->two]<<let[record->thr]<<" "<<species<<" ";
	}
}

vector<string> Iterator::keystring()
{
	vector<string> names;
	if (!species.empty()) names.push_back(species);
	if (right) {
		vector<string> rname = right->keystring();
		if (!rname.empty()) {
			for (int i=0;i<rname.size();i++) 
				names.push_back(rname[i]);
		}
	}
	if (left) {
		vector<string> lname = left->keystring();
		if (!lname.empty()) {
			for (int i=0;i<lname.size();i++) 
				names.push_back(lname[i]);
		}
	}
	return names;
}

string Iterator::key()
{
	string k;
	if (keycache.size() > 0) return keycache;
	vector<string> names = keystring();
	int i;
	// get names from sub-trees
	sort(names.begin(),names.end());
	for (i=0;i<names.size();i++)
		k += names[i];
	keycache = k;
	return k;
}

int Iterator::subs_in(string spec)
{
	int i = 0;
	if (right) i += right->subs_in(spec);
	if (left) i += left->subs_in(spec);
	if (!right && !left && species == spec) {
		i += (record->two != parent->record->two);
	}
	return i;
}
