#include "data\\shaders\\common.h"

Texture2D g_coc : register(t0);
Texture2D g_color : register(t1);

RWTexture2D<float2> g_half_coc : register(u0);
RWTexture2D<float4> g_half_color : register(u1);
RWTexture2D<float4> g_half_far_color : register(u2);

SamplerState g_sam_point : register(s1);
SamplerState g_sam_linear : register(s3);

[numthreads(8, 8, 1)]
void main(uint3 globalID : SV_DispatchThreadID, uint3 localID : SV_GroupThreadID, uint localIndex : SV_GroupIndex, uint3 groupID : SV_GroupID)
{
  float2 dim;
  g_color.GetDimensions(dim.x,dim.y);
  float2 texel_size = 1.0f.xx / dim;

  //float2 uv = (((float2)globalID.xy+0.5f)) / dim*2;
  float2 uv = (((float2)globalID.xy+0.25f)*2.0) / dim;

  float2 coc = g_coc.SampleLevel(g_sam_point, uv, 0).rg;
  float4 color = g_color.SampleLevel(g_sam_linear, uv, 0);
  
  float2 tex_coord_00 = uv + float2(-0.25f, -0.25f) * texel_size;
  float2 tex_coord_10 = uv + float2( 0.25f, -0.25f) * texel_size;
  float2 tex_coord_01 = uv + float2(-0.25f,  0.25f) * texel_size;
  float2 tex_coord_11 = uv + float2( 0.25f,  0.25f) * texel_size;
  
  float coc_far_00 = g_coc.SampleLevel(g_sam_point, tex_coord_00, 0).g;
  float coc_far_10 = g_coc.SampleLevel(g_sam_point, tex_coord_10, 0).g;
  float coc_far_01 = g_coc.SampleLevel(g_sam_point, tex_coord_01, 0).g;
  float coc_far_11 = g_coc.SampleLevel(g_sam_point, tex_coord_11, 0).g;
  
  float weight_00 = 1000.0f;
  float4 color_mul_far = weight_00 * g_color.SampleLevel(g_sam_linear, tex_coord_00, 0);
  float weight_sum = weight_00;
  
  float weight_10 = 1.0f / (abs(coc_far_00 - coc_far_10) + 0.001f);
  color_mul_far += weight_10 * g_color.SampleLevel(g_sam_linear, tex_coord_10, 0);
  weight_sum += weight_10;
  
  float weight_01 = 1.0f / (abs(coc_far_00 - coc_far_01) + 0.001f);
  color_mul_far += weight_01 * g_color.SampleLevel(g_sam_linear, tex_coord_01, 0);
  weight_sum += weight_01;
  
  float weight_11 = 1.0f / (abs(coc_far_00 - coc_far_11) + 0.001f);
  color_mul_far += weight_11 * g_color.SampleLevel(g_sam_linear, tex_coord_11, 0);
  weight_sum += weight_11;
  
  color_mul_far /= weight_sum;
  color_mul_far *= coc.g;
  
  g_half_coc[globalID.xy] = float2(coc.r, coc.g);
  g_half_color[globalID.xy] = color;
  g_half_far_color[globalID.xy] = color_mul_far;
}