//
//  MatchingPursuit.cpp
//  DBMCrossSynth
//
//  Created by Nicholas Collins on 17/05/2011.
//  Copyright 2011 Nick Collins. All rights reserved.
//

#include "MatchingPursuit.h"


MatchingPursuit::MatchingPursuit(DBMInnerProducts* ip): ips_(ip) {
    
    numiterations_ = 0; //none allocated
    
    numiterationsnow_ = 0; 
    
    iterations_ = 0; 
    
    residual_ = 0; 
    
    output_= 0; 
    
    numsamples_ = 0; //nothing there yet
}

MatchingPursuit::~MatchingPursuit() {
    
    delete [] iterations_; 
    delete [] residual_; 
    delete [] output_; 
}

void MatchingPursuit::Compute(float * source, float * target, int numsamples, int iterations) {
    
    int i,j;
    
    if (iterations>numiterations_) {
        
        delete [] iterations_; 
        
        iterations_ = new IterationData[iterations];
        
        numiterations_ = iterations; 
        
    }

    numsamples_ = numsamples; 
    
    //allocate storage area for residual

    delete [] residual_; 
    residual_ = new float[numsamples]; 
    
    delete [] output_; 
    output_ = new float[numsamples]; 
    
    //residual starts as same as input
    
    for ( i = 0 ; i<numsamples; ++i) 
        residual_[i] = target[i]; 
        
    
    int regionstart= 0; 
    int regionend = numsamples-1;
    
    for ( i = 0; i<iterations; ++i) {
        
        if(i%10==0) {
            
            std::cout << "iteration " << i << std::endl;  
        }
        
        //calculate any IPs in invalidated area
        ips_->Compute(regionstart, regionend, source, residual_);
        
        //find max atom

        //set 
        IterationData * iterationnow = iterations_+i; 
        
        int whichscale = ips_->maxindexscale_; 
        DBMInnerProductsAtScale * pscale = ips_->ipatscales_[whichscale]; 
        
        iterationnow->weight_ = ips_->maxscale_;
        iterationnow->whichscale_ = whichscale; 
        iterationnow->whichframe_ = pscale->maxframeatscale_; 
        iterationnow->whichfreq_ = pscale->maxindexperframe_[iterationnow->whichframe_];
        
        //set region based on atom 
        int size = pscale->scaleatoms_->scale_;
        int hop = pscale->scaleatoms_->timeresolution_;
        regionstart = hop*iterationnow->whichframe_; 
        regionend = regionstart + size - 1; 
        
        //remove atom from residual
        float * atom = source +(iterationnow->whichfreq_*hop);  //pscale->scaleatoms_->atoms_[iterationnow->whichfreq_]; //pscale->scaleatoms_->atoms_[iterationnow->whichfreq_]; 
        float beta = iterationnow->weight_; 
        float * envelope = pscale->scaleatoms_->gaussianwindow_; 
        float norm = pscale->scaleatoms_->normalizationconstants_[iterationnow->whichfreq_]; 
        beta *= norm; 
        
        float * targetnow = residual_+regionstart; 
        
        float * outputnow = output_+regionstart; 
        
        for ( j = 0 ; j<size; ++j) {
            
//            if (j<20)
//            std::cout << i << " " << j << " " << target[j] << " " << (atom[j]*beta) << " " << (target[j] - atom[j]*beta) << std::endl; 

            float atomnow = atom[j]*envelope[j]*beta;
            targetnow[j] -= atomnow; //*norm
            outputnow[j] += 0.5*atomnow;
        }
        
//        
//        //synthesize larger output than atom itself
//        int outputsize = size/2; 
//        
//        if((regionstart+outputsize)>numsamples)
//            outputsize = size; 
//        
//        if((iterationnow->whichfreq_*hop)+outputsize > pscale->scaleatoms_->length_) 
//            outputsize = size; 
//        
//        int fadeout = outputsize-256; 
//        
//        for ( j = 0 ; j<outputsize; ++j) {
//            
//            //            if (j<20)
//            //            std::cout << i << " " << j << " " << target[j] << " " << (atom[j]*beta) << " " << (target[j] - atom[j]*beta) << std::endl; 
//            
//            float atomnow = atom[j]*envelope[j]*beta;
//            
//            if(j<256) atomnow *= j/256.0f; 
//            
//            if(j>=fadeout) atomnow *= 1.0f-((j-fadeout)/256.0f); 
//            
//            outputnow[j] += 0.25*atomnow;
//        }
        
    }
    
    
    numiterationsnow_ = iterations; 
        
    
}












//void MatchingPursuit::Synthesize(float * output, int numsamples) {
//    
//    int i,j; 
//    
//    //add up atoms into output
//    
//    //zero output
//    for ( i = 0; i<numsamples; ++i) 
//        output[i] = 0.0f; 
//
//    int regionstart= 0; 
//    int regionend = numsamples-1;
//    
//    for ( i = 0; i<numiterationsnow_; ++i) {
//        
//        IterationData * iterationnow = iterations_+i;
//        
//        DBMInnerProductsAtScale * pscale = ips_->ipatscales_[iterationnow->whichscale_];
//        float * atom = pscale->scaleatoms_->atoms_[iterationnow->whichfreq_]; 
//        float beta = iterationnow->weight_; 
// 
//        int size = pscale->scaleatoms_->scale_;
//        int hop = pscale->scaleatoms_->timeresolution_;
//        regionstart = hop*iterationnow->whichframe_; 
//        regionend = regionstart + size - 1; 
//    
//        //CHECK VALID; can also synthesize in a lesser range then too
//        if (regionend < numsamples) {
//        
//        float * target = output + regionstart;
//        
//        for ( j = 0 ; j<size; ++j) 
//            target[j] += atom[j]*beta;
//            
//        }
//        
//    }
//    
//}


//
//
////transformative
//void MatchingPursuit::Synthesize2(float * output, int numsamples) {
//    
//    int i,j; 
//    
//    //add up atoms into output
//    
//    //zero output
//    for ( i = 0; i<numsamples; ++i) 
//        output[i] = 0.0f; 
//    
//    int regionstart= 0; 
//    int regionend = numsamples-1;
//    
//    for ( i = 0; i<numiterationsnow_; ++i) {
//        
//        IterationData * iterationnow = iterations_+i;
//
//        DBMInnerProductsAtScale * pscale = ips_->ipatscales_[iterationnow->whichscale_];
//        
//        int freqbin = iterationnow->whichfreq_; 
//        
//        if(freqbin>20) {
//        
//        float * atom = pscale->scaleatoms_->atoms_[iterationnow->whichfreq_]; 
//        float beta = iterationnow->weight_; 
//        
//        int size = pscale->scaleatoms_->scale_;
//        int hop = pscale->scaleatoms_->timeresolution_;
//        regionstart = hop*iterationnow->whichframe_; 
//        regionend = regionstart + size - 1; 
//        
//        //CHECK VALID; can also synthesize in a lesser range then too
//        if (regionend < numsamples) {
//            
//            float * target = output + regionstart;
//            
//            for ( j = 0 ; j<size; ++j) 
//                target[j] += atom[j]*beta;
//            
//        }
//            
//        }
//        
//    }
//    
//}
//





#include <fstream>
//#include <stdio.h>
//#include <math.h>
//#include <string>

using std::endl;
using std::ifstream; 
using std::ofstream;
using std::cout; 
//using std::cin; 
//using std::string; 

void MatchingPursuit::Save(char * filename) {
    
    ofstream out(filename); //no point using binary unless make it all read and write, but easier if just make whitespace based with ascii files, std::ios::binary);
    
    if (out.is_open()) {
        
        //out << string("ListeningLearning") << endl; 
        //int version=2; 
        
        out << numsamples_ << endl; 
        
        out << numiterationsnow_ << endl; 
        
        for (int i = 0; i<numiterationsnow_; ++i) {
            
            IterationData * iterationnow = iterations_+i;
        
            out << iterationnow->weight_ << " " << iterationnow->whichscale_ << " " << iterationnow->whichframe_ << " " << iterationnow->whichfreq_ << endl; 
        }
      
 
        out.close();
        cout << "saved" << filename << endl; 
        
    }
    else cout << "Unable to save file"; 
    
}

void MatchingPursuit::Load(char * filename) {
    
    ifstream in(filename); 
    
    if (in.is_open()) {

        in >> numsamples_; //for external reference
        
        in >> numiterationsnow_; 
        
        if (numiterationsnow_>numiterations_) {
            
            delete [] iterations_; 
            
            iterations_ = new IterationData[numiterationsnow_];
            
            numiterations_ = numiterationsnow_; 
            
        }
        
        for (int i = 0; i<numiterationsnow_; ++i) {
            
            IterationData * iterationnow = iterations_+i;
            
            in >> iterationnow->weight_ >> iterationnow->whichscale_ >> iterationnow->whichframe_ >> iterationnow->whichfreq_; 
        }
        
        in.close();
        cout << "loaded" << filename << endl; 
    }
    else cout << "Unable to load file";    


}




