#include <iostream>
#include <time.h>
#include <vector>
#include <iomanip>
#include "constants.h"
#include "Alignment.h"
#include "Decomposition.h"
#include "Score.h"
#include "Sequence.h"
#include "ParameterFile.h"

using namespace std;

// declare functions that we will define below
// read sequences and save those with a min length
std::vector<Sequence> ReadSeqsFromFile(string file, unsigned int min_len, double paddingfrac);
//optimize the non-specific energy
double optimize_nonspec_ener(vector<Sequence> sequences);


// compare two motifs
float compare(Alignment old_alignment, Alignment new_alignment);
// output functions
void write_site_scores_new(string outfile_name, vector<Sequence> sequences, double lam, int motiflen,double minimal_score);
void write_seq_scores_new(string outfile_name,vector<Sequence> sequences,double lam);
void write_motif(string filename, Alignment alignment);

int main(int argc, char* argv[]){
	if (argc != 4){
	  cerr << "Usage: " << argv[0] << " <DWT model> <input sequences (FASTA format)> <parameter file>" << endl;
	  cerr << "\t\te.g. " << argv[0] << " CTCF.dwt sequences.fasta param_file.txt" << endl;
		exit(1);
	}

	//Read the motif and initialize counts of dinucleotides
	string dwt_model_file = argv[1];
	Alignment alignment(dwt_model_file); 

	cerr << "Read the DWT motif file" << endl;

	//Read the parameters
	ParameterFile parameters(argv[3]);
	cerr << "Read the parameters." << endl;


	// First read sequences from file
	vector<Sequence> sequences = ReadSeqsFromFile(argv[2], alignment.ncols(),parameters.give_paddingfrac());
	if(sequences.size() <= 0) {
	  cerr << "No sequence to process, quitting..." << endl;
	  exit(1);
	}
	cerr << "Read the sequences." << endl;
	
	// check if we need to set the background frequencies of nucleotides from the sequences themselves
	// and if so, go over all sequences and calculate frequencies of A, C, G, T 
	vector<float> probs;
	probs.resize(ALPH_NUM, 0.0);
	if(parameters.give_background_type() == FROMSEQ) {
	  double sum = 0.0;
	  for(unsigned int i = 0; i < sequences.size(); i++) {
	    vector<unsigned int> counts = sequences[i].give_letter_counts();
	    for(unsigned int j = 0; j < counts.size(); j++) {
	      probs[j] += counts[j];
	      sum += counts[j];
	    }
	  }
	  // transform counts to probabilities
	  for(unsigned int j = 0; j < probs.size(); j++) {
	    probs[j] /= sum;
	  }
	  //make reverse-complement symmetric
	  double tmp_prob = (probs[0]+probs[3])/2.0;
	  probs[0] = probs[3] = tmp_prob;
	  tmp_prob = (probs[1]+probs[2])/2.0;
	  probs[1] = probs[2] = tmp_prob;
	  parameters.change_background(probs);
	}
	probs = parameters.give_background();
	
	cout << "Running in mode " << parameters.give_mode() << endl;
	cout << "Background nucleotide probabilities (A C G T)  " << probs[0] << " " << probs[1] << " " << probs[2] <<  " " << probs[3] << endl;
	cout << "Precision " << parameters.give_precision() << endl;
	cout << "Padding motif length fraction " << parameters.give_paddingfrac() << endl;
	cout << "Minimal score for printing " << parameters.give_lowest_score() << endl;
	
	// define the non-specific energy variable
	double lam;

	bool converged = false;
	int iter = 0;
	int max_iter = 200;
	float convergence_threshold = 0.05;

	while(!converged && iter < max_iter){
	  ++iter;
	  cerr << "Motif training iteration " << iter << endl;
	  //Calculate logR matrix and determinant with current alignment
	  Score S(alignment, alignment.get_TF(), parameters);
	  	  
	  //Calculate the score for all windows in all sequences
	  for(unsigned int i = 0; i < sequences.size(); i++) {
	    sequences[i].get_scores(S); 
	  }
	  
	  // optimize non-specific binding affinity
	  lam = optimize_nonspec_ener(sequences);
	  //cerr << "Optimized non-specific affinity lambda to " << lam << " frac sequences bound is " << ff << endl;
	  
	  // do we have to train? Check parameters flag
	  if(parameters.give_train_DWT()){
	    // yes, we do

	    // get a new alignment
	    // first copy the alignment that we have already initialized from input file
	    Alignment new_alignment(alignment);
	    // set counts back to 0
	    new_alignment.reset_counts();
	    // compute new counts based on the predicted sites in sequences
	    for(unsigned int i = 0; i < sequences.size(); i++) {
	      //only count the sites in the positive sequences
	      if(sequences[i].is_positive()){
		new_alignment.add_foreground_counts(sequences[i].give_seq(), sequences[i].give_rc_seq(), sequences[i].give_scores(),sequences[i].give_rc_scores(),lam);
	      }
	    }
	    cerr << "Made new alignment with total sites " << new_alignment.nrows() << endl << endl;
	    //renormalize the counts of this alignment (to make sure every marginal and pair sums to same nrows)
	    new_alignment.renormalize_counts();


	    //compare with old alignment
	    float diffaln = compare(alignment, new_alignment);
	    //cerr << "Difference with old motif " << diffaln << endl;
	    if(diffaln < convergence_threshold){
	      converged = true;
	    }
	    	    
	    //set alignment to new_alignment
	    for(short unsigned int i = 0; i < alignment.ncols();i++){
	      for(short unsigned int alpha = 0;alpha < ALPH_NUM;++alpha){
	      }
	    }

	    alignment = new_alignment;
	    for(short unsigned int i = 0; i < alignment.ncols();i++){
	      for(short unsigned int alpha = 0;alpha < ALPH_NUM;++alpha){
	      }
	    }
	  }
	  else{
	    converged = true;
	  }
	}

	//When the iteration has finished:
	//Write one file (output_dir/TFname.sites) with segments and their scores (for scores over a cut-off)
        // iterate over sequences
        // iterate over all scores and reverse complement scores
        //    check if score > minimal_score
        //       if so, calculate posterior (need s0) = 1.0/(1.0+exp(s0-s)) where s is score
	//       print line with: seqname start_window end_window strand sequence score posterior
	//       can do first plus strand and then the minus strand    
	string outfile_name;
	if(parameters.give_output_dir().size() > 0){ 
	  outfile_name = parameters.give_output_dir() + "/" + alignment.get_TF() + ".sites";
	}
	else{
	  outfile_name = alignment.get_TF() + ".sites";
	}
	string extension = "";
        if(parameters.give_mode() == DWT)
          extension = ".dwt";
        else if(parameters.give_mode() == ADJ)
          extension = ".adj";
        else if(parameters.give_mode() == PWM)
          extension = ".pwm";
        outfile_name = outfile_name + extension;
	int motiflen = alignment.ncols();
	write_site_scores_new(outfile_name, sequences, lam, motiflen,parameters.give_lowest_score());
	cerr << "Done writing site scores." << endl;

	//Write one file (output_dir/TFname.seq_scores) with overall scores and site-counts for all sequences (I will write how to get this score)
	if(parameters.give_output_dir().size() > 0){
	  outfile_name = parameters.give_output_dir() + "/" + alignment.get_TF() + ".seq_scores";
	}
	else{
	  outfile_name = alignment.get_TF() + ".seq_scores";
	}
        outfile_name = outfile_name + extension;
	write_seq_scores_new(outfile_name,sequences,lam);

	//If train_dwt was 1, Write new DWT file (output_dir/TFname.dwt or output_dir/TFname.adj or output_dir/TFname.pwm depending on 'mode' variable)
	if(parameters.give_train_DWT()){
	  // these are alignment scores
	  if(parameters.give_output_dir().size() >0){
	    outfile_name = parameters.give_output_dir() + "/" + alignment.get_TF() + ".trained";
	  }
	  else{
	    outfile_name = alignment.get_TF() + ".trained";
	  }
	  outfile_name = outfile_name + extension;
	  write_motif(outfile_name,alignment);
	}
	return 0;
}

std::vector<Sequence> ReadSeqsFromFile(string file, unsigned int min_len,double paddingfrac) {
  // get the number of sequences that we need to process
  // they should be longer than motif length
  ifstream FILE (file.c_str(), ios::in);
  if(!FILE.is_open()){
    cerr << "Cannot open sequence file, quitting..." << endl;
    exit(1);
  }
   
  //Determine how much padding to add
  int padding;
  padding = (int) (paddingfrac * min_len);
  
  unsigned int nr = 0;
  string name, sequence, dummy;
  while (true){
    FILE >> name;
    if (FILE.eof())
      break;
    getline(FILE, dummy);
    FILE >> sequence;
    if(sequence.length() > (min_len-2*padding)) {
      nr++;
    }
  }
  FILE.close();

  // now read the file again and construct the sequences data structure
  std::vector<Sequence> sequences;

  FILE.open(file.c_str(), ios::in);
  while (true){
    FILE >> name;
    if (FILE.eof())
      break;
    getline(FILE, dummy);
    name = name.substr(1);
    FILE >> sequence;
    //Now append Ns to the start and end
    if(padding > 0){
      for(int j=0;j<padding;++j){
	sequence.insert(0,"N");
      }
      for(int j=0;j<padding;++j){
	sequence.append("N");
      }
    }
    if(sequence.length() > 0) {
      Sequence sequenceObj = Sequence(name, sequence);
      sequences.push_back(sequenceObj);
    }
  }
  FILE.close();
  return sequences;
}


double optimize_nonspec_ener(vector<Sequence> sequences) {

  int max_iter = 100;
  double x_min = -50.0;
  double x_max = 10.0;
  int num_fg = 0;
  double lam;
  int bg_win = 0;
  double tot_bg_prob = 0.0;

  //We need to match  <ls/(exp(x) ls + exp(Es)) to 1/(1+exp(x))>

  //First get total energy for every sequence 
  std::vector<double> free_ener;
  free_ener.resize(sequences.size(),0.0);
  std::vector<int> fglen;
  fglen.resize(sequences.size(),0);
  
  for(unsigned int s = 0; s < sequences.size(); ++s) {
    //Get this sequence and its scores                                                                                                              
    Sequence curseq = sequences[s];
    std::vector<double> scores = curseq.give_scores();
    std::vector<double> rc_scores = curseq.give_rc_scores();
    unsigned int curlen = scores.size();
    fglen[s] = 2*curlen;
    free_ener[s] = 0.0;
    
    //If positive, count sitecount for this sequence                                                                                                
    if(curseq.is_positive()){
      for(unsigned int i = 0; i < curlen; ++i) {
	double curscore = scores[i];
	double revscore = rc_scores[i];
	//Add in posteriors of the forward and rev-comp sites                                                                                       
	free_ener[s] += exp(curscore);
	free_ener[s] += exp(revscore);
      }
      ++num_fg;
    }
    else{
      for(unsigned int i = 0; i < curlen; ++i) {
	double curscore = scores[i];
	double revscore = rc_scores[i];
	//Add in posteriors of the forward and rev-comp sites                                                                                       
	tot_bg_prob += exp(curscore);
	tot_bg_prob += exp(revscore);
	bg_win += 2;
      }
    }
  }

  if(bg_win > 0){
    tot_bg_prob /= ((double) bg_win);
  }
  //cerr << "total bg prob from " << bg_win << " windows is " << tot_bg_prob << endl;

  //Check derivative at x_min
  lam = exp(x_min);
  double deriv = 0;
  for(unsigned int s = 0; s < sequences.size(); ++s) {
    Sequence curseq = sequences[s];
    if(curseq.is_positive()){
      deriv += ((double) fglen[s])/(lam * ((double) fglen[s]) + free_ener[s]);
    }
  }
  deriv /= ((double) num_fg);
  deriv -= 1.0/(1.0+lam);
  double deriv_min = deriv;

  //Check derivative at x_max
  deriv = 0;
  lam = exp(x_max);
  for(unsigned int s = 0; s < sequences.size(); ++s) {
    Sequence curseq = sequences[s];
    if(curseq.is_positive()){
      deriv += ((double) fglen[s])/(lam * ((double) fglen[s]) + free_ener[s]);
    }
  }
  deriv /= ((double) num_fg);
  deriv -= 1.0/(1.0+lam);
  double deriv_max = deriv;

  if(deriv_min > 0 && deriv_max < 0){
    double x_mid;
    double x_high = x_max;
    double x_low = x_min;
    int iter = 0;
    while((x_high - x_low) > 0.01 && iter < max_iter){
      x_mid = (x_high+x_low)/2.0;
      lam = exp(x_mid);
      deriv = 0.0;
      for(unsigned int s = 0; s < sequences.size(); ++s) {
	Sequence curseq = sequences[s];
	if(curseq.is_positive()){
	  deriv += ((double) fglen[s])/(lam * ((double) fglen[s]) + free_ener[s]);
	}
      }
      deriv /= ((double) num_fg);
      deriv -= 1.0/(lam  +1.0);
	//cerr << "x_low " << x_low << " x_high " << x_high << " lam " << lam << " deriv " << deriv << endl;
      if(deriv > 0){
	x_low = x_mid;
      }
      else{
	x_high = x_mid;
      }
      ++iter;
    }
    if(iter >= max_iter){
      cerr << "ran out of iterations " << endl;
    }
    lam = exp((x_high+x_low)/2.0);
  }
  else{
    cerr << "derivatives not the right signs at the edges deriv_min " << deriv_min << " deriv_max " << deriv_max << endl;
    if(deriv_min <= 0){
      lam = exp(x_min);
      lam = 0.0;//set really to zero
    }
    else{
      lam = exp(x_max);
    }
  }
  return lam;
}


// Compare alignments and return the difference score
//Make new alignment based on sequences, their scores, and prior s0
float compare(Alignment old_alignment, Alignment new_alignment) {

  float reldiff = 0;
  int num = 0;
  
  for(unsigned short int alpha = 0; alpha < ALPH_NUM; ++alpha){
    for(unsigned short int beta = 0; beta < ALPH_NUM; ++beta){
      for(unsigned short int i = 0; i < old_alignment.ncols(); ++i){
	for(unsigned short int j = 0; j < i ;++j){//I am only going to run over columns j<i because the others are just copies with alpha/beta exchanged
	  float cnew = new_alignment.get_n_ij(alpha, beta, i, j);
	  float cold = old_alignment.get_n_ij(alpha, beta, i, j);
	  float rel = fabs(cold-cnew);//Do not need to divide by sum because absolute counts are large
	  reldiff += rel;
	  ++num;
	}
      }
    }
  }
  reldiff /= ((double) num);

  return reldiff;
}


void write_site_scores_new(string outfile_name, vector<Sequence> sequences, double lam, int motiflen,double minimal_score) {
  ofstream outfile (outfile_name.c_str(), ios::out);
  if(!outfile.is_open()) {
    cerr << "Cannot open TFname.sites for writing, Quitting... name " << outfile_name << endl;
    exit(1);
  }
  
  // write a header so that we know what we have in this file
  outfile << "SequenceName\tBeginSite\tEndSite\tStrand\tSiteSequence\tScore\tPosterior" << endl;
  
  for(unsigned int i = 0; i < sequences.size(); i++) {
    string seq = sequences[i].give_seq();
    string seq_rc = sequences[i].give_rc_seq();
    vector<double> seq_scores = sequences[i].give_scores();
    vector<double> seq_rc_scores = sequences[i].give_rc_scores();
    
    double free_ener = 0.0;
    unsigned int curlen = seq_scores.size();
    double ls = 2.0*((double) curlen);
    for(unsigned int ii = 0; ii < curlen; ++ii) {
      double curscore = seq_scores[ii];
      free_ener += exp(curscore);
      curscore = seq_rc_scores[ii];
      free_ener += exp(curscore);
    }
    double w = 1.0/(free_ener+lam * ls);

    string site_seq;
    int begin_site;
    int end_site;
    
    //first plus strand sites
    for(unsigned int ind = 0; ind < seq_scores.size(); ind++) {
      if(seq_scores[ind] > minimal_score) {
	double pos_score = w*exp(seq_scores[ind]);
	site_seq = seq.substr(ind, motiflen);
	begin_site = ind;
	end_site = begin_site + motiflen - 1;
	// print out this site
	outfile << sequences[i].give_name() << "\t" << begin_site << "\t" << end_site << "\t+\t" << site_seq << "\t" << seq_scores[ind] << "\t" << pos_score << endl;
      }
    }
    //cerr << "seq " << i << " wrote all plus strand sites" <<endl;
    // now minus strand scores
    // minus strand sequence is reverse complemented so I have to compute the site coordinates properly
    for(unsigned int ind = 0; ind < seq_rc_scores.size(); ind++) {
      if(seq_rc_scores[ind] > minimal_score) {
	double pos_score = w * exp(seq_rc_scores[ind]);
	site_seq = seq.substr(ind, motiflen);
	begin_site = seq_rc.size()-ind-1;
	end_site = begin_site + motiflen - 1;
	// print out this site
	outfile << sequences[i].give_name() << "\t" << begin_site << "\t" << end_site << "\t-\t" << site_seq << "\t" << seq_rc_scores[ind] << "\t" << pos_score << endl;
      }
    }
  }
  // close file
  outfile.close();
}



void write_seq_scores_new(string outfile_name,vector<Sequence> sequences,double lam) {
  ofstream FILE (outfile_name.c_str(), ios::out);
  if(!FILE.is_open()){
    cerr << "Cannot open file to write sequence scores, quitting..." << endl;
    exit(1);
  }

  for(unsigned int s = 0; s < sequences.size(); ++s) {
    Sequence curseq = sequences[s];

    std::vector<double> scores = curseq.give_scores();
    std::vector<double> rc_scores = curseq.give_rc_scores();
    unsigned int curlen = scores.size();
    double ls = 2.0*((double) curlen);
    
    double free_ener = 0;
    double site_count = 0;
    for(unsigned int i = 0; i < curlen; ++i) {
      double curscore = scores[i];
      free_ener += exp(curscore);
      curscore = rc_scores[i];
      free_ener += exp(curscore);
    }
    site_count = free_ener/(free_ener+lam * ls);
    free_ener = log(free_ener);

    //Writing to the output file
    FILE << curseq.give_name() << "\t" << free_ener << "\t" << site_count << endl;
  }
  FILE.close();
}

void write_motif(string outfile_name, Alignment alignment) {
  ofstream FILE (outfile_name.c_str(), ios::out);
  if(!FILE.is_open()){
    cerr << "Cannot open motif file for writing, quitting..." << endl;
    exit(1);
  }
  string TF = alignment.get_TF();

  FILE << "NA" << "\t" << TF << endl;//print name of the motif

  FILE << "PO1\tPO2\tAA\tAC\tAG\tAT\tCA\tCC\tCG\tCT\tGA\tGC\tGG\tGT\tTA\tTC\tTG\tTT" << endl;
  for(short unsigned int i = 0; i < alignment.ncols(); i++){
    for(short unsigned int j = i+1; j < alignment.ncols(); j++){
      FILE << i+1 << "\t" << j+1;
      for(short unsigned int alpha = 0 ; alpha < ALPH_NUM ; alpha++){
	for(short unsigned int beta = 0; beta < ALPH_NUM; beta++){
	  FILE << "\t" << alignment.get_n_ij(alpha,beta,i,j);
	}
      }
      FILE << endl;
    }
  }
  FILE.close();
}
