//SuperCollider is under GNU GPL version 3, http://supercollider.sourceforge.net/
//these extensions released under the same license

/*
 *  Wavelets.cpp
 *
 *  Created by Nicholas Collins on 19/07/2011
 *  Copyright 2011 Nicholas M Collins. All rights reserved.
 *
 */


//building with CMake
//need additional libs in both architectures if linking in 
//cmake -DSC_PATH=/data/gitprojects/supercollider -DCMAKE_OSX_ARCHITECTURES='x86_64' ..


//example gsl wavelet code, see
//http://www.gnu.org/s/gsl/manual/html_node/DWT-Examples.html

#include "SC_PlugIn.h"

#include <stdio.h>
#include <math.h>
#include <gsl/gsl_wavelet.h>
//#include "/Users/nickcollins/Desktop/tosort/gsl_universal_1.14/gsl/gsl_wavelet.h"
//#include "/usr/local/include/gsl/gsl_wavelet.h"


//#if defined(__APPLE__) && !defined(SC_IPHONE)
//#include "vecLib/vDSP.h"
//#endif




static InterfaceTable *ft; 



struct DWTBase : public Unit
{
	SndBuf *m_sndbuf;
	//float *m_fftbuf;
    
	int m_pos, m_fullbufsize, m_audiosize; // "fullbufsize" includes any zero-padding, "audiosize" does not.
	int m_log2n_full, m_log2n_audio;
    
	uint32 m_bufnum;

	int m_hopsize, m_shuntsize; // These add up to m_audiosize
	int m_wintype;
    
	int m_numSamples;
    
    double * data_; 
    gsl_wavelet *w_;    //wavelet object
    gsl_wavelet_workspace *work_; //work space area
    float * window_;
    
};

struct DWT : public DWTBase
{
	float *m_inbuf;
    
};

struct IDWT : public DWTBase
{
	float *m_olabuf;
    float * temp_; 
};



extern "C" {  
	
	void DWT_next(DWT *unit, int inNumSamples);
	void DWT_Ctor(DWT* unit);
	void DWT_Dtor(DWT* unit);
    void DWT_ClearUnitOutputs(DWT *unit, int wrongNumSamples);

	void IDWT_next(IDWT *unit, int inNumSamples);
	void IDWT_Ctor(IDWT* unit);
	void IDWT_Dtor(IDWT* unit);
}


//not needed, SC_fftlib.h provides already
//enum Wavelets_WindowFunction
//{
//	kRectWindow = -1,
//	kSineWindow = 0,
//	kHannWindow = 1
//};

//could actually use the SC_fftlib function call

#define pi 3.1415926535898f
#define twopi 6.28318530717952646f

float* wavelets_create_window(DWTBase *unit, int wintype, int size)
{
	unsigned short i;
	float *win = (float*)RTAlloc (unit->mWorld,size * sizeof(float));
    
	double winc;
	switch(wintype) {
		case kSineWindow:
			winc = pi / size;
			for (i=0; i<size; ++i) {
				double w = i * winc;
				win[i] = sin(w);
			}
			break;
		case kHannWindow:
			winc = twopi / size;
			for (i=0; i<size; ++i) {
				double w = i * winc;
				win[i] = 0.5 - 0.5 * cos(w);
			}
			break;
	}
    
	return win;
    
}





//functions for real time safe allocation and deallocation, adapted from gsl wavelet code

gsl_wavelet * gsl_wavelet_alloc2 (DWTBase *unit, const gsl_wavelet_type * T, size_t k)
{
    int status;
    
    gsl_wavelet *w = (gsl_wavelet *) RTAlloc (unit->mWorld, sizeof (gsl_wavelet));
    
//    if (w == NULL)
//    {
//        GSL_ERROR_VAL ("failed to allocate space for wavelet struct",
//                       GSL_ENOMEM, 0);
//    };
    
    w->type = T;
    
    status = (T->init) (&(w->h1), &(w->g1), &(w->h2), &(w->g2),
                        &(w->nc), &(w->offset), k);
    
    if (status)
    {
        RTFree(unit->mWorld, w);
        //GSL_ERROR_VAL ("invalid wavelet member", GSL_EINVAL, 0);
        w = 0; 
    }
    
    return w;
}

void gsl_wavelet_free2 (DWTBase *unit, gsl_wavelet * w)
{
    RTFree(unit->mWorld, w);
    
}


gsl_wavelet_workspace * gsl_wavelet_workspace_alloc2 (DWTBase *unit, size_t n)
{
    gsl_wavelet_workspace *work;
    
//    if (n == 0)
//    {
//        GSL_ERROR_VAL ("length n must be positive integer", GSL_EDOM, 0);
//    }
    
    work = (gsl_wavelet_workspace *) RTAlloc (unit->mWorld, sizeof (gsl_wavelet_workspace));
    
//    if (work == NULL)
//    {
//        GSL_ERROR_VAL ("failed to allocate struct", GSL_ENOMEM, 0);
//    }
    
    work->n = n;
    work->scratch = (double *) RTAlloc (unit->mWorld, n * sizeof (double));
    
    if (work->scratch == NULL)
    {
        /* error in constructor, prevent memory leak */
        RTFree(unit->mWorld, work);
        //GSL_ERROR_VAL ("failed to allocate scratch space", GSL_ENOMEM, 0);
    }
    
    return work;
}

void gsl_wavelet_workspace_free2 (DWTBase *unit, gsl_wavelet_workspace * work)
{
    //RETURN_IF_NULL (work);
    /* release scratch space */
    
    RTFree(unit->mWorld, work->scratch);
    //work->scratch = NULL;
    RTFree(unit->mWorld, work);
}



int g_bsplinek[11] = {103, 105, 202, 204, 206, 208, 301, 303, 305, 307, 309}; 


static int DWTBase_Ctor(DWTBase *unit, int frmsizinput, int wavelettypeinput)
{
	World *world = unit->mWorld;
     
	uint32 bufnum = (uint32)ZIN0(0);
    
	SndBuf *buf;
	if (bufnum >= world->mNumSndBufs) {
		int localBufNum = bufnum - world->mNumSndBufs;
		Graph *parent = unit->mParent;
		if(localBufNum <= parent->localMaxBufNum) {
			buf = parent->mLocalSndBufs + localBufNum;
		} else {
			if(unit->mWorld->mVerbosity > -1){ Print("DWTBase_Ctor error: invalid buffer number: %i.\n", bufnum); }
			return 0;
		}
	} else {
		buf = world->mSndBufs + bufnum;
	}
    
	if (!buf->data) {
		if(unit->mWorld->mVerbosity > -1){ Print("DWTBase_Ctor error: Buffer %i not initialised.\n", bufnum); }
		return 0;
	}
    
	unit->m_sndbuf = buf;
	unit->m_bufnum = bufnum;
	unit->m_fullbufsize = buf->samples;
	int framesize = (int)ZIN0(frmsizinput);
	if(framesize < 1)
		unit->m_audiosize = buf->samples;
	else
		unit->m_audiosize = sc_min(buf->samples, framesize);
    
	unit->m_log2n_full  = LOG2CEIL(unit->m_fullbufsize);
	unit->m_log2n_audio = LOG2CEIL(unit->m_audiosize);
    
    
	// Although FFTW allows non-power-of-two buffers (vDSP doesn't), this would complicate the windowing, so we don't allow it.
	if (!ISPOWEROFTWO(unit->m_fullbufsize)) {
		Print("DWTBase_Ctor error: buffer size (%i) not a power of two.\n", unit->m_fullbufsize);
		return 0;
	}
	else if (!ISPOWEROFTWO(unit->m_audiosize)) {
		Print("DWTBase_Ctor error: audio frame size (%i) not a power of two.\n", unit->m_audiosize);
		return 0;
	}
	else if (unit->m_audiosize < SC_FFT_MINSIZE ||
             (((int)(unit->m_audiosize / unit->mWorld->mFullRate.mBufLength))
              * unit->mWorld->mFullRate.mBufLength != unit->m_audiosize)) {
                 Print("DWTBase_Ctor error: audio frame size (%i) not a multiple of the block size (%i).\n", unit->m_audiosize, unit->mWorld->mFullRate.mBufLength);
                 return 0;
             }
    
	unit->m_pos = 0;
    

    //double precision calculations required for wavelet transform in gsl
    unit->data_ = (double*)RTAlloc(unit->mWorld, unit->m_fullbufsize * sizeof(double));
    
    int wavelettype = (int)ZIN0(wavelettypeinput);  
    
    const gsl_wavelet_type * waveletname;
    int k; 
    
    //available: 
    //gsl_wavelet_daubechies 4, 6, 8,10,12, 14, 16, 18, 20  indices 0-8
    //gsl_wavelet_daubechies_centered 4, 6, 8,10,12, 14, 16, 18, 20  indices 9-17
    //gsl_wavelet_haar k=2 only index 18
    //gsl_wavelet_haar_centered k=2 only index 19
    //gsl_wavelet_bspline k=103, 105, 202, 204, 206, 208, 301, 303, 305 307, 309 indices 20-30
    //gsl_wavelet_bspline_centered k=103, 105, 202, 204, 206, 208, 301, 303, 305 307, 309 31-41
    
    if(wavelettype>41) wavelettype = 0; 
    if(wavelettype<0) wavelettype=0; 
    
    if(wavelettype<9) {
        
        waveletname = gsl_wavelet_daubechies;
        k = wavelettype*2+4; 
        
    } else if (wavelettype<18) {
        
        waveletname = gsl_wavelet_daubechies_centered;
        k = (wavelettype-9)*2+4; 
        
    } else if (wavelettype==18) {
        
        waveletname = gsl_wavelet_haar;
        k=2; 
        
    } else if (wavelettype==19) {
        
        waveletname = gsl_wavelet_haar_centered;
        k=2; 
        
    } else if (wavelettype<31) {
        
        waveletname = gsl_wavelet_bspline;
        k = g_bsplinek[wavelettype-20]; 
        
        
    } else {
        
        waveletname = gsl_wavelet_bspline_centered;
        k = g_bsplinek[wavelettype-31];
    }  
    
    
    unit->w_ = gsl_wavelet_alloc2 (unit, waveletname, k);
    //unit->w_ = gsl_wavelet_alloc2 (unit, gsl_wavelet_daubechies, 4);
    unit->work_ = gsl_wavelet_workspace_alloc2 (unit, unit->m_fullbufsize);
    
    
    //create window
    
    //wavelets_create_window
    if (unit->m_wintype!=kRectWindow)
        unit->window_ = wavelets_create_window(unit,unit->m_wintype, unit->m_audiosize); 
        
	ZOUT0(0) = ZIN0(0);
    
	return 1;
}


static int DWTBase_Dtor(DWTBase *unit) {
    
    RTFree(unit->mWorld, unit->data_);
    
    gsl_wavelet_free2(unit, unit->w_);
    gsl_wavelet_workspace_free2(unit, unit->work_);
    
    if (unit->m_wintype!=kRectWindow) 
        RTFree(unit->mWorld,unit->window_); 
    
    
    return 1; 
   
}



void DWT_Ctor( DWT* unit ) {
	
    
    unit->m_wintype = (int)ZIN0(3); // wintype may be used by the base ctor
    
	if(!DWTBase_Ctor(unit, 5, 6)){
		SETCALC(DWT_ClearUnitOutputs);
		// These zeroes are to prevent the dtor freeing things that don't exist:
		unit->m_inbuf = 0;
		//unit->m_scfft = 0;
		return;
	}
	//int fullbufsize = unit->m_fullbufsize * sizeof(float);
	int audiosize = unit->m_audiosize * sizeof(float);
    
	int hopsize = (int)(sc_max(sc_min(ZIN0(2), 1.f), 0.f) * unit->m_audiosize);
	if (((int)(hopsize / unit->mWorld->mFullRate.mBufLength)) * unit->mWorld->mFullRate.mBufLength
        != hopsize) {
		Print("DWT_Ctor: hopsize (%i) not an exact multiple of SC's block size (%i) - automatically corrected.\n", hopsize, unit->mWorld->mFullRate.mBufLength);
		hopsize = ((int)(hopsize / unit->mWorld->mFullRate.mBufLength)) * unit->mWorld->mFullRate.mBufLength;
	}
	unit->m_hopsize = hopsize;
	unit->m_shuntsize = unit->m_audiosize - hopsize;
    
	unit->m_inbuf = (float*)RTAlloc(unit->mWorld, audiosize);
    
//	SCWorld_Allocator alloc(ft, unit->mWorld);
//	unit->m_scfft = scfft_create(unit->m_fullbufsize, unit->m_audiosize, (SCFFT_WindowFunction)unit->m_wintype, unit->m_inbuf,
//								 unit->m_fftsndbuf->data, kForward, alloc);
//    
	memset(unit->m_inbuf, 0, audiosize);
    
	//Print("FFT_Ctor: hopsize %i, shuntsize %i, bufsize %i, wintype %i, \n",
	//	unit->m_hopsize, unit->m_shuntsize, unit->m_bufsize, unit->m_wintype);
    
	if (INRATE(1) == calc_FullRate) {
		unit->m_numSamples = unit->mWorld->mFullRate.mBufLength;
	} else {
		unit->m_numSamples = 1;
	}
    //Print("DWT_Ctor: checks %d %d %p \n",hopsize,unit->m_audiosize, unit->m_inbuf);
    
    SETCALC(DWT_next); 

}



void DWT_Dtor(DWT *unit)
{
    
	if(unit->m_inbuf)
		RTFree(unit->mWorld, unit->m_inbuf);
    
    DWTBase_Dtor(unit); 
}


// Ordinary ClearUnitOutputs outputs zero, potentially telling the IDWT (+ WV UGens) to act on buffer zero, so let's skip that:
void DWT_ClearUnitOutputs(DWT *unit, int wrongNumSamples)
{
	ZOUT0(0) = -1;
}



void DWT_next( DWT *unit, int inNumSamples ) {
	    
    float *in = IN(1);
	float *out = unit->m_inbuf + unit->m_pos + unit->m_shuntsize;
    
    // 	int numSamples = unit->mWorld->mFullRate.mBufLength;
	int numSamples = unit->m_numSamples;
    
	// copy input
	memcpy(out, in, numSamples * sizeof(float));
    
	unit->m_pos += numSamples;
    
	//bool gate = ZIN0(4) > 0.f; // Buffer shunting continues, but no FFTing
        
	if (unit->m_pos != unit->m_hopsize || !unit->m_sndbuf->data || unit->m_sndbuf->samples != unit->m_fullbufsize) {
		if(unit->m_pos == unit->m_hopsize)
			unit->m_pos = 0;
		ZOUT0(0) = -1.f;
	} else {
        
    //printf("check DWT_next %p \n",unit->m_inbuf); 
    unit->m_pos = 0;
	    
        double * target = unit->data_; 
        float * source = unit->m_inbuf; 
        
        float * outputbuffer = unit->m_sndbuf->data; 
        
        //may have to use buffers as buffers of doubles? 
        int i; 
        
        
        //enveloping, zero padding
        //int zeropad = unit->m_fullbufsize-unit->m_audiosize;
        
        //wavelets_create_window
        if (unit->m_wintype!=kRectWindow) {
            
            float * window = unit->window_; 
            
            for (i=0; i<unit->m_audiosize; ++i) 
                target[i] = source[i]*window[i]; //float to double conversion
            
        } else {
            
            for (i=0; i<unit->m_audiosize; ++i) 
                target[i] = source[i]; //float to double conversion
            
        }
        
        //zero padding; will do nothing if audiosize=fullbufsize
        for (i=unit->m_audiosize; i<unit->m_fullbufsize; ++i) 
            target[i] = 0.0;
        
     
        gsl_wavelet_transform_forward (unit->w_, target, 1, unit->m_fullbufsize, unit->work_);

        for (i=0; i<unit->m_fullbufsize; ++i) 
            outputbuffer[i] = (float)target[i]; //double to float conversion

        
        //for (i=0; i<10; ++i) 
        //    printf("DWT check val %d is %5.12f \n",i,source[i]);
        
        
        
        ZOUT0(0) = unit->m_bufnum;
	
        
		// Shunt input buf down
		memcpy(unit->m_inbuf, unit->m_inbuf + unit->m_hopsize, unit->m_shuntsize * sizeof(float));
	}

    
}


void IDWT_Ctor( IDWT* unit ) {
	
    unit->m_wintype = (int)ZIN0(1); // wintype may be used by the base ctor
	if(!DWTBase_Ctor(unit, 2, 3)){
		SETCALC(*ClearUnitOutputs);
		// These zeroes are to prevent the dtor freeing things that don't exist:
		unit->m_olabuf = 0;
		return;
	}
    
	// This will hold the transformed and progressively overlap-added data ready for outputting.
	unit->m_olabuf = (float*)RTAlloc(unit->mWorld, unit->m_audiosize * sizeof(float));
	memset(unit->m_olabuf, 0, unit->m_audiosize * sizeof(float));
    
	// "pos" will be reset to zero when each frame comes in. Until then, the following ensures silent output at first:
	unit->m_pos = 0; //unit->m_audiosize;
    
    if (unit->mCalcRate == calc_FullRate) {
		unit->m_numSamples = unit->mWorld->mFullRate.mBufLength;
	} else {
		unit->m_numSamples = 1;
	}
    
    //allow stable direct resynthesis, so doesn't right in place over buffer data
    unit->temp_ = (float*)RTAlloc(unit->mWorld, unit->m_audiosize * sizeof(float));
    
    
	SETCALC(IDWT_next);

}





void IDWT_Dtor(IDWT *unit) {
   
	if(unit->m_olabuf)
		RTFree(unit->mWorld, unit->m_olabuf);
    if(unit->temp_)
		RTFree(unit->mWorld, unit->temp_);

    DWTBase_Dtor(unit); 
}



void IDWT_next(IDWT *unit, int wrongNumSamples)
{
	float *out = OUT(0); // NB not ZOUT0
    
	// Load state from struct into local scope
	int pos     = unit->m_pos;
	int fullbufsize  = unit->m_fullbufsize;
	int audiosize = unit->m_audiosize;
    // 	int numSamples = unit->mWorld->mFullRate.mBufLength;
	int numSamples = unit->m_numSamples;
	float *olabuf = unit->m_olabuf;
	float fbufnum = ZIN0(0);
    
	// Only run the IFFT if we're receiving a new block of input data - otherwise just output data already received
	if (fbufnum >= 0.f){
		
        //printf("bufnum is %f \n",fbufnum);
        
        // Ensure it's in cartesian format, not polar
		//ToComplexApx(unit->m_fftsndbuf);
        
		float* buf = unit->m_sndbuf->data;
        float * temp = unit->temp_; 
        double * target = unit->data_;
        
        int i; 
        
        for (i=0; i<unit->m_fullbufsize; ++i) 
            target[i] = buf[i]; //float to double conversion 
        
		//scfft_doifft(unit->m_scfft);
        gsl_wavelet_transform_inverse(unit->w_,target, 1, unit->m_fullbufsize, unit->work_);
        
        if (unit->m_wintype!=kRectWindow) {
            
            float * window = unit->window_; 
            
            for (i=0; i<unit->m_audiosize; ++i) 
                temp[i] = ((float)target[i])*window[i]; //float to double conversion
            
        } else {
            
            for (i=0; i<unit->m_audiosize; ++i) 
                temp[i] = (float)target[i]; //double to float conversion 
            
        }
        
        //zero padding; will do nothing if audiosize=fullbufsize
        //for (i=unit->m_audiosize; i<unit->m_fullbufsize; ++i) 
        //    target[i] = 0.0;

        
        
//        for (i=0; i<unit->m_audiosize; ++i) 
//            buf[i] = (float)target[i]; //double to float conversion 
        
        
//        for (int i=0; i<100; ++i) {
//            printf("check val %d is %5.12f \n",i,buf[i]);
//        }
        
		// Then shunt the "old" time-domain output down by one hop
		int hopsamps = pos;
		int shuntsamps = audiosize - hopsamps;
		if(hopsamps != audiosize)  // There's only copying to be done if the position isn't all the way to the end of the buffer
			memcpy(olabuf, olabuf+hopsamps, shuntsamps * sizeof(float));
        
		// Then mix the "new" time-domain data in - adding at first, then just setting (copying) where the "old" is supposed to be zero.
//#if defined(__APPLE__) && !defined(SC_IPHONE)
//        vDSP_vadd(olabuf, 1, buf, 1, olabuf, 1, shuntsamps);
//#else
     
        //could multiply at this point with triangular reconstruction envelope
        
        // NB we re-use the "pos" variable temporarily here for write rather than read
        for(pos = 0; pos < shuntsamps; ++pos){
            olabuf[pos] += temp[pos]; //WAS +=
        }
//#endif
        
		memcpy(olabuf + shuntsamps, temp + shuntsamps, (hopsamps) * sizeof(float));
        
		// Move the pointer back to zero, which is where playback will next begin
		pos = 0;
        
	} // End of has-the-chain-fired
    
	// Now we can output some stuff, as long as there is still data waiting to be output.
	// If there is NOT data waiting to be output, we output zero. (Either irregular/negative-overlap
	//     FFT firing, or FFT has given up, or at very start of execution.)
	if(pos >= audiosize)
		ClearUnitOutputs(unit, numSamples);
	else {
		memcpy(out, olabuf + pos, numSamples * sizeof(float));
		pos += numSamples;
	}
	unit->m_pos = pos;
}

//#include "WT_UGens.h"

extern void initWT(InterfaceTable *inTable); 

PluginLoad(Wavelets) {
	
	ft = inTable;
	
    DefineDtorUnit(DWT);
	DefineDtorUnit(IDWT);
	
    initWT(ft); //WT_Units for the in chain processing

}









