#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"
#include "data\\shaders\\noise.h"

static float blend = 2.0f;

Texture2D g_color : register(t0);
Texture2D g_coc : register(t1);
Texture2D g_coc_half : register(t2);
Texture2D g_near_field : register(t3);
Texture2D g_far_field : register(t4);

SamplerState g_sam_point : register(s1);
SamplerState g_sam_linear : register(s3);

float4 main(VertexTOut pin) : SV_Target
{
  float2 dim;
  g_color.GetDimensions(dim.x,dim.y);
  float2 texel_size = 1.0f.xx / dim;

  float4 color = g_color.SampleLevel(g_sam_point, pin.uv, 0);
  
  //far
  {
    float2 uv_00 = pin.uv + 0.25f * texel_size;
    float2 uv_10 = uv_00 + float2(texel_size.x, 0.0f);
    float2 uv_01 = uv_00 + float2(0.0f, texel_size.y);
    float2 uv_11 = uv_00 + float2(texel_size.x, texel_size.y);
  
    float coc_far = g_coc.SampleLevel(g_sam_point, uv_00, 0).g;
    float4 coc_far_half = g_coc_half.GatherGreen(g_sam_point, uv_00).wzxy;
    float4 coc_far_diffs = abs(coc_far.xxxx - coc_far_half);
    
    float4 dof_far_00 = g_far_field.SampleLevel(g_sam_point, uv_00, 0);
    float4 dof_far_10 = g_far_field.SampleLevel(g_sam_point, uv_10, 0);
    float4 dof_far_01 = g_far_field.SampleLevel(g_sam_point, uv_01, 0);
    float4 dof_far_11 = g_far_field.SampleLevel(g_sam_point, uv_11, 0);
    
    float2 image_coord = (uv_00) / texel_size;
    float2 fractional = frac(image_coord);
		float a = (1.0f - fractional.x) * (1.0f - fractional.y);
		float b = fractional.x * (1.0f - fractional.y);
		float c = (1.0f - fractional.x) * fractional.y;
		float d = fractional.x * fractional.y;
    
    float4 dof_far = 0.0f.xxxx;
    float weight_sum = 0.0f;
    
    float weight00 = a / (coc_far_diffs.x + 0.001f);
    dof_far += dof_far_00 * weight00;
    weight_sum += weight00;
    
    float weight10 = b / (coc_far_diffs.y + 0.001f);
    dof_far += dof_far_10 * weight10;
    weight_sum += weight10;
    
    float weight01 = c / (coc_far_diffs.z + 0.001f);
    dof_far += dof_far_01 * weight01;
    weight_sum += weight01;
    
    float weight11 = d / (coc_far_diffs.w + 0.001f);
    dof_far += dof_far_11 * weight11;
    weight_sum += weight11;
    
    dof_far /= weight_sum;
    color = lerp(color, dof_far, saturate(coc_far * blend));
  }
  
  //near
  {
    float coc_near = g_coc_half.SampleLevel(g_sam_point, pin.uv + 1.0f*texel_size, 0).r;
    float4 dof_near = g_near_field.SampleLevel(g_sam_linear, pin.uv, 0);
    color = lerp(color, dof_near, saturate(coc_near * blend));
  }
  
  color.w = 1.0f;
  return color;
}