#include "Macros.fxh"
#include "Shadow.fxh"
#include "Dither.fxh"

static const int NUM_SHADOW_CASCADES = 4;

DECLARE_TEXTURE(gPosition, 0);
DECLARE_TEXTURE(gAlbedo, 1);
DECLARE_TEXTURE(gNormal, 2);
DECLARE_TEXTURE(gMetallicRoughness, 3);
DECLARE_TEXTURE(shadowMapOne, 4);
DECLARE_TEXTURE(shadowMapTwo, 5);
DECLARE_TEXTURE(shadowMapThree, 6);
DECLARE_TEXTURE(shadowMapFour, 7);

BEGIN_CONSTANTS

float3 EyePosition                           _ps(c0)     _cb(c0);

float3 DirectionalLightDirection             _ps(c1)     _cb(c1);
float3 DirectionalLightColor                 _ps(c2)     _cb(c2);

float CascadeFarPlanes[NUM_SHADOW_CASCADES]  _ps(c4)     _cb(c4);

float ShadowMapSize                          _ps(c8)     _cb(c8);

MATRIX_CONSTANTS

float4x4 LightSpaceMatrixOne                 _ps(c9)     _cb(c9);
float4x4 LightSpaceMatrixTwo                 _ps(c13)    _cb(c13);
float4x4 LightSpaceMatrixThree               _ps(c17)    _cb(c17);
float4x4 LightSpaceMatrixFour                _ps(c21)    _cb(c21);

float4x4 ViewMatrix                          _ps(c25)    _cb(c25);

END_CONSTANTS

struct VertexInput
{
    float4 Position : POSITION;
    float2 TexCoord : TEXCOORD;
};

struct PixelInput
{
    float4 Position : SV_POSITION;
    float2 TexCoord : TEXCOORD0;
};

PixelInput main_vs(VertexInput input)
{
    PixelInput output;

    output.Position = input.Position;
    output.TexCoord = input.TexCoord;

    return output;
}

float ComputeShadow(float3 positionWorldSpace, float3 N, float3 L)
{
    float4 positionCameraSpace = mul(float4(positionWorldSpace, 1.0), ViewMatrix);

    int shadowCascadeIndex = 0; // 0 is closest
    for (int i = 0; i < NUM_SHADOW_CASCADES; i++)
    {
        if (abs(positionCameraSpace.z) < CascadeFarPlanes[i])
        {
            shadowCascadeIndex = i;
            break;
        }
    }

    float4x4 lightSpaceMatrix;

    if (shadowCascadeIndex == 0)
    {
        lightSpaceMatrix = LightSpaceMatrixOne;
    }
    else if (shadowCascadeIndex == 1)
    {
        lightSpaceMatrix = LightSpaceMatrixTwo;
    }
    else if (shadowCascadeIndex == 2)
    {
        lightSpaceMatrix = LightSpaceMatrixThree;
    }
    else
    {
        lightSpaceMatrix = LightSpaceMatrixFour;
    }

    if (shadowCascadeIndex == 0)
    {
        return HardShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapOne),
            ShadowMapSize
        );
    }
    else if (shadowCascadeIndex == 1)
    {
        return HardShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapTwo),
            ShadowMapSize
        );
    }
    else if (shadowCascadeIndex == 2)
    {
        return HardShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapThree),
            ShadowMapSize
        );
    }
    else
    {
        return HardShadow(
            positionWorldSpace,
            N,
            L,
            lightSpaceMatrix,
            SAMPLER(shadowMapFour),
            ShadowMapSize
        );
    }
}

float IntensityBanding(float NdotL)
{
    // if (NdotL > 0.5)
    // {
    //     return 1.0;
    // }
    // else if (NdotL > 0.25)
    // {
    //     return 0.5;
    // }
    // else if (NdotL > 0.0)
    // {
    //     return 0.25;
    // }
    // else
    // {
    //     return 0.0;
    // }
    if (NdotL > 0)
    {
        return 1.0;
    }
    else
    {
        return 0.25;
    }
}

float4 FlatShadow(PixelInput input) : SV_TARGET0
{
    float2 screenPosition = input.Position.xy;
    float3 worldPosition = SAMPLE_TEXTURE(gPosition, input.TexCoord).rgb;
    float3 normal = SAMPLE_TEXTURE(gNormal, input.TexCoord).xyz;
    float3 albedo = SAMPLE_TEXTURE(gAlbedo, input.TexCoord).rgb;
    float2 metallicRoughness = SAMPLE_TEXTURE(gMetallicRoughness, input.TexCoord).rg;

    // the lower the glossiness, the sharper the specular highlight
    float glossiness = lerp(64, 16, 1.0 - metallicRoughness.r);

    float3 V = normalize(EyePosition - worldPosition);
    float3 L = normalize(DirectionalLightDirection);
    float3 N = normalize(normal);
    float3 H = normalize(V + L);

    float NdotL = dot(N, L);
    float NdotH = max(dot(N, H), 0.0);

    float lightIntensity = IntensityBanding(NdotL);
    float3 light = lightIntensity * DirectionalLightColor;

    float specularIntensity = pow(NdotH * lightIntensity, glossiness * glossiness);
    float specularSmooth = smoothstep(0.005, 0.01, specularIntensity);

    float3 specular = specularSmooth * float3(1.0, 1.0, 1.0);

    if (metallicRoughness.r == 0.0) { specular = float3(0.0, 0.0, 0.0); }

    float3 rimColor = float3(1.0, 1.0, 1.0);
    float rimThreshold = 0.1;
    float rimAmount = 1 - metallicRoughness.g;
    float rimDot = 1 - dot(V, N);
    float rimIntensity = rimDot * pow(max(NdotL, 0.0), rimThreshold);
    rimIntensity = smoothstep(rimAmount - 0.01, rimAmount + 0.01, rimIntensity);
    float3 rim = rimIntensity * rimColor;

    float shadow = ComputeShadow(worldPosition, N, L);
    float3 color = albedo * (light + specular + rim) * shadow;

    return float4(color, 1.0);
}

// FIXME: organize this
float4 DitheredShadow(PixelInput input) : SV_TARGET0
{
    float2 screenPosition = input.Position.xy;
    float3 worldPosition = SAMPLE_TEXTURE(gPosition, input.TexCoord).rgb;
    float3 normal = SAMPLE_TEXTURE(gNormal, input.TexCoord).xyz;
    float3 albedo = SAMPLE_TEXTURE(gAlbedo, input.TexCoord).rgb;
    float2 metallicRoughness = SAMPLE_TEXTURE(gMetallicRoughness, input.TexCoord).rg;

    // the lower the glossiness, the sharper the specular highlight
    float glossiness = lerp(64, 16, 1.0 - metallicRoughness.r);

    float3 V = normalize(EyePosition - worldPosition);
    float3 L = normalize(DirectionalLightDirection);
    float3 N = normalize(normal);
    float3 H = normalize(V + L);

    float NdotL = dot(N, L);
    float NdotH = max(dot(N, H), 0.0);

    float lightIntensity = IntensityBanding(NdotL);
    //float3 light = lightIntensity * DirectionalLightColor;
    float3 light = DirectionalLightColor;

    if (lightIntensity < 1)
    {
        light *= dither(lightIntensity, screenPosition);
    }

    float specularIntensity = pow(NdotH * lightIntensity, glossiness * glossiness);
    float specularSmooth = smoothstep(0.005, 0.01, specularIntensity);

    float3 specular = specularSmooth * float3(1.0, 1.0, 1.0);

    if (metallicRoughness.r == 0.0) { specular = float3(0.0, 0.0, 0.0); }

    float3 rimColor = float3(1.0, 1.0, 1.0);
    float rimThreshold = 0.1;
    float rimAmount = 1 - metallicRoughness.g;
    float rimDot = 1 - dot(V, N);
    float rimIntensity = rimDot * pow(max(NdotL, 0.0), rimThreshold);
    rimIntensity = smoothstep(rimAmount - 0.01, rimAmount + 0.01, rimIntensity);
    float3 rim = rimIntensity * rimColor;

    float shadow = ComputeShadow(worldPosition, N, L);
    float3 color = albedo * (light + specular + rim); // * shadow;

    if (shadow < 1)
    {
        color *= dither(shadow, screenPosition);
    }

    return float4(color, 1.0);
}

PixelShader PSArray[2] =
{
    compile ps_3_0 FlatShadow(),
    compile ps_3_0 DitheredShadow()
};

int PSIndices[2] =
{
    0, 1
};

int ShaderIndex = 0;

Technique Deferred_Toon
{
    Pass
    {
        VertexShader = compile vs_3_0 main_vs();
        PixelShader = (PSArray[PSIndices[ShaderIndex]]);
    }
}