#include "ast.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

char* strdup (const char* s)
{
  size_t slen = strlen(s);
  char* result = malloc(slen + 1);
  if(result == NULL)
  {
    return NULL;
  }

  memcpy(result, s, slen+1);
  return result;
}

const char* SyntaxKindString(SyntaxKind syntaxKind)
{
    switch(syntaxKind)
    {
        case AccessExpression: return "AccessExpression";
        case Assignment: return "Assignment";
        case BinaryExpression: return "BinaryExpression";
        case Comment: return "Comment";
        case Declaration: return "Declaration";
        case DeclarationSequence: return "DeclarationSequence";
        case FunctionArgumentSequence: return "FunctionArgumentSequence";
        case FunctionCallExpression: return "FunctionCallExpression";
        case FunctionDeclaration: return "FunctionDeclaration";
        case FunctionModifiers: return "FunctionModifiers";
        case FunctionSignature: return "FunctionSignature";
        case FunctionSignatureArguments: return "FunctionSignatureArguments";
        case Identifier: return "Identifier";
        case Number: return "Number";
        case Return: return "Return";
        case StatementSequence: return "StatementSequence";
        case StaticModifier: return "StaticModifier";
        case StringLiteral: return "StringLiteral";
        case StructDeclaration: return "StructDeclaration";
        case Type: return "Type";
        case UnaryExpression: return "UnaryExpression";
        default: return "Unknown";
    }
}

Node* MakeTypeNode(
    PrimitiveType type
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Type;
    node->type = type;
    node->childCount = 0;
    return node;
}

Node* MakeCustomTypeNode(
    Node *identifierNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Type;
    node->type = CustomType;
    node->childCount = 1;
    node->children = (Node**) malloc(sizeof(Node*));
    node->children[0] = identifierNode;
    return node;
}

Node* MakeIdentifierNode(
    const char *id
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Identifier;
    node->value.string = strdup(id);
    node->childCount = 0;
    return node;
}

Node* MakeNumberNode(
    const char *numberString
) {
    char *ptr;
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Number;
    node->value.number = strtoul(numberString, &ptr, 10);
    node->childCount = 0;
    return node;
}

Node* MakeStringNode(
    const char *string
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = StringLiteral;
    node->value.string = strdup(string);
    node->childCount = 0;
    return node;
}

Node* MakeStaticNode()
{
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = StaticModifier;
    node->childCount = 0;
    return node;
}

Node* MakeFunctionModifiersNode(
    Node **pModifierNodes,
    uint32_t modifierCount
) {
    uint32_t i;
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = FunctionModifiers;
    node->childCount = modifierCount;
    if (modifierCount > 0)
    {
        node->children = malloc(sizeof(Node*) * node->childCount);
        for (i = 0; i < modifierCount; i += 1)
        {
            node->children[i] = pModifierNodes[i];
        }
    }

    return node;
}

Node* MakeUnaryNode(
    UnaryOperator operator,
    Node *child
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = UnaryExpression;
    node->operator.unaryOperator = operator;
    node->children = malloc(sizeof(Node*));
    node->children[0] = child;
    node->childCount = 1;
    return node;
}

Node* MakeBinaryNode(
    BinaryOperator operator,
    Node *left,
    Node *right
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = BinaryExpression;
    node->operator.binaryOperator = operator;
    node->children = malloc(sizeof(Node*) * 2);
    node->children[0] = left;
    node->children[1] = right;
    node->childCount = 2;
    return node;
}

Node* MakeDeclarationNode(
    Node* typeNode,
    Node* identifierNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Declaration;
    node->children = (Node**) malloc(sizeof(Node*) * 2);
    node->childCount = 2;
    node->children[0] = typeNode;
    node->children[1] = identifierNode;
    return node;
}

Node* MakeAssignmentNode(
    Node *left,
    Node *right
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Assignment;
    node->childCount = 2;
    node->children = malloc(sizeof(Node*) * 2);
    node->children[0] = left;
    node->children[1] = right;
    return node;
}

Node* MakeStatementSequenceNode(
    Node** pNodes,
    uint32_t nodeCount
) {
    int32_t i;
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = StatementSequence;
    node->children = (Node**) malloc(sizeof(Node*) * nodeCount);
    node->childCount = nodeCount;
    for (i = nodeCount - 1; i >= 0; i -= 1)
    {
        node->children[nodeCount - 1 - i] = pNodes[i];
    }
    return node;
}

Node* MakeReturnStatementNode(
    Node *expressionNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = Return;
    node->children = (Node**) malloc(sizeof(Node*));
    node->childCount = 1;
    node->children[0] = expressionNode;
    return node;
}

Node* MakeReturnVoidStatementNode()
{
    Node *node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = ReturnVoid;
    node->childCount = 0;
    node->children = NULL;
    return node;
}

Node *MakeFunctionSignatureArgumentsNode(
    Node **pArgumentNodes,
    uint32_t argumentCount
) {
    int32_t i;
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = FunctionSignatureArguments;
    node->childCount = argumentCount;
    node->children = (Node**) malloc(sizeof(Node*) * (node->childCount));

    for (i = argumentCount - 1; i >= 0; i -= 1)
    {
        node->children[argumentCount - 1 - i] = pArgumentNodes[i];
    }

    return node;
}

Node* MakeFunctionSignatureNode(
    Node *identifierNode,
    Node* typeNode,
    Node* arguments,
    Node* modifiersNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = FunctionSignature;
    node->childCount = 4;
    node->children = (Node**) malloc(sizeof(Node*) * (node->childCount));
    node->children[0] = identifierNode;
    node->children[1] = typeNode;
    node->children[2] = arguments;
    node->children[3] = modifiersNode;
    return node;
}

Node* MakeFunctionDeclarationNode(
    Node* functionSignatureNode,
    Node* functionBodyNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = FunctionDeclaration;
    node->childCount = 2;
    node->children = (Node**) malloc(sizeof(Node*) * 2);
    node->children[0] = functionSignatureNode;
    node->children[1] = functionBodyNode;
    return node;
}

Node* MakeStructDeclarationNode(
    Node *identifierNode,
    Node *declarationSequenceNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = StructDeclaration;
    node->childCount = 2;
    node->children = (Node**) malloc(sizeof(Node*) * 2);
    node->children[0] = identifierNode;
    node->children[1] = declarationSequenceNode;
    return node;
}

Node* MakeDeclarationSequenceNode(
    Node **pNodes,
    uint32_t nodeCount
) {
    int32_t i;
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = DeclarationSequence;
    node->children = (Node**) malloc(sizeof(Node*) * nodeCount);
    node->childCount = nodeCount;
    for (i = nodeCount - 1; i >= 0; i -= 1)
    {
        node->children[nodeCount - 1 - i] = pNodes[i];
    }
    return node;
}

Node *MakeFunctionArgumentSequenceNode(
    Node **pArgumentNodes,
    uint32_t argumentCount
) {
    int32_t i;
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = FunctionArgumentSequence;
    node->childCount = argumentCount;
    node->children = (Node**) malloc(sizeof(Node*) * node->childCount);
    for (i = argumentCount - 1; i >= 0; i -= 1)
    {
        node->children[argumentCount - 1 - i] = pArgumentNodes[i];
    }
    return node;
}

Node* MakeFunctionCallExpressionNode(
    Node *identifierNode,
    Node *argumentSequenceNode
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = FunctionCallExpression;
    node->children = (Node**) malloc(sizeof(Node*) * 2);
    node->childCount = 2;
    node->children[0] = identifierNode;
    node->children[1] = argumentSequenceNode;
    return node;
}

Node* MakeAccessExpressionNode(
    Node *accessee,
    Node *accessor
) {
    Node* node = (Node*) malloc(sizeof(Node));
    node->syntaxKind = AccessExpression;
    node->children = (Node**) malloc(sizeof(Node*) * 2);
    node->childCount = 2;
    node->children[0] = accessee;
    node->children[1] = accessor;
    return node;
}

static const char* PrimitiveTypeToString(PrimitiveType type)
{
    switch (type)
    {
        case Int: return "Int";
        case UInt: return "UInt";
        case Bool: return "Bool";
        case Void: return "Void";
        case CustomType: return "CustomType";
    }

    return "Unknown";
}

static void PrintBinaryOperator(BinaryOperator expression)
{
    switch (expression)
    {
        case Add:
            printf("+");
            break;

        case Subtract:
            printf("-");
            break;

        case Multiply:
            printf("*");
            break;
    }
}

static void PrintNode(Node *node, int tabCount)
{
    uint32_t i;
    for (i = 0; i < tabCount; i += 1)
    {
        printf("  ");
    }

    printf("%s: ", SyntaxKindString(node->syntaxKind));
    switch (node->syntaxKind)
    {
        case BinaryExpression:
            PrintBinaryOperator(node->operator.binaryOperator);
            break;

        case Declaration:
            break;

        case Type:
            printf("%s", PrimitiveTypeToString(node->type));
            break;

        case Identifier:
            printf("%s", node->value.string);
            break;

        case Number:
            printf("%lu", node->value.number);
            break;
    }

    printf("\n");
}

void PrintTree(Node *node, uint32_t tabCount)
{
    uint32_t i;
    PrintNode(node, tabCount);
    for (i = 0; i < node->childCount; i += 1)
    {
        PrintTree(node->children[i], tabCount + 1);
    }
}