/*
	SuperCollider real time audio synthesis system
 Copyright (c) 2002 James McCartney. All rights reserved.
	http://www.audiosynth.com
 
 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.
 
 This program is distributed in the hope that it will be useful,
 but WITHOUT ANY WARRANTY; without even the implied warranty of
 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 GNU General Public License for more details.
 
 You should have received a copy of the GNU General Public License
 along with this program; if not, write to the Free Software
 Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301  USA
 */

//UGens by Nick Collins  http://www.informatics.sussex.ac.uk/users/nc81/
//Released under the GNU GPL as extensions for SuperCollider 3

//building: SC 3.5/3.4 compatible:
//cmake -DSC_PATH=/data/gitprojects/SuperCollider-Source3.5 -DCMAKE_OSX_ARCHITECTURES='i386;x86_64' ..

//NOTES
//Vnrt not actually needed, since can't do NRT calc and collect new spectral data at same time 


#include "SC_PlugIn.h"
#include <math.h>
#include <stdlib.h>
//#include "/Users/nickcollins/Desktop/tosort/gsl_universal_1.14/gsl/gsl_cblas.h"
#include "gsl_cblas.h"
#include "FFT_UGens.h"


InterfaceTable *ft; 



//data to be shared between RT and NRT threads 
struct CmdData {
	enum Type
    {
		NRTSourceSeparationCalculation
		//NRTDtorPP
    };
	
	Type type; 
	Unit * unit; 
	void *	nrtallocated; 
	float samplingrate_; 
	
};




struct SourceSeparation : public Unit  
{
    bool rtcollectionunderway_; 
    int collectionpos_; 
    bool nrtcalculationunderway_; 
    
    int numsources_; 
    int fftbins_; 
    int numframes_; 
    
    float *W_, *H_, *V_; 
    
    float * Wnrt_;
    float * Hnrt_;
    float * Vnrt_;
    float * calctemp_; 
    
    float lastinput_; 
    

    
};


struct PV_SourceSeparationMask : public Unit {
 
    float lastuseh_; 
    int hframepos_; 
    bool hplayback_; 
    
};


extern "C" {  
	
	void SourceSeparation_next(SourceSeparation* unit, int inNumSamples);
	void SourceSeparation_Ctor(SourceSeparation* unit);
    void SourceSeparation_Dtor(SourceSeparation* unit);
	
    
    void PV_SourceSeparationMask_next(PV_SourceSeparationMask* unit, int inNumSamples);
	void PV_SourceSeparationMask_Ctor(PV_SourceSeparationMask* unit);
    //void PV_SourceSeparationMask_Dtor(PV_SourceSeparationMask* unit);
    
}




//adapted to work on a whole array at once
//http://www.taygeta.com/random/gaussian.html
//ftp://ftp.taygeta.com/pub/c/boxmuller.c
void boxmullerarray(float m, float s, float * target, int targetsize)	/* normal random variate generator */
{				        /* mean m, standard deviation s */
	float x1, x2, w, y1;
	static float y2;
	static int use_last = 0;
    
    for (int ii=0; ii<targetsize; ++ii) {
    
	if (use_last)		        /* use value from previous call */
	{
		y1 = y2;
		use_last = 0;
	}
	else
	{
		do {
			x1 = 2.0 * ((float)rand()/RAND_MAX) - 1.0;
			x2 = 2.0 * ((float)rand()/RAND_MAX) - 1.0;
			w = x1 * x1 + x2 * x2;
		} while ( w >= 1.0 );
        
		w = sqrt( (-2.0 * log( w ) ) / w );
		y1 = x1 * w;
		y2 = x2 * w;
		use_last = 1;
	}
    
        target[ii] = m + y1 * s ;  
        
    }
	//return( m + y1 * s );
    
    
    
}



//for NRT allocation and deallocation 

bool cmdStage2(World* inWorld, CmdData* cmd) // NRT
{
	//Unit* unit = cmd->unit;
	
	switch (cmd->type) {
		case CmdData::NRTSourceSeparationCalculation: {
            
//            //set up algorithm values
//            PolyPitch* pUnit = (PolyPitch*)cmd->unit;
//            
//            PolyPitchUGen * pPoly = new PolyPitchUGen(cmd->samplingrate_,pUnit->maxvoices_,inWorld); 
//            
//            pPoly->gamma_.levelcompressionfactor_ = pUnit->levelcompressionfactor_;
//            pPoly->gamma_.mixleftterm_ = pUnit->mixleftterm_;
//            
//            pPoly->polypitch_.torprec_ = pUnit->torprec_;
//            pPoly->polypitch_.cancellationweight_ = pUnit->cancellationweight_;
//            pPoly->polypitch_.polyphonyestimategamma_ = pUnit->polyphonyestimategamma_;
//            
//			cmd->nrtallocated = (void *)pPoly;
            
            Print("Source Separation NRT calculation\n");
            
            //actual NMF calculation 
            //prototyped in cblastest.c
            
            
            SourceSeparation* pUnit = (SourceSeparation*)cmd->unit;
            int n = pUnit->fftbins_;
            int m = pUnit->numframes_; //1000 spectral frames, about 10 seconds 
            int r = pUnit->numsources_; //5; //5 sources
            //2n*(r+m) + r*m
            int ii,jj,kk; 
            
            float * V = pUnit->Vnrt_;
            float * W = pUnit->Wnrt_;
            float * H = pUnit->Hnrt_;
            
            //initialise W to all ones
            
            //prepare W as ones
            for (jj=0; jj<(n*r); ++jj) W[jj]=1.0f; 
            
            //initialise H to normal distribution, mean 1.0, stddev 0.1
            boxmullerarray(1.0,0.1,H,r*m);
            
            float * calctemp = pUnit->calctemp_; 
            
            float * temp1 = calctemp; 
            float * temp2a = calctemp + (n*r); 
            float * temp2 = temp2a + (n*m); 
            float * temp3 = temp2 + (n*r); 
            float * temp4 = temp3 + (r*m); 
            
            //total size
            //2n*(r+m) + r*m
            
            
            for (ii=0; ii<200; ++ii) {
                
                //temp1 = (V*(H')); 
                //temp2 = W*H*(H'); 

                //printf("here 3a\n");
                //rows V and temp1    columns H' and temp1    cols V   rows H'
                cblas_sgemm (CblasRowMajor, CblasNoTrans, CblasTrans, n, r, m, 1.0, V, m, H, m, 0.0, temp1, r);

                //printf("here 3b\n");
                //W*H
                //rows W and temp2a  cols H and temp2a  cols W and rows H
                cblas_sgemm (CblasRowMajor, CblasNoTrans, CblasNoTrans, n, m, r, 1.0, W, r, H, m, 0.0, temp2a, m);
                
                //printf("here 3c\n");
                //(W*H)*H'	
                cblas_sgemm (CblasRowMajor, CblasNoTrans, CblasTrans, n, r, m, 1.0, temp2a, m, H, m, 0.0, temp2, r);
                
                //printf("here 4\n");
                
                for (jj=0; jj<n; ++jj) {
                    
                    int rownum = r*jj; 
                    
                    for (kk=0; kk<r; ++kk) {
                        
                        int indexnow = rownum+kk; 
                        
                        //was W2, but replacing anyway, and can work in place
                        W[indexnow] = W[indexnow]*temp1[indexnow]/temp2[indexnow]; 
                    }
                }
 
//                printf("iteration %d\n",ii); 
//                printf ("[ %g, %g\n", W[10], W[11]);
//                printf ("  %g, %g ]\n", W[42], W[43]);
//                
                //  temp1 = ((W')*V); 
                //   temp2 = (W')*W*H; 
                //   
                //printf("here 4a\n");
                
                // temp3 = ((W')*V);    r by n    *  n by m 
                cblas_sgemm (CblasRowMajor, CblasTrans, CblasNoTrans, r, m, n, 1.0, W, r, V, m, 0.0, temp3, m);
                
                //printf("here 4b\n");
                
                //W*H
                //as above
                //rows W and temp2a  cols H and temp2a  cols W and rows H
                cblas_sgemm (CblasRowMajor, CblasNoTrans, CblasNoTrans, n, m, r, 1.0, W, r, H, m, 0.0, temp2a, m);
                
                //printf("here 4c\n");
                
                //temp4 = (W')*W*H;  (W')*temp2a      r by n  *   n* m 
                cblas_sgemm (CblasRowMajor, CblasTrans, CblasNoTrans, r, m, n, 1.0, W, r, temp2a, m, 0.0, temp4, m);
                
                //printf("here 5\n");
                
                for (jj=0; jj<r; ++jj) {
                    
                    int rownum = m*jj; 
                    
                    for (kk=0; kk<m; ++kk) {
                        
                        int indexnow = rownum+kk; 
                        
                        //was W2, but replacing anyway, and can work in place
                        H[indexnow] = H[indexnow]*temp3[indexnow]/temp4[indexnow]; 
                    }
                }
                
//                printf("iteration %d\n",ii); 
//                printf ("[ %g, %g, %g\n", H[10], H[11],H[12]);
//                printf ("  %g, %g, %g ]\n", H[53], W[54],H[55]);
                
            }
     
            
        }
			return true;

	}
	
	return false;
}

bool cmdStage3(World* world, CmdData* cmd) // RT
{
	switch (cmd->type) {
		case CmdData::NRTSourceSeparationCalculation: {
			//((PolyPitch*)cmd->unit)->ugen_ = (PolyPitchUGen*)cmd->nrtallocated;
            
            
            //copy data across to buffers
            int j; 
            
            SourceSeparation* unit = (SourceSeparation*)cmd->unit;
            
            float * Vnrt = unit->Vnrt_;
            float * Wnrt = unit->Wnrt_;
            float * Hnrt = unit->Hnrt_;
            
            float * W = unit->W_; 
            float * H = unit->H_; 
            
            //may be better to pass the transpose W' here so rows themselves are the masks, more efficient for utilising UGen
            //no, OK, not too expensive to resolve in mask UGen. However, calculation was on power spectrum and will 
            //modify magnitude spectrum, so need to adjust multipliers. Square root both
            // W * H = V so sqrt(W) * sqrt(H ) = sqrt(V) for individual terms to compose magnitude spectrum (sqrt(V))
            for (j=0; j<(unit->fftbins_*unit->numsources_); ++j)
                W[j] =   sqrt(Wnrt[j]);
            
            for (j=0; j<(unit->numframes_*unit->numsources_); ++j)
                H[j] =   sqrt(Hnrt[j]);
            
            
            Print("Source Separation NRT calculation done\n");
            
		}
			return true;
	}
	return false;
}

bool cmdStage4(World* world, CmdData* cmd) // NRT
{
	return true;
}

void cmdCleanup(World* world, void* cmd)
{
	RTFree(world, cmd);
}













void SourceSeparation_Ctor( SourceSeparation* unit ) {
	
    //Print("SourceSeparation starting\n"); 
    //printf("SourceSeparation starting\n"); 
    
	//ZOUT0(0) = 0.0; 
	ZOUT0(0) = ZIN0(0);
    
    unit->numsources_ = ZIN0(1); 
    
    unit->numframes_ = ZIN0(2);
    
    unit->nrtcalculationunderway_ = false; 
    unit->rtcollectionunderway_ = false; 
    
    unit->collectionpos_ = 0; 
    
    //W buffer
    int ibufnum = ZIN0(4); 
    
    World *world = unit->mWorld;
    SndBuf *buf; 
    int localBufNum; 
    Graph *parent; 
    
    if (ibufnum >= world->mNumSndBufs) {
        localBufNum = ibufnum - world->mNumSndBufs;
        parent = unit->mParent;
        if(localBufNum <= parent->localBufNum) {
            buf = parent->mLocalSndBufs + localBufNum;
        } else {
            buf = world->mSndBufs;
        }
    } else {
        buf = world->mSndBufs + ibufnum;
    }

    //W is n by r, which is fftbins  * numsources
    unit->W_ = buf->data; 
    int siz = buf->samples; 
    
    unit->fftbins_ = siz/unit->numsources_; 
    
    //H buffer
    ibufnum = ZIN0(5); 
    
    if (ibufnum >= world->mNumSndBufs) {
        localBufNum = ibufnum - world->mNumSndBufs;
        parent = unit->mParent;
        if(localBufNum <= parent->localBufNum) {
            buf = parent->mLocalSndBufs + localBufNum;
        } else {
            buf = world->mSndBufs;
        }
    } else {
        buf = world->mSndBufs + ibufnum;
    }

    //H  is r * m  which is numsources * numframes
     unit->H_ = buf->data; 
    
    if(buf->samples != (unit->numsources_ *  unit->numframes_)) {
        
        printf("Hbuffer has wrong size: should be %d and is %d\n",buf->samples,unit->numsources_ *  unit->numframes_); 
        
    }
    
    unit->Wnrt_ = (float*)RTAlloc(world, sizeof(float) * (unit->fftbins_ * unit->numsources_));
    unit->Hnrt_ = (float*)RTAlloc(world, sizeof(float) * (unit->numsources_ *  unit->numframes_));
    unit->Vnrt_ = (float*)RTAlloc(world, sizeof(float) * (unit->fftbins_ * unit->numframes_));
    unit->V_ = (float*)RTAlloc(world, sizeof(float) * (unit->fftbins_ * unit->numframes_));
 
    //2n*(r+m) + r*m
    //(2*unit->fftbins_*(unit->numframes_ + unit->numsources_))  +  (unit->numframes_ * unit->numsources_);
    
    unit->calctemp_ = (float*)RTAlloc(world, sizeof(float) * ((2*unit->fftbins_*(unit->numframes_ + unit->numsources_))  +  (unit->numframes_ * unit->numsources_)));
    
    unit->lastinput_ = 0.0f; 

	SETCALC(SourceSeparation_next);
	
}


void SourceSeparation_Dtor( SourceSeparation* unit ) {

    //assumes any NRT functionality cancelled by now so no danger
    
    RTFree(unit->mWorld, unit->Wnrt_); 
    RTFree(unit->mWorld, unit->Hnrt_); 
    RTFree(unit->mWorld, unit->Vnrt_); 
    RTFree(unit->mWorld, unit->V_);
    RTFree(unit->mWorld, unit->calctemp_);
    
}


void SourceSeparation_next( SourceSeparation *unit, int inNumSamples ) {
	
	int i,j;
	
	//float *output = OUT(0);
    
    
    float fbufnum = ZIN0(0);
    
    //copy through, so can chain 
    
    OUT0(0) = fbufnum; //-1.f;
    //nothing can be done while background calculation proceeding
    if(unit->nrtcalculationunderway_) {
        
        return;     
        
    }
    

    float trigger = ZIN0(3);
    
    
    
    //printf("trigger %f last %f diff %f\n",trigger,unit->lastinput_,trigger-unit->lastinput_);
    
    if ((trigger>0.0f) && ((trigger-unit->lastinput_) > 0.0000001f)) {
        
        //if already collecting, continue from where you're up to
        if(!unit->rtcollectionunderway_) {
            
            unit->collectionpos_=0; 
            
            unit->rtcollectionunderway_=true; 
            
            printf("SourceSeparation UGen: triggered spectral frames collection"); 
            //re-initialise W and H 
            
            //NO, do at beginning of NRT work 
            
            
            
        }
        
    }
    
    unit->lastinput_ = trigger; 
    
    
    if(unit->rtcollectionunderway_) {
    
        
	//next FFT bufffer ready, update
	//assuming at this point that buffer precalculated for any resampling
	if (fbufnum > -0.01f) {
		
		int ibufnum = (uint32)fbufnum; 
		
		World *world = unit->mWorld;
		SndBuf *buf; 
		
		if (ibufnum >= world->mNumSndBufs) {
			int localBufNum = ibufnum - world->mNumSndBufs;
			Graph *parent = unit->mParent;
			if(localBufNum <= parent->localBufNum) {
				buf = parent->mLocalSndBufs + localBufNum;
			} else {
				buf = world->mSndBufs;
			}
		} else {
			buf = world->mSndBufs + ibufnum;
		}
      
        //make sure in real and imag form
        //SCComplexBuf * complexbuf = ToComplexApx(buf);  
        
        float * data= (float *)ToComplexApx(buf);
        
        //float * data= buf->data;
        
        float real, imag; 
        float intensity; 
        
        //int indexpos = unit->collectionpos_ * unit->fftbins_; 
        int rowsize = unit->numframes_; 
        
        float * target = unit->V_; 
        
        int framenow = unit->collectionpos_; 
        for  (j=1; j<(unit->fftbins_-1); ++j) {
            
            int index = 2*j; 
            real= data[index];			
            imag= data[index+1];
            intensity = (real*real) + (imag*imag);                  
        
                //major row; so need spectral frame as one column
            target[j*rowsize + framenow] = intensity; 

                //intensities[i] = intensity; 
            }     

        //sort DC and nyquist separately 
        target[framenow] = data[0]* data[0];
        target[(unit->fftbins_-1)*rowsize+framenow] = data[1] * data[1];         
        
        //printf("collected frame %d\n",framenow); 
        
        ++unit->collectionpos_; 
        
        
        //if have collected enough data after being triggered to analyze
        
        if(unit->collectionpos_== unit->numframes_) {
                        
            //copy over data from RT to NRT storage for safety (avoids iterative updates to W and H buffers)
                
            float * Vnrt = unit->Vnrt_;
            
            for (j=0; j<(unit->fftbins_*unit->numframes_); ++j)
                Vnrt[j] =   target[j];
            
            
            printf("commencing NRT\n");
        
            unit->rtcollectionunderway_=false; 
            
            unit->nrtcalculationunderway_=true; 
            
            //set NRT calculation flag, can't interrupt a second time
            
            CmdData* cmd = (CmdData*)RTAlloc(unit->mWorld, sizeof(CmdData));
            cmd->unit = (Unit *)unit; 
            cmd->type = CmdData::NRTSourceSeparationCalculation; 
            
            //(AsyncStageFn)PitchNoteUGencmdStage4
            
            DoAsynchronousCommand(unit->mWorld, 0, "", (void*)cmd,
                                  (AsyncStageFn)cmdStage2,
                                  (AsyncStageFn)cmdStage3,
                                  NULL,
                                  cmdCleanup,
                                  0, 0);
            
            
            
        }
        
        

    }
        
    }
     
	
//    for(j=0; j<inNumSamples; ++j) {
//		
//		output[j]= input[j]*0.1; //((float)j/inNumSamples); 
//	}
}







void PV_SourceSeparationMask_Ctor(PV_SourceSeparationMask* unit) {
    
    
    unit->lastuseh_ = 0.f; 
    unit->hframepos_ = 0; 
    unit->hplayback_ = false; 
    
    
    SETCALC(PV_SourceSeparationMask_next);
	ZOUT0(0) = ZIN0(0);
}



void PV_SourceSeparationMask_next(PV_SourceSeparationMask* unit, int inNumSamples) {
    
    float useh = ZIN0(3);
    
    if(useh>0.0f) {
        
    if(unit->lastuseh_<=0.0f) {
        
        unit->hplayback_ = true;
        unit->hframepos_=0; 
    }
        
    } else {
        
        unit->hplayback_ = false; 
    }
    
    unit->lastuseh_ = useh; 

    PV_GET_BUF
    
	SCPolarBuf *p = ToPolarApx(buf);
    
    int numsources =  ZIN0(1);
    int whichsource = ZIN0(2);

    if(whichsource<0) whichsource = 0; 
    if(whichsource>=numsources) whichsource = whichsource%numsources; 
    
    //wbuffer, hbuffer;
    
    ibufnum = ZIN0(4); 
    
    int localBufNum2; 
    Graph *parent2; 
    
    if (ibufnum >= world->mNumSndBufs) {
        localBufNum2 = ibufnum - world->mNumSndBufs;
        parent2 = unit->mParent;
        if(localBufNum2 <= parent2->localBufNum) {
            buf = parent2->mLocalSndBufs + localBufNum2;
        } else {
            buf = world->mSndBufs;
        }
    } else {
        buf = world->mSndBufs + ibufnum;
    }
    
    //W is n by r, which is fftbins  * numsources
    float * W = buf->data; 
    int Wsiz = buf->samples; 
    
    //H buffer
    ibufnum = ZIN0(5);
    
    if (ibufnum >= world->mNumSndBufs) {
        localBufNum2 = ibufnum - world->mNumSndBufs;
        parent2 = unit->mParent;
        if(localBufNum2 <= parent2->localBufNum) {
            buf = parent2->mLocalSndBufs + localBufNum2;
        } else {
            buf = world->mSndBufs;
        }
    } else {
        buf = world->mSndBufs + ibufnum;
    }
    
    //H is r by m, which is numsources by numframes
    float * H = buf->data; 
    int Hsiz = buf->samples; 
    
    //do masking 
    
    //safety check
    if(numbins == ((Wsiz/numsources)-2)) {
    
    //ignore use of H mixing matrix for now; else need a frame counter loop and reset capability triggered by useh
    if(unit->hplayback_) {
      
        int numframes = Hsiz/numsources; 
        
        //safety, local copy checked against buffer available
        int hframepos = (unit->hframepos_)%numframes; 
        
        float mix = H[whichsource*numframes + hframepos];
        
        p->dc  = p->dc  * sc_min(W[whichsource] * mix,1.0f);
        p->nyq = p->nyq * sc_min(W[numbins*numsources+whichsource] * mix,1.0f);
        
        for (int i=0; i<numbins; ++i) {
            float mag = p->bin[i].mag;
            p->bin[i].mag = mag * (sc_min(W[(i+1)*numsources+whichsource] * mix,1.0f));
        }
        
        unit->hframepos_ = (unit->hframepos_+1)%numframes; 
        
    } else {

        p->dc  = p->dc  * sc_min(W[whichsource],1.0f);
        p->nyq = p->nyq * sc_min(W[numbins*numsources+whichsource],1.0f);
        
        for (int i=0; i<numbins; ++i) {
            float mag = p->bin[i].mag;
            p->bin[i].mag = mag * sc_min(W[(i+1)*numsources+whichsource],1.0f);
        }
            
    }
        
    }
}


#define DefinePVUnit(name) \
(*ft->fDefineUnit)(#name, sizeof(PV_Unit), (UnitCtorFunc)&name##_Ctor, 0, 0);


PluginLoad(SourceSeparation)
{
    
    
    //printf("SourceSeparation loading\n"); 
	init_SCComplex(inTable);
	
    ft = inTable;
	
	DefineDtorCantAliasUnit(SourceSeparation);
    
    //DefineDtorUnit
    DefinePVUnit(PV_SourceSeparationMask);
	
}


