#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <llvm-c/Core.h>
#include <llvm-c/Analysis.h>
#include <llvm-c/BitWriter.h>

#include "y.tab.h"
#include "ast.h"
#include "stack.h"

extern FILE *yyin;
Stack *stack;
Node *rootNode;

typedef struct LocalVariable
{
    char *name;
    LLVMValueRef pointer;
} LocalVariable;

typedef struct FunctionArgument
{
    char *name;
    LLVMValueRef value;
} FunctionArgument;

typedef struct ScopeFrame
{
    LocalVariable *localVariables;
    uint32_t localVariableCount;

    FunctionArgument *arguments;
    uint32_t argumentCount;
} ScopeFrame;

typedef struct Scope
{
    ScopeFrame *scopeStack;
    uint32_t scopeStackCount;
} Scope;

Scope *scope;

typedef struct StructTypeField
{
    char *name;
    uint32_t index;
} StructTypeField;

typedef struct StructTypeFieldDeclaration
{
    LLVMTypeRef structType;
    StructTypeField *fields;
    uint32_t fieldCount;

} StructTypeFieldDeclaration;

StructTypeFieldDeclaration *structTypeFieldDeclarations;
uint32_t structTypeFieldDeclarationCount;

static Scope* CreateScope()
{
    Scope *scope = malloc(sizeof(Scope));

    scope->scopeStack = malloc(sizeof(ScopeFrame));
    scope->scopeStack[0].argumentCount = 0;
    scope->scopeStack[0].arguments = NULL;
    scope->scopeStack[0].localVariableCount = 0;
    scope->scopeStack[0].localVariables = NULL;
    scope->scopeStackCount = 1;

    return scope;
}

static void PushScopeFrame(Scope *scope)
{
    uint32_t index = scope->scopeStackCount;
    scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * (scope->scopeStackCount + 1));
    scope->scopeStack[index].argumentCount = 0;
    scope->scopeStack[index].arguments = NULL;
    scope->scopeStack[index].localVariableCount = 0;
    scope->scopeStack[index].localVariables = NULL;

    scope->scopeStackCount += 1;
}

static void PopScopeFrame(Scope *scope)
{
    uint32_t i;
    uint32_t index = scope->scopeStackCount - 1;

    if (scope->scopeStack[index].arguments != NULL)
    {
        for (i = 0; i < scope->scopeStack[index].argumentCount; i += 1)
        {
            free(scope->scopeStack[index].arguments[i].name);
        }
        free(scope->scopeStack[index].arguments);
    }

    if (scope->scopeStack[index].localVariables != NULL)
    {
        for (i = 0; i < scope->scopeStack[index].localVariableCount; i += 1)
        {
            free(scope->scopeStack[index].localVariables[i].name);
        }
        free(scope->scopeStack[index].localVariables);
    }

    scope->scopeStackCount -= 1;

    scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
}

static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name)
{
    ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
    uint32_t index = scopeFrame->localVariableCount;

    scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
    scopeFrame->localVariables[index].name = strdup(name);
    scopeFrame->localVariables[index].pointer = pointer;

    scopeFrame->localVariableCount += 1;
}

static void AddFunctionArgument(Scope *scope, LLVMValueRef value, char *name)
{
    ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
    uint32_t index = scopeFrame->argumentCount;

    scopeFrame->arguments = realloc(scopeFrame->arguments, sizeof(FunctionArgument) * (scopeFrame->argumentCount + 1));
    scopeFrame->arguments[index].name = strdup(name);
    scopeFrame->arguments[index].value = value;

    scopeFrame->argumentCount += 1;
}

static LLVMValueRef FindStructFieldPointer(LLVMBuilderRef builder, LLVMValueRef structPointer, char *name)
{
    int32_t i, j;

    LLVMTypeRef structType = LLVMTypeOf(structPointer);

    for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
    {
        if (structTypeFieldDeclarations[i].structType == structType)
        {
            for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1)
            {
                if (strcmp(structTypeFieldDeclarations[i].fields[j].name, name) == 0)
                {
                    char *ptrName = strdup(name);
                    strcat(ptrName, "_ptr");
                    return LLVMBuildStructGEP(
                        builder,
                        structPointer,
                        structTypeFieldDeclarations[i].fields[j].index,
                        ptrName
                    );
                    free(ptrName);
                }
            }
        }
    }

    printf("Failed to find struct field pointer!");
    return NULL;
}

static LLVMValueRef FindVariablePointer(char *name)
{
    int32_t i, j;

    for (i = scope->scopeStackCount - 1; i >= 0; i -= 1)
    {
        for (j = 0; j < scope->scopeStack[i].localVariableCount; j += 1)
        {
            if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
            {
                return scope->scopeStack[i].localVariables[j].pointer;
            }
        }
    }

    printf("Failed to find variable pointer!");
    return NULL;
}

static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
{
    int32_t i, j;

    for (i = scope->scopeStackCount - 1; i >= 0; i -= 1)
    {
        for (j = 0; j < scope->scopeStack[i].argumentCount; j += 1)
        {
            if (strcmp(scope->scopeStack[i].arguments[j].name, name) == 0)
            {
                return scope->scopeStack[i].arguments[j].value;
            }
        }

        for (j = 0; j < scope->scopeStack[i].localVariableCount; j += 1)
        {
            if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
            {
                return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
            }
        }
    }

    printf("Failed to find variable value!");
    return NULL;
}

static void AddStructDeclaration(
    LLVMTypeRef wStructType,
    Node **fieldDeclarations,
    uint32_t fieldDeclarationCount
) {
    uint32_t i;
    uint32_t index = structTypeFieldDeclarationCount;
    structTypeFieldDeclarations = realloc(structTypeFieldDeclarations, sizeof(StructTypeFieldDeclaration) * (structTypeFieldDeclarationCount + 1));
    structTypeFieldDeclarations[index].structType = wStructType;
    structTypeFieldDeclarations[index].fields = NULL;
    structTypeFieldDeclarations[index].fieldCount = 0;

    for (i = 0; i < fieldDeclarationCount; i += 1)
    {
        structTypeFieldDeclarations[index].fields = realloc(structTypeFieldDeclarations[index].fields, sizeof(StructTypeField) * (structTypeFieldDeclarations[index].fieldCount + 1));
        structTypeFieldDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string);
        structTypeFieldDeclarations[index].fields[i].index = i;
        structTypeFieldDeclarations[index].fieldCount += 1;
    }

    structTypeFieldDeclarationCount += 1;
}

static void AddStructVariables(
    LLVMBuilderRef builder,
    LLVMValueRef structPointer
) {
    uint32_t i, j;

    for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
    {
        if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer))
        {
            for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1)
            {
                char *ptrName = strdup(structTypeFieldDeclarations[i].fields[j].name);
                strcat(ptrName, "_ptr");
                LLVMValueRef elementPointer = LLVMBuildStructGEP(
                    builder,
                    structPointer,
                    structTypeFieldDeclarations[i].fields[j].index,
                    ptrName
                );
                free(ptrName);

                AddLocalVariable(
                    scope,
                    elementPointer,
                    structTypeFieldDeclarations[i].fields[j].name
                );
            }
        }
    }
}

static LLVMValueRef CompileExpression(
    LLVMValueRef wStructValue,
    LLVMBuilderRef builder,
    LLVMValueRef function,
    Node *binaryExpression
);

typedef struct CustomTypeMap
{
    LLVMTypeRef type;
    char *name;
} CustomTypeMap;

CustomTypeMap *customTypes;
uint32_t customTypeCount;

static void RegisterCustomType(LLVMTypeRef type, char *name)
{
    customTypes = realloc(customTypes, sizeof(CustomType) * (customTypeCount + 1));
    customTypes[customTypeCount].type = type;
    customTypes[customTypeCount].name = strdup(name);
    customTypeCount += 1;
}

static LLVMTypeRef LookupCustomType(char *name)
{
    uint32_t i;

    for (i = 0; i < customTypeCount; i += 1)
    {
        if (strcmp(customTypes[i].name, name) == 0)
        {
            return customTypes[i].type;
        }
    }

    return NULL;
}

static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{
    switch (type)
    {
        case Int:
            return LLVMInt64Type();

        case UInt:
            return LLVMInt64Type();

        case Bool:
            return LLVMInt1Type();

        case Void:
            return LLVMVoidType();
    }

    fprintf(stderr, "Unrecognized type!");
    return NULL;
}

static LLVMValueRef CompileNumber(
    Node *numberExpression
) {
    return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}

static LLVMValueRef CompileBinaryExpression(
    LLVMValueRef wStructValue,
    LLVMBuilderRef builder,
    LLVMValueRef function,
    Node *binaryExpression
) {
    LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]);
    LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]);

    switch (binaryExpression->operator.binaryOperator)
    {
        case Add:
            return LLVMBuildAdd(builder, left, right, "tmp");

        case Subtract:
            return LLVMBuildSub(builder, left, right, "tmp");

        case Multiply:
            return LLVMBuildMul(builder, left, right, "tmp");

    }

    return NULL;
}

/* FIXME THIS IS ALL BROKEN */
static LLVMValueRef CompileFunctionCallExpression(
    LLVMValueRef wStructValue,
    LLVMBuilderRef builder,
    LLVMValueRef function,
    Node *expression
) {
    uint32_t i;
    uint32_t argumentCount = expression->children[1]->childCount;
    LLVMValueRef args[argumentCount];

    for (i = 0; i < argumentCount; i += 1)
    {
        args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
    }

    //return LLVMBuildCall(builder, FindVariableValueByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp");
    return NULL;
}

static LLVMValueRef CompileAccessExpressionForStore(
    LLVMBuilderRef builder,
    LLVMValueRef wStructValue,
    LLVMValueRef function,
    Node *expression
) {
    Node *accessee = expression->children[0];
    Node *accessor = expression->children[1];
    LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
    return FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
}

static LLVMValueRef CompileAccessExpression(
    LLVMBuilderRef builder,
    LLVMValueRef wStructValue,
    LLVMValueRef function,
    Node *expression
) {
    Node *accessee = expression->children[0];
    Node *accessor = expression->children[1];
    LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
    LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
    return LLVMBuildLoad(builder, access, accessor->value.string);
}

static LLVMValueRef CompileExpression(
    LLVMValueRef wStructValue,
    LLVMBuilderRef builder,
    LLVMValueRef function,
    Node *expression
) {
    LLVMValueRef var;

    switch (expression->syntaxKind)
    {
        case AccessExpression:
            return CompileAccessExpression(builder, wStructValue, function, expression);

        case BinaryExpression:
            return CompileBinaryExpression(wStructValue, builder, function, expression);

        case FunctionCallExpression:
            return CompileFunctionCallExpression(wStructValue, builder, function, expression);

        case Identifier:
            return FindVariableValue(builder, expression->value.string);

        case Number:
            return CompileNumber(expression);
    }

    fprintf(stderr, "Unknown expression kind!\n");
    return NULL;
}

/* FIXME: we need a scope structure */
static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
    uint32_t i, j;
    LLVMValueRef expression = CompileExpression(wStructValue, builder, function, returnStatemement->children[0]);
    LLVMBuildRet(builder, expression);
}

static void CompileReturnVoid(LLVMBuilderRef builder)
{
    LLVMBuildRetVoid(builder);
}

static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
    LLVMValueRef fieldPointer;
    LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]);
    LLVMValueRef identifier;
    if (assignmentStatement->children[0]->syntaxKind == AccessExpression)
    {
        identifier = CompileAccessExpressionForStore(builder, wStructValue, function, assignmentStatement->children[0]);
    }
    else if (assignmentStatement->children[0]->syntaxKind == Identifier)
    {
        identifier = FindVariablePointer(assignmentStatement->children[0]->value.string);
    }

    LLVMBuildStore(builder, result, identifier);
}

static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration)
{
    LLVMValueRef variable;
    char *variableName = variableDeclaration->children[1]->value.string;
    char *ptrName = strdup(variableName);
    strcat(ptrName, "_ptr");

    if (variableDeclaration->children[0]->type == CustomType)
    {
        char *customTypeName = variableDeclaration->children[0]->children[0]->value.string;
        variable = LLVMBuildAlloca(builder, LookupCustomType(customTypeName), ptrName);
    }
    else
    {
        variable = LLVMBuildAlloca(builder, WraithTypeToLLVMType(variableDeclaration->children[0]->type), ptrName);
    }

    free(ptrName);

    AddLocalVariable(scope, variable, variableName);
}

static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
    switch (statement->syntaxKind)
    {
        case Assignment:
            CompileAssignment(wStructValue, builder, function, statement);
            return 0;

        case Declaration:
            CompileFunctionVariableDeclaration(builder, statement);
            return 0;

        case Return:
            CompileReturn(wStructValue, builder, function, statement);
            return 1;

        case ReturnVoid:
            CompileReturnVoid(builder);
            return 1;
    }

    fprintf(stderr, "Unknown statement kind!\n");
    return 0;
}

static void CompileFunction(
    LLVMModuleRef module,
    LLVMTypeRef wStructPointerType,
    Node **fieldDeclarations,
    uint32_t fieldDeclarationCount,
    Node *functionDeclaration
) {
    uint32_t i;
    uint8_t hasReturn = 0;
    Node *functionSignature = functionDeclaration->children[0];
    Node *functionBody = functionDeclaration->children[1];
    uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */
    LLVMTypeRef paramTypes[argumentCount];

    PushScopeFrame(scope);

    paramTypes[0] = wStructPointerType;

    for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
    {
        paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
    }

    LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type);
    LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0);
    LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);

    LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
    LLVMBuilderRef builder = LLVMCreateBuilder();
    LLVMPositionBuilderAtEnd(builder, entry);

    LLVMValueRef wStructPointer = LLVMGetParam(function, 0);

    AddStructVariables(builder, wStructPointer);

    for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
    {
        char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string);
        strcat(ptrName, "_ptr");
        LLVMValueRef argument = LLVMGetParam(function, i + 1);
        LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
        LLVMBuildStore(builder, argument, argumentCopy);
        free(ptrName);
        AddLocalVariable(scope, argumentCopy, functionSignature->children[2]->children[i]->children[1]->value.string);
    }

    for (i = 0; i < functionBody->childCount; i += 1)
    {
        hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]);
    }

    if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
    {
        LLVMBuildRetVoid(builder);
    }
    else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
    {
        fprintf(stderr, "Return statement not provided!");
    }

    PopScopeFrame(scope);
}

static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
    uint32_t i;
    uint32_t fieldCount = 0;
    uint32_t declarationCount = node->children[1]->childCount;
    uint8_t packed = 1;
    LLVMTypeRef types[declarationCount];
    Node *currentDeclarationNode;
    Node *fieldDeclarations[declarationCount];

    PushScopeFrame(scope);

    LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string);
    LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */

    /* first, build the structure definition */
    for (i = 0; i < declarationCount; i += 1)
    {
        currentDeclarationNode = node->children[1]->children[i];

        switch (currentDeclarationNode->syntaxKind)
        {
            case Declaration: /* this is badly named */
                types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type);
                fieldDeclarations[fieldCount] = currentDeclarationNode;
                fieldCount += 1;
                break;
        }
    }

    LLVMStructSetBody(wStruct, types, fieldCount, packed);
    AddStructDeclaration(wStructPointerType, fieldDeclarations, fieldCount);
    RegisterCustomType(wStruct, node->children[0]->value.string);

    /* now we can wire up the functions */
    for (i = 0; i < declarationCount; i += 1)
    {
        currentDeclarationNode = node->children[1]->children[i];

        switch (currentDeclarationNode->syntaxKind)
        {
            case FunctionDeclaration:
                CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
                break;
        }
    }

    PopScopeFrame(scope);
}

static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
    uint32_t i;

    switch (node->syntaxKind)
    {
        case StructDeclaration:
            CompileStruct(module, context, node);
            break;
    }

    for (i = 0; i < node->childCount; i += 1)
    {
        Compile(module, context, node->children[i]);
    }
}

int main(int argc, char *argv[])
{
    if (argc < 2)
    {
        printf("Please provide a file.\n");
        return 1;
    }

    scope = CreateScope();

    structTypeFieldDeclarations = NULL;
    structTypeFieldDeclarationCount = 0;

    customTypes = NULL;
    customTypeCount = 0;

    stack = CreateStack();

    FILE *fp = fopen(argv[1], "r");
    yyin = fp;
    yyparse(fp, stack);
    fclose(fp);

    PrintTree(rootNode, 0);

    LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
    LLVMContextRef context = LLVMGetGlobalContext();

    Compile(module, context, rootNode);

    char *error = NULL;
    LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
    LLVMDisposeMessage(error);

    if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
        fprintf(stderr, "error writing bitcode to file\n");
    }

    return 0;
}