#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"

Texture2D g_color : register(t0);
Texture2D g_depth : register(t1);
Texture2D g_normal : register(t2);
Texture2D g_material : register(t3);
SamplerState g_sam_linear : register(s3);

static const float MaxDistance = 1000.0f;

float3 GetViewPos(float2 uv)
{
  float depth = g_depth.Sample(g_sam_linear, uv).x;
  float3 pos = ScreenToView( float4(uv, depth, 1.0f)).xyz;
  return pos;
}

bool IntersectsDepthBuffer(float z, float minZ, float maxZ)
{
	float depthScale = min(1.0f, z * g_ssr_stride_z_cutoff);
	z += g_ssr_z_thickness + lerp(0.0f, 2.0f, depthScale);
	return (maxZ >= z) && (minZ - g_ssr_z_thickness <= z);
}

float DistanceSquared(float2 a, float2 b)
{
    a -= b;
    return dot(a, a);
}

bool ScreenSpaceRayTrace(float3 origin, float3 dir,  out float2 hitPixel, out float3 hitPoint, out float iterationCount)
{
	float rayLength = MaxDistance;

	float3 endPoint = origin + dir * rayLength;

	float4 H0 = mul(float4(origin,1.0), g_proj);
	float4 H1 = mul(float4(endPoint,1.0), g_proj);

	float2 oneOverRegionSize = float2(1.0f,1.0f) / g_screen_size;

	float k0 = 1.0/H0.w;
	float k1 = 1.0/H1.w;

	float3 Q0 = origin * k0;
	float3 Q1 = endPoint * k1;

	float2 P0 = H0.xy * k0;
	P0.x = P0.x*0.5+0.5;
	P0.y = -P0.y*0.5+0.5;
	P0.xy *= g_screen_size;

	float2 P1 = H1.xy * k1;
	P1.x = P1.x*0.5+0.5;
	P1.y = -P1.y*0.5+0.5;
	P1.xy *= g_screen_size;

	P1 += (DistanceSquared(P0, P1) < 0.0001f) ? float2(0.01f, 0.01f) : 0.0f;
	float2 delta = P1 - P0;

	bool permute = false;
	if (abs(delta.x) < abs(delta.y))
	{
		permute = true;
		delta = delta.yx;
		P0 = P0.yx;
		P1 = P1.yx;
	}

	float stepDir = sign(delta.x);
	float invdx = stepDir / delta.x;

	float3 dQ = (Q1 - Q0) * invdx;
	float dk = (k1 - k0) * invdx;
	float2 dP = float2( stepDir, delta.y * invdx);

	float strideScale = 1.0f - min(1.0f, origin.z * g_ssr_stride_z_cutoff);
	float stride = 1.0f + strideScale * g_ssr_stride;

	dP *= stride;
  dQ *= stride;
  dk *= stride;

	float4 PQk = float4(P0, Q0.z, k0);
	float4 dPQk = float4(dP, dQ.z, dk);
	float3 Q = Q0;

	float end = P1.x * stepDir;

	float stepCount = 0.0f;
	float prevZMaxEstimate = origin.z;
	float rayZMin = prevZMaxEstimate;
	float rayZMax = prevZMaxEstimate;
	float sceneZMax = rayZMax + 100.0f;

	for (;
	((PQk.x * stepDir) <= end) && (stepCount < g_ssr_max_steps) &&
	!IntersectsDepthBuffer(sceneZMax, rayZMin, rayZMax) &&
	(sceneZMax != 0.0f);
	++stepCount)
	{
		rayZMin = prevZMaxEstimate;
		rayZMax = (dPQk.z * 0.5f + PQk.z) / (dPQk.w * 0.5f + PQk.w);
		prevZMaxEstimate = rayZMax;
		if (rayZMin > rayZMax)
		{
				float temp = rayZMax;
	    	rayZMax = rayZMin;
	    	rayZMin = temp;
		}

		hitPixel = permute ? PQk.yx : PQk.xy;

		float2 uv = hitPixel * oneOverRegionSize;
		sceneZMax = g_normal.SampleLevel(g_sam_linear, uv, 0).w;

		PQk += dPQk;
	}

	Q.xy += dQ.xy * stepCount;
	hitPoint = Q * ( 1.0f / PQk.w);
	iterationCount = stepCount;
	return IntersectsDepthBuffer(sceneZMax, rayZMin, rayZMax);
}

float3 hash(float3 a)
{
    a = frac(a * float3(.8,.8,.8));
    a += dot(a, a.yxz + 19.19);
    return frac((a.xxy + a.yxx)*a.zyx) - 0.5f;
}

float4 main(VertexTOut pin) : SV_Target
{
	float2 uv = pin.uv;

	float3 viewPos = GetViewPos(uv);
	float4 mat = g_material.Sample(g_sam_linear, uv);

	float metalness = mat.y;

	//if (metalness < 0.01)
	//	discard;

	float4 vNormal = g_normal.Sample(g_sam_linear, uv);
	float3 viewNormal = vNormal.xyz;
	float3 color = g_color.Sample(g_sam_linear, uv).xyz;

	float4 wp = mul(float4(viewPos.xyz, 1.0),g_inv_view);
	//wp.xyz /= wp.w;

	float3 reflected = normalize(reflect(normalize(viewPos), normalize(viewNormal)));

	float3 hitPos = viewPos;

	float roughness = mat.x;
	float3 jitt = lerp(float3(0.0,0.0,0.0), hash(wp.xyz), saturate(roughness)*1.0f);

	float2 hitPixel;
	float3 hitPoint;
	float iterationCount;

	bool intersected = ScreenSpaceRayTrace(viewPos, normalize(reflected + jitt), hitPixel, hitPoint, iterationCount );
	hitPixel = hitPixel.xy * (float2(1.0f, 1.0f) / g_screen_size);
	float2 coords = hitPixel.xy;

	float2 dCoords = smoothstep(0.3, 0.5, abs(float2(0.5, 0.5) - coords.xy));
  float screenEdgefactor = clamp(1.0 - (dCoords.x + dCoords.y), 0.0, 1.0);

	float2 uvRefl = coords.xy;
	float3 col = g_color.Sample(g_sam_linear,uvRefl).xyz;

	float ReflectionMultiplier = pow(metalness, 3.0f) *
                screenEdgefactor *
                reflected.z;

	float alpha = clamp(ReflectionMultiplier, 0.0, 0.9) * (intersected ? 1.0f : 0.0f);
	alpha *= 1.0 - clamp( distance( viewPos, hitPoint) / MaxDistance, 0.0, 1.0);
	alpha *= 1.0 - (iterationCount / g_ssr_max_steps);

	return float4(col , alpha);
}
