#include "data\\shaders\\common.h"
#include "data\\shaders\\input_formats.h"
#include "data\\shaders\\lights.h"
#include "data\\shaders\\pcss.h"

StructuredBuffer<DirLight> g_dir_lights : register(t0, space1);
StructuredBuffer<float4x4> g_shadows_view_proj : register(t4, space1);

Texture2D g_depth : register(t0);
Texture2D g_shadow_maps[10] : register(t0, space2);

SamplerState g_sam_linear : register(s3);
SamplerComparisonState g_sam_shadow : register(s6);

float3 GetViewPos(float2 uv)
{
  float depth = g_depth.Sample(g_sam_linear, uv).x;
  float3 pos = ScreenToView( float4(uv, depth, 1.0f)).xyz;
  return pos;
}

// Mie scaterring approximated with Henyey-Greenstein phase function.
//https://www.alexandre-pestana.com/volumetric-lights/
float ComputeScattering(float lightDotView, float scattering)
{
float result = 1.0f - scattering * scattering;
result /= (4.0f * PI * pow(1.0f + scattering * scattering - (2.0f * scattering) * lightDotView, 1.5f));
return result;
}

float4 main(VertexTIOut pin) : SV_Target
{
  DirLight light = g_dir_lights[pin.instance_id];
  float2 half_size = g_screen_size / 2.0f;
  
  float3 pos = GetViewPos(pin.uv);
  float3 v = normalize(pos);
  float3 l = light.direction;
  float sun = pow(saturate(dot(l,v)),3.0f);
  
  float3 eye_pos = float3(0,0,0);
  float dist = length(pos);
  float3 view_dir = normalize(pos);
  
  float3 start = eye_pos;
  float3 end = pos;
  
  uint sample_count = 8;
  float step_size = distance(start,end) / sample_count;
  float3 step = view_dir * step_size;
  
  float fog_mul = light.scatt_density;
  float scatt_mul = light.scatt_multiplier;
  float scattering = light.scatt_falloff;
  
  pos = start + step * Dither(pin.uv * half_size);
  float3 accumulation = 0.0f.xxx;
  for (uint i = 0; i < sample_count; ++i)
  {
    float3 attenuation = (1.0f - exp(-pos.z*0.01f))*fog_mul + ComputeScattering(saturate(dot(l,v)),scattering)*scatt_mul;
    int shadow_index = light.shadow_map_index;
    if (shadow_index >= 0)
    {
      if (pos.z > cascade_ranges[0])
      {
        ++shadow_index;
      }
      if (pos.z > cascade_ranges[1])
      {
        ++shadow_index;
      }

      float3 pos_w = mul(float4(pos,1.0f), g_inv_view).xyz;
      float3 l_uv = WorldToScreen(pos_w, g_shadows_view_proj[shadow_index]);
      l_uv.z = saturate(l_uv.z - 0.0009f);
      
      attenuation *= g_shadow_maps[shadow_index].SampleCmpLevelZero(g_sam_shadow, l_uv.xy, l_uv.z);
    }
    
    accumulation += attenuation;
    pos += step;
  }    

  accumulation /= sample_count;

  return float4(max(accumulation,0.0f.xxx) * light.color * light.intensity, end.z);
}