#include "Macros.fxh"

DECLARE_TEXTURE(AlbedoTexture, 0);
DECLARE_TEXTURE(NormalTexture, 1);
DECLARE_TEXTURE(MetallicRoughnessTexture, 2);

BEGIN_CONSTANTS

    float3 AlbedoValue                         _ps(c0)    _cb(c0);
    float MetallicValue                        _ps(c1)    _cb(c1);
    float RoughnessValue                       _ps(c2)    _cb(c2);

MATRIX_CONSTANTS

    float4x4 World                  _vs(c0)               _cb(c7);
    float4x4 WorldInverseTranspose  _vs(c4)               _cb(c11);
    float4x4 WorldViewProjection    _vs(c8)               _cb(c15);

END_CONSTANTS

struct VertexInput
{
    float4 Position : POSITION;
    float3 Normal   : NORMAL;
    float2 TexCoord : TEXCOORD0;
};

struct PixelInput
{
    float4 Position      : SV_POSITION;
    float3 PositionWorld : TEXCOORD0;
    float3 NormalWorld   : TEXCOORD1;
    float2 TexCoord      : TEXCOORD2;
};

struct PixelOutput
{
    float4 gPosition          : COLOR0;
    float4 gNormal            : COLOR1;
    float4 gAlbedo            : COLOR2;
    float4 gMetallicRoughness : COLOR3;
};

// Vertex Shader

PixelInput main_vs(VertexInput input)
{
    PixelInput output;

    output.PositionWorld = mul(input.Position, World).xyz;
    output.NormalWorld = mul(input.Normal, (float3x3)WorldInverseTranspose).xyz;
    output.TexCoord = input.TexCoord;
    output.Position = mul(input.Position, WorldViewProjection);

    return output;
}

// Pixel Shaders

// Easy trick to get tangent-normals to world-space to keep PBR code simplified.
float3 GetNormalFromMap(float3 worldPos, float2 texCoords, float3 normal)
{
    float3 tangentNormal = SAMPLE_TEXTURE(NormalTexture, texCoords).xyz * 2.0 - 1.0;

    float3 Q1  = ddx(worldPos);
    float3 Q2  = ddy(worldPos);
    float2 st1 = ddx(texCoords);
    float2 st2 = ddy(texCoords);

    float3 N   = normalize(normal);
    float3 T  = normalize(Q1*st2.y - Q2*st1.y);
    float3 B  = -normalize(cross(N, T));
    float3x3 TBN = float3x3(T, B, N);

    return normalize(mul(tangentNormal, TBN));
}

PixelOutput NonePS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(normalize(input.NormalWorld), 0.0);
    output.gAlbedo = float4(AlbedoValue, 1.0);
    output.gMetallicRoughness = float4(MetallicValue, RoughnessValue, 0.0, 0.0);

    return output;
}

PixelOutput AlbedoPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(normalize(input.NormalWorld), 0.0);
    output.gAlbedo = SAMPLE_TEXTURE(AlbedoTexture, input.TexCoord);
    output.gMetallicRoughness = float4(MetallicValue, RoughnessValue, 0.0, 0.0);

    return output;
}

PixelOutput MetallicRoughnessPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(normalize(input.NormalWorld), 0.0);
    output.gAlbedo = float4(AlbedoValue, 1.0);
    output.gMetallicRoughness = SAMPLE_TEXTURE(MetallicRoughnessTexture, input.TexCoord);

    return output;
}

PixelOutput NormalPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(GetNormalFromMap(input.PositionWorld, input.TexCoord, input.NormalWorld), 0.0);
    output.gAlbedo = float4(AlbedoValue, 1.0);
    output.gMetallicRoughness = float4(MetallicValue, RoughnessValue, 0.0, 0.0);

    return output;
}

PixelOutput AlbedoMetallicRoughnessPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(normalize(input.NormalWorld), 0.0);
    output.gAlbedo = SAMPLE_TEXTURE(AlbedoTexture, input.TexCoord);
    output.gMetallicRoughness = SAMPLE_TEXTURE(MetallicRoughnessTexture, input.TexCoord);

    return output;
}

PixelOutput AlbedoNormalPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(GetNormalFromMap(input.PositionWorld, input.TexCoord, input.NormalWorld), 0.0);
    output.gAlbedo = SAMPLE_TEXTURE(AlbedoTexture, input.TexCoord);
    output.gMetallicRoughness = float4(MetallicValue, RoughnessValue, 0.0, 0.0);

    return output;
}

PixelOutput MetallicRoughnessNormalPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(GetNormalFromMap(input.PositionWorld, input.TexCoord, input.NormalWorld), 0.0);
    output.gAlbedo = float4(AlbedoValue, 1.0);
    output.gMetallicRoughness = SAMPLE_TEXTURE(MetallicRoughnessTexture, input.TexCoord);

    return output;
}

PixelOutput AlbedoMetallicRoughnessNormalMapPS(PixelInput input)
{
    PixelOutput output;

    output.gPosition = float4(input.PositionWorld, 0.0);
    output.gNormal = float4(GetNormalFromMap(input.PositionWorld, input.TexCoord, input.NormalWorld), 0.0);
    output.gAlbedo = SAMPLE_TEXTURE(AlbedoTexture, input.TexCoord);
    output.gMetallicRoughness = SAMPLE_TEXTURE(MetallicRoughnessTexture, input.TexCoord);

    return output;
}

PixelShader PSArray[8] =
{
    compile ps_3_0 NonePS(),

    compile ps_3_0 AlbedoPS(),
    compile ps_3_0 MetallicRoughnessPS(),
    compile ps_3_0 NormalPS(),

    compile ps_3_0 AlbedoMetallicRoughnessPS(),
    compile ps_3_0 AlbedoNormalPS(),
    compile ps_3_0 MetallicRoughnessNormalPS(),

    compile ps_3_0 AlbedoMetallicRoughnessNormalMapPS()
};

int PSIndices[8] =
{
    0, 1, 2, 3, 4, 5, 6, 7
};

int ShaderIndex = 0;

Technique GBuffer
{
    Pass
    {
        VertexShader = compile vs_3_0 main_vs();
        PixelShader = (PSArray[PSIndices[ShaderIndex]]);
    }
}