/*
 * This program takes two input sequences, one from human and one from mouse,
 * and calculates a conservation score for them.
 */

#include <stdio.h>
#include <strings.h>
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <ctype.h>
#include "matrix.h"
#include "iterator.h"
#include "probability.h"
#include <iostream>

extern Matrix<double> *m24ihash[];
extern Matrix<double> *m33hash[];
extern int computing;
double interpolate_probability(double time, Matrix<double> **input, int i, int j);

double high() { return 1e20; }
double zero() { return 0; }
double low() { return -1000; }

// transition probability of a -> b
double _prob(int i1,int i2, int i3, int j1, int j2, int j3, int isize, Iterator *i)
{
	int sub1 = 0, sub2 = 0, dsub = 0, slet = 1, dlet = 1, ilet = LET*LET*LET, in=0, de=0;
	Row<double> *flet, *filet;
	
	if (!is_legal_pair(i1,i2,i3,j1,j2,j3)) return 0;

	if (i1 < G1 && j1 < G1) { sub1 += i1*slet; sub2 += j1*slet; slet *= LET; ilet /= LET; }
	if (i2 < G1 && j2 < G1) { sub1 += i2*slet; sub2 += j2*slet; slet *= LET; ilet /= LET; }
	if (i3 < G1 && j3 < G1) { sub1 += i3*slet; sub2 += j3*slet; slet *= LET; ilet /= LET; }
	if (i1 < G1 && j1 >= G1) { dsub += i1*dlet; dlet *= LET; de = j1-3; }
	if (i2 < G1 && j2 >= G1) { dsub += i2*dlet; dlet *= LET; de = j2-3; }
	if (i3 < G1 && j3 >= G1) { dsub += i3*dlet; dlet *= LET; de = j3-3; }

	if (i1 >= G1 && j1 < G1) in = i1-3;
	if (i2 >= G1 && j2 < G1) in = i2-3;
	if (i3 >= G1 && j3 < G1) in = i3-3;

	if (in) return i->substitute(sub1,sub2,slet)/ilet*isize*i->ins(in);
	if (de) {
		if (de == 4)
			return i->substitute(sub1,sub2,slet)/ilet*i->del(de,isize);//(isize+4.84-1);
		else
			return i->substitute(sub1,sub2,slet)/ilet*i->del(de,isize);//*(isize+de-1);
	}
	
	double p = i->substitute(sub1,sub2,slet)/ilet*i->del(0)*i->ins(0);

	return p;
}

double subst(int i1,int i2,int i3, int j1,int j2,int j3, Iterator *i, double r)
{
	double sb = 0;

	if (!is_legal_pair(i1,i2,i3,j1,j2,j3)) return r;

	if (i1 != -1) {
	if (i1 != j1) sb+=1.0;
	if (i2 != j2) sb+=1.0;
	if (i3 != j3) sb+=1.0;
	}

	return sb;
}

double subst1(int i,int j, double r)
{
	double sb = 0;

	if (i >= G1 && j >= G1 && i != j) return 1;

	if (i != -1) {
	if (i != j) return 1;
	}
	return 0;
}

int is_legal_nuc_base(int i1, int i2, int i3)
{
	// okay, check conditions - first of all, do any of them have
	// inequal values?
	// let's be lax - we'll take anything that has two gaps if it has
	// an intervening non-gap. Thus the ONLY invalid conditions are
	// if there are two ADJACENT non-equal gap characters.
	// AAA
  
	if ((i1 != i2 && i1 >= G1 && i2 >= G1) 
	 || (i2 != i3 && i3 >= G1 && i2 >= G1)) return 0;
	// next, check to see if G1 or G2 conditions are false
	if (i1 == G1 && i2 == G1) return 0;
	if (i2 == G1 && i3 == G1) return 0;

	if (i2 >= G2 && ((i1 >= G1 && i1 != i2) || (i3 >= G1 && i3 != i2) || (i3 < G1 && i1 < G1))) return 0;

	return 1;
}

#define INDEL(i,j) ((i>=G1&&j<G1)?i:((j>=G1&&i<G1)?(0-j):0))

int is_legal_pair_base(int i1, int i2, int i3,  int j1, int j2, int j3)
{
	int gap1,gap2,gap3,gap, sgap;
	//if (i1 >= G1 && j1 >= G1 && i1 != j1) return 0;
	//if (i2 >= G1 && j2 >= G1 && i2 != j2) return 0;
	//if (i3 >= G1 && j3 >= G1 && i3 != j3) return 0;
	// the only thing that's NOT legal is multiple indel events in one
	// branch. Otherwise we're okay.
	// reject mismatches, as well.
        if (!is_legal_nuc(i1,i2,i3)||!is_legal_nuc(j1,j2,j3)) return 0;
	if (j1 >= G1 && i1 >= G1 && i1 != j1) return 0;
	if (j2 >= G1 && i2 >= G1 && i2 != j2) return 0;
	if (j3 >= G1 && i3 >= G1 && i3 != j3) return 0;
	// find an indel.
	gap1=INDEL(i1,j1);
	gap2=INDEL(i2,j2);
	gap3=INDEL(i3,j3);
	if (!gap2 && gap1 && gap3) return 0;
	if (gap1 && gap2 && gap1 != gap2) return 0;
	if (gap2 && gap3 && gap2 != gap3) return 0;
	// okay, now it SHOULD be legal. But it might not be if
	// adjacent pairs are unchanged, e.g. A33 => AA3. So see if
	// neighbors are empty gaps, and reject them.
	if (gap1 && !gap2) {
		if (j2 >= G1) return 0;
	}
	if (gap2 && !gap1) {
		if (j1 >= G1) return 0;
	}
	if (gap2 && !gap3) {
		if (j3 >= G1) return 0;
	}
	if (gap3 && !gap2) {
		if (j2 >= G1) return 0;
	}

	return 1;

}

int legalnuc[GLET*GLET*GLET];
int legalpair[GLET*GLET*GLET][GLET*GLET*GLET];

Stack allnucs[GLET*GLET*GLET];
Stack gapnucs[LET*LET*LET];
Stack lnucs[GLET*GLET*GLET];
int numlegal = 0;

char let[10];

void cache_nuc()
{
	int i1,i2,i3;
	int j1,j2,j3;
	
	let[0] = 'A';
	let[1] = 'G';
	let[2] = 'C';
	let[3] = 'T';
	let[4] = '1';
	let[5] = '2';
	let[6] = '3';
	let[7] = '4';
	let[8] = 'N';

	for (i1=0;i1<GLET;i1++)
	for (i2=0;i2<GLET;i2++)
	for (i3=0;i3<GLET;i3++) {
		legalnuc[IN(i1,i2,i3)] = is_legal_nuc_base(i1,i2,i3);
		allnucs[IN(i1,i2,i3)].one = i1;
		allnucs[IN(i1,i2,i3)].two = i2;
		allnucs[IN(i1,i2,i3)].thr = i3;
		if (i1 < G1 && i2 < G1 && i3 < G1) {
		gapnucs[IN2(i1,i2,i3)].one = i1;
		gapnucs[IN2(i1,i2,i3)].two = i2;
		gapnucs[IN2(i1,i2,i3)].thr = i3;
		}

		if (legalnuc[IN(i1,i2,i3)]) {
		lnucs[numlegal].one = i1;
		lnucs[numlegal].two = i2;
		lnucs[numlegal].thr = i3;
		numlegal++;
		}
	}
	

	for (i1=0;i1<GLET;i1++)
	for (i2=0;i2<GLET;i2++)
	for (i3=0;i3<GLET;i3++)
		for (j1=0;j1<GLET;j1++)
		for (j2=0;j2<GLET;j2++)
		for (j3=0;j3<GLET;j3++)
			legalpair[IN(i1,i2,i3)][IN(j1,j2,j3)] = is_legal_pair_base(i1,i2,i3,j1,j2,j3);
	
}

int is_legal_nuc(int i1, int i2, int i3)
{
	return legalnuc[i1+i2*GLET+i3*GLET*GLET];
}

int is_legal_pair(int i1, int i2, int i3,  int j1, int j2, int j3)
{
	return
			legalpair[i1+i2*GLET+i3*GLET*GLET][j1+j2*GLET+j3*GLET*GLET];
}

double tree_prob(Stack *s, Stack *t, Iterator *i)
{
	return tree_prob(s->one,s->two,s->thr,t->one,t->two,t->thr,i);
}

double tree_prob(Iterator *i, Stack *s)
{
	if (!i->actual_time) {
		if (s->two == i->record->two) return 1;
		else return 0;
	}
	return interpolate_probability(i->actual_time, m24ihash, MI(s->two,i->record->two), FL(s->one,s->thr,s->one,s->thr));
}

double tree_prob(int i1, int i2, int i3, int j1, int j2, int j3, Iterator *i)
{
	if (!i->actual_time) {
		if (i2 == j2) return 1;
		else return 0;
	}
	return interpolate_probability(i->actual_time, m24ihash, MI(i2,j2), FL(i1,i3,i1,i3));
}

double single_prob(Iterator *i, Stack *s)
{
	// compare ONLY the first column.
	int a, b;
	a = s->one;
	b = i->record->one;
	if (a >= G1 && b >= G1) {
		if (a==b) return 1;
		else 
			return 0;
	}
	if (a >= G1 && b < G1) return i->ins(a-3);
	if (a < G1 && b >= G1) return i->del(b-3);
	return i->substitute(a,b,LET);
}

double flat_prob(Stack *a, Stack *b,Iterator *i)
{
	if (!a) return freq(b);
	//return _prob(a->one,a->two,a->thr,b->one,b->two,b->thr,3,i);
	return flat_prob(a->one,a->two,a->thr,b->one,b->two,b->thr,i);

}

double flat_prob(Iterator *i, Stack *s)
{
	//return _prob(s->one,s->two,s->thr,i->record->one,i->record->two,i->record->thr,3,i);
	return flat_prob(s, i->record, i);
}

double flat_prob(int i1, int i2, int i3, int j1, int j2, int j3, Iterator *i)
{
	return (*i->m33)[i1+i2*GLET+i3*GLET*GLET][j1+j2*GLET+j3*GLET*GLET];
}

double freq(int i1, int i2, int i3)
{
	int sub = 0, let = 1;
	if (i1 < G1) { sub += i1;let *= LET; }
	if (i2 < G1) { sub += i2*let;let *= LET; }
	if (i3 < G1) { sub += i3*let;let *= LET; }
	if (let==1) return 1;
	if (let == LET) return f1[sub];
	if (let == LET*LET) return f2[sub];
	if (let == LET*LET*LET) return f3[sub];
	return 1;
}

double cfreq(int i1, int i2, int i3)
{
	int sub = 0, let = 1;
	if (i1 < G1) { sub += i1;let *= LET; }
	if (i2 < G1) { sub += i2*let;let *= LET; }
	if (i3 < G1) { sub += i3*let;let *= LET; }
	if (let==1) return 1/(LET*LET*LET);
	if (let == LET) return f1[sub]/(LET*LET);
	if (let == LET*LET) return f2[sub]/LET;
	if (let == LET*LET*LET) return f3[sub];
	return 1;
}

double freq(Stack *a)
{
	return freq(a->one,a->two,a->thr);
}
