/*
 * Alignment.cpp
 *
 */

#include "constants.h"
#include "Alignment.h"
using namespace std;

Alignment::Alignment() {
	//cerr << "A filename should be provided." << endl;
}

Alignment::~Alignment() {
	// TODO Auto-generated destructor stub
}

Alignment::Alignment(string file) {
  read_alignment(file);
  filename = file;
  /*  cerr << "From " << filename << " read the following: " << endl;
  cerr << "Rows: " << n_rows << endl;
  cerr << "Cols: " << n_cols << endl;
  cerr << "TF: " << TF << endl;
  cerr << "n_i's: " << endl;
  
  for(unsigned int alpha = 0; alpha < ALPH_NUM ; alpha++){
    for(unsigned int i= 0; i < n_cols; i++){
      cerr << n_i[alpha][i] << " ";
    }
    cerr << endl;
  }
  cerr << endl;
  cerr << "n_ij's: " << endl;
  for(unsigned int i= 0; i < n_cols; i++){
    for(unsigned int k  =  i+1; k < n_cols; k++){
      for(unsigned int alpha = 0; alpha < ALPH_NUM ; alpha++){
	for(unsigned int beta = 0; beta < ALPH_NUM ; beta++){
	  cerr << n_ij[alpha][beta][i][k] << " ";
	}
	cerr << endl;
      }
      cerr << endl;
    }
    cerr << endl;
  }
  cerr << endl;
  */
}

//copy constructor
Alignment::Alignment(Alignment &other) {
  // make deep copies of all the data structures in other
  n_rows = other.n_rows;
  n_cols = other.n_cols;
  TF = other.TF;
  filename = other.filename;

  n_i.resize(ALPH_NUM);
  lp_i.resize(ALPH_NUM);
  for(unsigned int alpha = 0; alpha < ALPH_NUM ; alpha++){
    n_i[alpha].resize(n_cols);
    lp_i[alpha].resize(n_cols);
    for(unsigned int i= 0; i < n_cols; i++){
      n_i[alpha][i] = other.n_i[alpha][i];
      lp_i[alpha][i] = other.lp_i[alpha][i];
    }
  }

  n_ij.resize(ALPH_NUM);
  lp_ij.resize(ALPH_NUM);
  for(unsigned int alpha = 0; alpha < ALPH_NUM; alpha++){
    n_ij[alpha].resize(ALPH_NUM);
    lp_ij[alpha].resize(ALPH_NUM);
    for(unsigned int beta = 0; beta < ALPH_NUM ; beta++){
      n_ij[alpha][beta].resize(n_cols);
      lp_ij[alpha][beta].resize(n_cols);
      for(unsigned int k  =  0; k < n_cols; k++){
	n_ij[alpha][beta][k].resize(n_cols);
	lp_ij[alpha][beta][k].resize(n_cols);
	for(unsigned int l = 0;l < n_cols; l++){
	  n_ij[alpha][beta][k][l] = other.n_ij[alpha][beta][k][l];
	  lp_ij[alpha][beta][k][l] = other.lp_ij[alpha][beta][k][l];
	}
      }
    }
  }
}

void Alignment::read_alignment(string file){
  ifstream File(file.c_str(), ios::in);
  if (!File) {
    cerr << "There is no such a file or directory: " << file << endl;
    exit(1);
  }
  string line;
  vector<string> elems;
  int line_number = 0;
  int num_positions = 0;
  while(!File.eof()){
    getline(File, line);
    if(line.find("NA") == 0){
      boost::split(elems, line, boost::is_any_of("\t "));
      TF = elems[1];
    }
    if((line.find("\\") == 0) || (line.find("PO") == 0)){
      continue;
    }
    boost::split(elems, line, boost::is_any_of("\t "));
    if(elems.size() == 18){
      line_number++;
    }
  }
  num_positions = int((1 + sqrt(8*line_number+1))/2);

  initialize_n_i(num_positions);
  initialize_n_ij(num_positions);
  File.close();

  File.open(file.c_str(), ios::in);
  int index = 0;
  while(!File.eof()){
    getline(File, line);
    if(line.find("NA") == 0){
      continue;
    }
    if((line.find("\\") == 0) || (line.find("PO") == 0)){
      continue;
    }
    boost::split(elems, line, boost::is_any_of("\t "));
    if(elems.size() == ALPH_NUM * ALPH_NUM + 2){
      index = 2;
      int i = atoi(elems[0].c_str()) - 1;
      int j = atoi(elems[1].c_str()) - 1;
      for(int a = 0; a < ALPH_NUM; a++){
        for(int b = 0; b < ALPH_NUM; b++){ 
          n_ij[a][b][i][j] = n_ij[b][a][j][i] = atof(elems[index].c_str());
          index++;
        }
      }
    }
  }
  File.close();

  n_rows = 0.0;
  for(int a = 0; a < ALPH_NUM; a++){ 
    for(int i = 0; i < num_positions; i++) {
      n_i[a][i] = 0.0;
      int j = 0;
      if(i == 0) {
        j = 1;
      }
      for (int b = 0; b < ALPH_NUM; b++) {
        n_i[a][i] += n_ij[a][b][i][j];
      }
      n_rows += n_i[a][i];
    }
  }                                                                                                                                                            
  n_rows /= ((float) num_positions);
  n_cols = num_positions;

  //set the log probabilities
  initialize_lp_i(num_positions);
  initialize_lp_ij(num_positions);
}

//set the log probabilities from a current count matrix
void Alignment::set_logprobs(double LAMBDA){
  double ltot = log(((double) n_rows)+16.0*LAMBDA);
  for(int a = 0; a < ALPH_NUM; a++){ 
    for(int i = 0; i < n_cols; i++) {
      lp_i[a][i] = log(n_i[a][i]+4*LAMBDA)-ltot;
      for(int j=0;j<i;j++){
	for(int b=0;b<ALPH_NUM;b++){
	  lp_ij[a][b][i][j] = lp_ij[b][a][j][i] = log(n_ij[a][b][i][j]+LAMBDA)-ltot;
	}
      }
    }
  }
}


void Alignment::reset_counts() {
  for(int alpha = 0; alpha < ALPH_NUM; alpha++) {
    for(int i = 0; i < n_cols; i++) {
      n_i[alpha][i] = 0.0;
      for(int beta = 0; beta < ALPH_NUM; beta++) {
	for(int j = 0; j < n_cols; j++) {
	  n_ij[alpha][beta][i][j] = 0.0;
	}
      }
    }
  }
  n_rows = 0.0;
}

void Alignment::renormalize_counts() {
  for(int i = 0; i < n_cols; i++) {
    double single_sum = 0.0;
    for(int alpha = 0; alpha < ALPH_NUM; alpha++) {
      single_sum += n_i[alpha][i];
    }
    double renor = n_rows/single_sum;
    for(int alpha = 0; alpha < ALPH_NUM; alpha++) {
      n_i[alpha][i] = renor * n_i[alpha][i];
    }
    
    for(int j = 0; j < n_cols; j++) {
      double pair_sum = 0.0;
      for(int alpha = 0; alpha < ALPH_NUM; alpha++) {
	for(int beta = 0; beta < ALPH_NUM; beta++) {
	  pair_sum += n_ij[alpha][beta][i][j];
	}
      }
      renor = n_rows/pair_sum;
      for(int alpha = 0; alpha < ALPH_NUM; alpha++) {
	for(int beta = 0; beta < ALPH_NUM; beta++) {
	  n_ij[alpha][beta][i][j] = renor * n_ij[alpha][beta][i][j];
	}
      }
    }
  }
}

void Alignment::add_foreground_counts(string sequence, string revcomp, vector<double> sequence_scores, vector<double> rev_scores,double lam){
  string curseq = sequence;
  string revseq = revcomp;
  std::vector<double> scores = sequence_scores;
  std::vector<double> rc_scores = rev_scores;
  int curlen = scores.size();

  //Get the total free energy of this sequence
  double free_ener = 0.0;
  for(int i = 0; i < curlen; ++i) {
  
    double curscore = scores[i];
    free_ener += exp(curscore);
    
    curscore = rc_scores[i];
    free_ener += exp(curscore);
  }
  //Foreground weight of this sequence
  double w = 1.0/(free_ener + 2.0*((double) curlen)*lam);
  //cerr << "free ener " << free_ener << " len " << curlen << " lam " << lam << " w " << w << endl;

  //now go add the counts
  for(int i = 0; i < curlen; ++i) {
    double curscore = scores[i];
    double curpost = w * exp(curscore);
    if(curpost > 0.01){
      n_rows += curpost;
      // cerr << "curpost " << curpost << endl;
      for(int p1 = 0;p1 < n_cols; ++p1){
	//Update the marginal count of this base
	int alpha = convert(toupper(curseq[i+p1]));
	//cerr << curseq[i+p1];
	n_i[alpha][p1] += curpost;
	//Only run over the positions j < i.
	for(int p2 = 0; p2 < p1; ++p2){
	  int beta = convert(toupper(curseq[i+p2]));
	  n_ij[alpha][beta][p1][p2] += curpost;
	  n_ij[beta][alpha][p2][p1] += curpost;
	}
      }
    }
    //now the rev-comp sequences
    curscore = rc_scores[i];
    curpost = w * exp(curscore);
    if(curpost > 0.01){
      n_rows += curpost;
      //cerr << "curpost " << curpost << endl;
      for(int p1 = 0;p1 < n_cols; ++p1){
	//Update the marginal count of this base
	int alpha = convert(toupper(revseq[i+p1]));
	//cerr << curseq[i+p1];
	n_i[alpha][p1] += curpost;
	
	//Only run over the positions j < i.
	for(int p2 = 0; p2 < p1; ++p2){
	  int beta = convert(toupper(revseq[i+p2]));
	  n_ij[alpha][beta][p1][p2] += curpost;
	  n_ij[beta][alpha][p2][p1] += curpost;
	}
      }
    }
  }
}


  




string Alignment::get_TF(){
	return TF;
}

float Alignment::nrows(){
	return n_rows;
}

unsigned short int Alignment::ncols(){
	return n_cols;
}

float Alignment::get_n_i(char alpha, unsigned short int i){
	return n_i[convert(toupper(alpha))][i];	// I should add an exception here in order to prevent it from index out of range
}

float Alignment::get_lp_i(char alpha, unsigned short int i){
	return lp_i[convert(toupper(alpha))][i];	// I should add an exception here in order to prevent it from index out of range
}


float Alignment::get_n_i(unsigned short int alpha, unsigned short int i){
	return n_i[alpha][i];
}

float Alignment::get_lp_i(unsigned short int alpha, unsigned short int i){
	return lp_i[alpha][i];
}


float Alignment::get_n_ij(char alpha, char beta, unsigned short int i, unsigned short int j){
	return n_ij[convert(toupper(alpha))][convert(toupper(beta))][i][j];	// same as get_n_i();
}

float Alignment::get_lp_ij(char alpha, char beta, unsigned short int i, unsigned short int j){
	return lp_ij[convert(toupper(alpha))][convert(toupper(beta))][i][j];	// same as get_n_i();
}


float Alignment::get_n_ij(unsigned short int alpha, unsigned short int beta, unsigned short int i, unsigned short int j){
	return n_ij[alpha][beta][i][j];
}

float Alignment::get_lp_ij(unsigned short int alpha, unsigned short int beta, unsigned short int i, unsigned short int j){
	return lp_ij[alpha][beta][i][j];
}

void Alignment::initialize_n_i(unsigned short int cols){
	n_i.resize(ALPH_NUM);
	for(unsigned int i = 0; i < ALPH_NUM ; i++){
		n_i[i].resize(cols, 0);
	}
//	n_i = (float **) malloc(ALPH_NUM*sizeof(float));
//	for(unsigned int i=0; i < ALPH_NUM ; i++){
//		n_i[i] = (float*) malloc(cols*sizeof(float));
//		for(unsigned int j = 0; j < cols ; j++){
//			n_i[i][j] = .0;
//		}
//	}
}

void Alignment::initialize_lp_i(unsigned short int cols){
  lp_i.resize(ALPH_NUM);
  for(unsigned int i = 0; i < ALPH_NUM ; i++){
    lp_i[i].resize(cols, 0);
  }
}



void Alignment::initialize_n_ij(unsigned short int cols){
	n_ij.resize(ALPH_NUM);
	for(unsigned int i=0; i< ALPH_NUM; i++){
		n_ij[i].resize(ALPH_NUM);
		for(unsigned int j=0; j<ALPH_NUM ; j++){
			n_ij[i][j].resize(cols);
			for(unsigned int k = 0; k<cols ; k++){
				n_ij[i][j][k].resize(cols, 0);
			}
		}
	}
}


void Alignment::initialize_lp_ij(unsigned short int cols){
  lp_ij.resize(ALPH_NUM);
  for(unsigned int i=0; i< ALPH_NUM; i++){
    lp_ij[i].resize(ALPH_NUM);
    for(unsigned int j=0; j<ALPH_NUM ; j++){
      lp_ij[i][j].resize(cols);
      for(unsigned int k = 0; k<cols ; k++){
	lp_ij[i][j][k].resize(cols, 0);
      }
    }
  }
}




