
// Make a motion compensate temporal denoiser

// See legal notice in Copying.txt for more information

// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA, or visit
// http://www.gnu.org/copyleft/gpl.html .

#include "MVDenoise.h"
#include "CopyCode.h"
#include "Padding.h"

MVDenoise::MVDenoise(PClip _child, PClip mvsbw, PClip mvsfw, int thT, int thMC, int sadT, int thMV, const MotionParameters &params, IScriptEnvironment* env) :
GenericMotionFilter(_child, params, mvsbw->GetVideoInfo(), env), mvclipBW(mvsbw), mvclipFW(mvsfw)
{
	CreateFGOP(&fgopCN);
	CreateFGOP(&fgopCP);
	CreateFGOP(&fgopNoN);
	CreateFGOP(&fgopPoP);

	/*
		We give the thresholds in a comprehensive way, so adjustments have to be made
	*/
	TemporalThreshold = thT;
	if ( TemporalThreshold < 0 ) TemporalThreshold = 0;

	MotionCoherenceThreshold = thMC * thMC * nPel * nPel;

	SADThreshold = sadT;

	MVThreshold = thMV * thMV * nPel * nPel;

	pitch =  nWidth + 2 * nHPadding;

	offset = nVPadding * pitch + nHPadding;

    size = pitch * ( nHeight + 2 * nVPadding );

	pFrame = new unsigned char[size];
	nFrame = new unsigned char[size];
	cFrame = new unsigned char[size];
	ppFrame = new unsigned char[size];
	nnFrame = new unsigned char[size];
	previousFrame = new unsigned char[size];
	nextFrame = new unsigned char[size];

	previousMask = new unsigned char[size];
	nextMask = new unsigned char[size];


}

MVDenoise::~MVDenoise()
{
	delete[] pFrame;
	delete[] nFrame;
	delete[] ppFrame;
	delete[] nnFrame;
	delete[] cFrame;
	delete[] previousFrame;
	delete[] nextFrame;
	delete[] previousMask;
	delete[] nextMask;
}

#define SATURATE(x,a,b) (((x) < (a)) ? (a) : (((x) > (b)) ? (b) : (x)))

void AddLuma(unsigned char *pDst, int nDstPitch, int luma)
{
    for ( int j = 0; j < 8; j++ )
    {
        for ( int i = 0; i < 8; i++ )
            pDst[i] = SATURATE(pDst[i] + luma, 0, 255);
        pDst += nDstPitch;
    }
}

void MVDenoise::DenoiseBlock(unsigned char *derp, int der_pitch)
{
	const FakePlaneOfBlocks &cpBlocks = fgopCP->GetPlane(0);
	const FakePlaneOfBlocks &cnBlocks = fgopCN->GetPlane(0);
	const FakePlaneOfBlocks &popBlocks = fgopPoP->GetPlane(0);
	const FakePlaneOfBlocks &nonBlocks = fgopNoN->GetPlane(0);

	bool doP = IsUsable(fgopCP);
	bool doN = IsUsable(fgopCN);
	bool doPoP = IsUsable(fgopPoP) && doP;
	bool doNoN = IsUsable(fgopNoN) && doN;

	int bw = cpBlocks.GetBlockSize();
	int bh = cpBlocks.GetBlockSize();

	previousFrame += offset;
	nextFrame += offset;
	previousMask += offset;
	nextMask += offset;
	pFrame += offset;
	ppFrame += offset;
	nFrame += offset;
	nnFrame += offset;

	if ( doPoP )
	{
		for ( int i = 0; i < cpBlocks.GetBlockCount(); i++ )
		{
			int x = cpBlocks[i].GetX();
			int y = cpBlocks[i].GetY();

			if ((popBlocks[i].GetSAD() < SADThreshold) && (popBlocks[i].GetMVLength() < MVThreshold))
			{
				BitBlt(previousFrame + x + y * pitch, pitch,
					ppFrame + x + popBlocks[i].GetMV().x + ( y + popBlocks[i].GetMV().y ) * pitch,
					pitch, bw, bh, isse);
                AddLuma(previousFrame + x + y * pitch, pitch, (popBlocks[i].GetLuma() - popBlocks[i].GetRefLuma()) / 64);
				MemZoneSet(previousMask, 1, bw, bh, x, y, pitch);
			}
			else MemZoneSet(previousMask, 0, bw, bh, x, y, pitch); 
		}
	}
	else MemZoneSet(previousMask, 0, nWidth, nHeight, 0, 0, pitch); 

	if ( doNoN )
	{
		for ( int i = 0; i < cpBlocks.GetBlockCount(); i++ )
		{
			int x = cpBlocks[i].GetX();
			int y = cpBlocks[i].GetY();

			if ((nonBlocks[i].GetSAD() < SADThreshold) && (nonBlocks[i].GetMVLength() < MVThreshold))
			{
				BitBlt(nextFrame + x + y * pitch, pitch,
					nnFrame + x + nonBlocks[i].GetMV().x + ( y + nonBlocks[i].GetMV().y ) * pitch,
					pitch, bw, bh, isse);
                AddLuma(nextFrame + x + y * pitch, pitch, (nonBlocks[i].GetLuma() - nonBlocks[i].GetRefLuma()) / 64);
				MemZoneSet(nextMask, 1, bw, bh, x, y, pitch);
			}
			else MemZoneSet(nextMask, 0, bw, bh, x, y, pitch); 
		}
	}
	else MemZoneSet(nextMask, 0, nWidth, nHeight, 0, 0, pitch); 

	Padding::PadReferenceFrame(previousFrame - offset,
		pitch, nHPadding, nVPadding, nWidth, nHeight);
	Padding::PadReferenceFrame(nextFrame - offset,
		pitch, nHPadding, nVPadding, nWidth, nHeight);
	Padding::PadReferenceFrame(previousMask - offset,
		pitch, nHPadding, nVPadding, nWidth, nHeight);
	Padding::PadReferenceFrame(nextMask - offset,
		pitch, nHPadding, nVPadding, nWidth, nHeight);

	for ( int i = 0; i < cpBlocks.GetBlockCount(); i++ )
	{
		int x = cpBlocks[i].GetX();
		int y = cpBlocks[i].GetY();
		unsigned char *d = derp + x + y * der_pitch;
		const unsigned char *pbp, *nbn, *pbpp, *nbnn;
		const unsigned char *pm = previousMask + x + y * pitch;
		const unsigned char *nm = nextMask + x + y * pitch;
        int pdelta = 0;
        int ndelta = 0;

		int prevPitch, nextPitch; 

		if ( doP && (cpBlocks[i].GetSAD() < SADThreshold) && (cpBlocks[i].GetMVLength() < MVThreshold))
		{
            pbp = pFrame + x + cpBlocks[i].GetMV().x + ( y + cpBlocks[i].GetMV().y ) * pitch;
			pbpp = previousFrame + x + cpBlocks[i].GetMV().x + ( y + cpBlocks[i].GetMV().y ) * pitch;
			prevPitch = pitch;
            pdelta = (cpBlocks[i].GetLuma() - cpBlocks[i].GetRefLuma()) / 64;
		}
		else {
			pbp = d;
			pbpp = d;
			prevPitch = der_pitch;
		}
		if ( doN && (cnBlocks[i].GetSAD() < SADThreshold) && (cnBlocks[i].GetMVLength() < MVThreshold))
		{
			nbn = nFrame + x + cnBlocks[i].GetMV().x + ( y + cnBlocks[i].GetMV().y ) * pitch;
            nbnn = nextFrame + x + cnBlocks[i].GetMV().x + ( y + cnBlocks[i].GetMV().y ) * pitch;
			nextPitch = pitch;
            ndelta = (cnBlocks[i].GetLuma() - cnBlocks[i].GetRefLuma()) / 64;
		}
		else {
			nbn = d;
			nbnn = d;
			nextPitch = der_pitch;
		}
		for ( int k = 0; k < cpBlocks.GetBlockSize(); k++ )
		{
			for ( int l = 0; l < cpBlocks.GetBlockSize(); l++ )
			{
				int count = 1;
				int newpix = ((int)d[l]);
				if (MABS((int)d[l] - (int)pbp[l] - pdelta) < TemporalThreshold)
				{
					newpix += (int)pbp[l] + pdelta;
					count += 1;
				}
				if (MABS((int)d[l] - (int)nbn[l] - ndelta) < TemporalThreshold)
				{
					newpix += (int)nbn[l] + ndelta;
					count += 1;
				}
				if ((MABS((int)d[l] - (int)pbpp[l] - pdelta) < TemporalThreshold) && (pm[l]))
				{
					newpix += pbpp[l] + pdelta;
					count++;
				}
				if ((MABS((int)d[l] - (int)nbnn[l] - ndelta) < TemporalThreshold) && (nm[l]))
				{
					newpix += nbnn[l] + ndelta;
					count++;
				}

				d[l] = (newpix + (count / 2 )) / count;
			}
			d += der_pitch;
			pbpp += prevPitch;
			nbnn += nextPitch;
			pbp += prevPitch;
			nbn += nextPitch;
			nm += pitch;
			pm += pitch;

		}
	}

	previousFrame -= offset;
	nextFrame -= offset;
	previousMask -= offset;
	nextMask -= offset;
	pFrame -= offset;
	ppFrame -= offset;
	nFrame -= offset;
	nnFrame -= offset;
}

PVideoFrame __stdcall MVDenoise::GetFrame(int n, IScriptEnvironment* env)
{
	PVideoFrame			src			 = child->GetFrame(n, env);
	const unsigned char	*srcp_y		 = src->GetReadPtr(PLANAR_Y);
	const unsigned char	*srcp_u		 = src->GetReadPtr(PLANAR_U);
	const unsigned char	*srcp_v		 = src->GetReadPtr(PLANAR_V);
	const int			src_pitch_y  = src->GetPitch(PLANAR_Y);
	const int			src_pitch_uv = src->GetPitch(PLANAR_U);

	PVideoFrame			der			 = env->NewVideoFrame(vi);
	unsigned char		*derp_y		 = der->GetWritePtr(PLANAR_Y);
	unsigned char		*derp_u		 = der->GetWritePtr(PLANAR_U);
	unsigned char		*derp_v		 = der->GetWritePtr(PLANAR_V);
	const int			der_pitch_y	 = der->GetPitch(PLANAR_Y);
	const int			der_pitch_uv = der->GetPitch(PLANAR_U);


	env->BitBlt(derp_u, der_pitch_uv, srcp_u, src_pitch_uv, nWidth >> 1, nHeight >> 1);
	env->BitBlt(derp_v, der_pitch_uv, srcp_v, src_pitch_uv, nWidth >> 1, nHeight >> 1);
	env->BitBlt(derp_y, der_pitch_y, srcp_y, src_pitch_y, nWidth, nHeight);

	if (( n > 1 ) && ( n < vi.num_frames - 2 ))
	{
		GetVectorStream(n, env, mvclipBW, fgopCN);
		GetVectorStream(n, env, mvclipFW, fgopCP);
		GetVectorStream(n-1, env, mvclipFW, fgopPoP);
		GetVectorStream(n+1, env, mvclipBW, fgopNoN);

		const unsigned char *srcpp_y = child->GetFrame(n - 1, env)->GetReadPtr(PLANAR_Y);
		const unsigned char *srcpn_y = child->GetFrame(n + 1, env)->GetReadPtr(PLANAR_Y);
		const unsigned char *srcppp_y = child->GetFrame(n - 2, env)->GetReadPtr(PLANAR_Y);
		const unsigned char *srcpnn_y = child->GetFrame(n + 2, env)->GetReadPtr(PLANAR_Y);

		//env->BitBlt(cFrame + horizontalPadding + verticalPadding * pitch,
		//	pitch, srcp_y, src_pitch_y, width, height);
		env->BitBlt(pFrame + offset,
			pitch, srcpp_y, src_pitch_y, nWidth, nHeight);
		env->BitBlt(nFrame + offset,
			pitch, srcpn_y, src_pitch_y, nWidth, nHeight);
		env->BitBlt(ppFrame + offset,
			pitch, srcppp_y, src_pitch_y, nWidth, nHeight);
		env->BitBlt(nnFrame + offset,
			pitch, srcpnn_y, src_pitch_y, nWidth, nHeight);

		//Padding::PadReferenceFrame(cFrame, pitch, horizontalPadding, verticalPadding, width, height);
		Padding::PadReferenceFrame(pFrame, pitch, nHPadding, nVPadding, nWidth, nHeight);
		Padding::PadReferenceFrame(nFrame, pitch, nHPadding, nVPadding, nWidth, nHeight);
		Padding::PadReferenceFrame(ppFrame, pitch, nHPadding, nVPadding, nWidth, nHeight);
		Padding::PadReferenceFrame(nnFrame, pitch, nHPadding, nVPadding, nWidth, nHeight);

		DenoiseBlock(derp_y, der_pitch_y);
	}
	else {
		env->BitBlt(derp_y, der_pitch_y, srcp_y, src_pitch_y, nWidth, nHeight);
	}

	return der;
}