#include "Macros.fxh" //from FNA
#include "Lighting.fxh"
#include "Shadow.fxh"

static const int NUM_SHADOW_CASCADES = 4;

DECLARE_TEXTURE(gPosition, 0);
DECLARE_TEXTURE(gAlbedo, 1);
DECLARE_TEXTURE(gNormal, 2);
DECLARE_TEXTURE(gMetallicRoughness, 3);
DECLARE_TEXTURE(shadowMapOne, 4);
DECLARE_TEXTURE(shadowMapTwo, 5);
DECLARE_TEXTURE(shadowMapThree, 6);
DECLARE_TEXTURE(shadowMapFour, 7);

BEGIN_CONSTANTS

    float3 EyePosition                                         _ps(c0)    _cb(c0);

    float3 DirectionalLightDirection                           _ps(c1)    _cb(c1);
    float3 DirectionalLightColor                               _ps(c2)    _cb(c2);

    float CascadeFarPlanes[NUM_SHADOW_CASCADES]                _ps(c3)    _cb(c3);

    float ShadowMapSize                                        _ps(c7)    _cb(c7);

MATRIX_CONSTANTS

    float4x4 LightSpaceMatrixOne                               _ps(c8)    _cb(c8);
    float4x4 LightSpaceMatrixTwo                               _ps(c12)   _cb(c12);
    float4x4 LightSpaceMatrixThree                             _ps(c16)   _cb(c16);
    float4x4 LightSpaceMatrixFour                              _ps(c20)   _cb(c20);

    // used to select shadow cascade
    float4x4 ViewMatrix                                        _ps(c24)   _cb(c24);

END_CONSTANTS

struct VertexInput
{
    float4 Position : POSITION;
    float2 TexCoord : TEXCOORD;
};

struct PixelInput
{
    float4 Position : SV_POSITION;
    float2 TexCoord : TEXCOORD0;
};

PixelInput main_vs(VertexInput input)
{
    PixelInput output;

    output.Position = input.Position;
    output.TexCoord = input.TexCoord;

    return output;
}

// Pixel Shader

float ComputeShadow(float3 positionWorldSpace, float3 N, float3 L)
{
    float4 positionCameraSpace = mul(float4(positionWorldSpace, 1.0), ViewMatrix);

    int shadowCascadeIndex = 0; // 0 is closest
    for (int i = 0; i < NUM_SHADOW_CASCADES; i++)
    {
        if (abs(positionCameraSpace.z) < CascadeFarPlanes[i])
        {
            shadowCascadeIndex = i;
            break;
        }
    }

    float4x4 lightSpaceMatrix;

    if (shadowCascadeIndex == 0)
    {
        lightSpaceMatrix = LightSpaceMatrixOne;
    }
    else if (shadowCascadeIndex == 1)
    {
        lightSpaceMatrix = LightSpaceMatrixTwo;
    }
    else if (shadowCascadeIndex == 2)
    {
        lightSpaceMatrix = LightSpaceMatrixThree;
    }
    else
    {
        lightSpaceMatrix = LightSpaceMatrixFour;
    }

    // PCF + Poisson soft shadows

    if (shadowCascadeIndex == 0)
    {
        return PoissonShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapOne),
            ShadowMapSize
        );
    }
    else if (shadowCascadeIndex == 1)
    {
        return PoissonShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapTwo),
            ShadowMapSize
        );
    }
    else if (shadowCascadeIndex == 2)
    {
        return PoissonShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapThree),
            ShadowMapSize
        );
    }
    else
    {
        return PoissonShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapFour),
            ShadowMapSize
        );
    }
}

float4 ComputeColor(
    float3 worldPosition,
    float3 worldNormal,
    float3 albedo,
    float metallic,
    float roughness
) {
    float3 V = normalize(EyePosition - worldPosition);
    float3 N = normalize(worldNormal);

    float3 F0 = float3(0.04, 0.04, 0.04);
    F0 = lerp(F0, albedo, metallic);

    float3 L = normalize(DirectionalLightDirection);
    float3 radiance = DirectionalLightColor;

    float shadow = ComputeShadow(worldPosition, N, L);
    float3 color = ComputeLight(L, radiance, F0, V, N, albedo, metallic, roughness, shadow);

    return float4(color, 1.0);
}

float4 main_ps(PixelInput input) : SV_TARGET0
{
    float3 worldPosition = SAMPLE_TEXTURE(gPosition, input.TexCoord).rgb;
    float3 normal = SAMPLE_TEXTURE(gNormal, input.TexCoord).xyz;
    float3 albedo = SAMPLE_TEXTURE(gAlbedo, input.TexCoord).rgb;
    float2 metallicRoughness = SAMPLE_TEXTURE(gMetallicRoughness, input.TexCoord).rg;

    return ComputeColor(
        worldPosition,
        normal,
        albedo,
        metallicRoughness.r,
        metallicRoughness.g
    );
}

Technique DeferredPBR_Directional
{
    Pass
    {
        VertexShader = compile vs_3_0 main_vs();
        PixelShader = compile ps_3_0 main_ps();
    }
}