From 62f42e47b91e3f6d954d244862d14a00b1acfcc1 Mon Sep 17 00:00:00 2001
From: cosmonaut <evan@moonside.games>
Date: Thu, 29 Apr 2021 23:49:35 -0700
Subject: [PATCH] initial for loop range implementation

---
 euler001.w            |  19 ++++++
 generators/wraith.lex |   2 +
 generators/wraith.y   |  30 ++++++--
 src/ast.c             |  17 +++++
 src/ast.h             |  11 ++-
 src/codegen.c         | 154 +++++++++++++++++++++++++++++++++---------
 6 files changed, 196 insertions(+), 37 deletions(-)
 create mode 100644 euler001.w

diff --git a/euler001.w b/euler001.w
new file mode 100644
index 0000000..2c5c49e
--- /dev/null
+++ b/euler001.w
@@ -0,0 +1,19 @@
+struct Program
+{
+    static Main(): int
+    {
+        sum: int;
+
+        sum = 0;
+
+        for (i in [1..1000])
+        {
+            if ((i % 3 == 0) || (i % 5 == 0))
+            {
+                sum = sum + i;
+            }
+        }
+
+        return sum;
+    }
+}
diff --git a/generators/wraith.lex b/generators/wraith.lex
index b870eb3..ee84f7b 100644
--- a/generators/wraith.lex
+++ b/generators/wraith.lex
@@ -19,6 +19,8 @@
 "alloc"                     return ALLOC;
 "if"                        return IF;
 "else"                      return ELSE;
+"in"                        return IN;
+"for"                       return FOR;
 [0-9]+                      return NUMBER;
 [a-zA-Z][a-zA-Z0-9]*        return ID;
 \"[a-zA-Z][a-zA-Z0-9]*\"    return STRING_LITERAL;
diff --git a/generators/wraith.y b/generators/wraith.y
index 831bb76..fc3e0ab 100644
--- a/generators/wraith.y
+++ b/generators/wraith.y
@@ -32,6 +32,8 @@ extern FILE *yyin;
 %token ALLOC
 %token IF
 %token ELSE
+%token IN
+%token FOR
 %token NUMBER
 %token ID
 %token STRING_LITERAL
@@ -65,10 +67,10 @@ extern FILE *yyin;
 
 %define parse.error verbose
 
-%left GREATER_THAN LESS_THAN
+%left GREATER_THAN LESS_THAN EQUAL
 %left PLUS MINUS
-%left STAR
-%left BANG
+%left STAR PERCENT
+%left BANG BAR
 %left LEFT_PAREN RIGHT_PAREN
 
 %%
@@ -139,10 +141,12 @@ AccessExpression        : Identifier POINT AccessExpression
                             $$ = $1;
                         }
 
-PrimaryExpression       : NUMBER
+Number                  : NUMBER
                         {
                             $$ = MakeNumberNode(yytext);
                         }
+
+PrimaryExpression       : Number
                         | STRING
                         {
                             $$ = MakeStringNode(yytext);
@@ -172,6 +176,10 @@ BinaryExpression        : Expression PLUS Expression
                         {
                             $$ = MakeBinaryNode(Multiply, $1, $3);
                         }
+                        | Expression PERCENT Expression
+                        {
+                            $$ = MakeBinaryNode(Mod, $1, $3);
+                        }
                         | Expression LESS_THAN Expression
                         {
                             $$ = MakeBinaryNode(LessThan, $1, $3);
@@ -180,6 +188,14 @@ BinaryExpression        : Expression PLUS Expression
                         {
                             $$ = MakeBinaryNode(GreaterThan, $1, $3);
                         }
+                        | Expression EQUAL EQUAL Expression
+                        {
+                            $$ = MakeBinaryNode(Equal, $1, $4);
+                        }
+                        | Expression BAR BAR Expression
+                        {
+                            $$ = MakeBinaryNode(LogicalOr, $1, $4);
+                        }
 
 Expression              : BinaryExpression
                         | UnaryExpression
@@ -240,8 +256,14 @@ Conditional             : IfStatement
                             $$ = MakeIfElseNode($1, $3);
                         }
 
+ForStatement            : FOR LEFT_PAREN Identifier IN LEFT_BRACKET Number POINT POINT Number RIGHT_BRACKET RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE
+                        {
+                            $$ = MakeForLoopNode($3, $6, $9, $13);
+                        }
+
 Statement               : PartialStatement SEMICOLON
                         | Conditional
+                        | ForStatement
                         ;
 
 Statements              : Statement
diff --git a/src/ast.c b/src/ast.c
index b736938..24bdeb9 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -440,6 +440,23 @@ Node* MakeIfElseNode(
     return node;
 }
 
+Node* MakeForLoopNode(
+    Node *identifierNode,
+    Node *startNumberNode,
+    Node *endNumberNode,
+    Node *statementSequenceNode
+) {
+    Node* node = (Node*) malloc(sizeof(Node));
+    node->syntaxKind = ForLoop;
+    node->childCount = 4;
+    node->children = (Node**) malloc(sizeof(Node*) * 4);
+    node->children[0] = identifierNode;
+    node->children[1] = startNumberNode;
+    node->children[2] = endNumberNode;
+    node->children[3] = statementSequenceNode;
+    return node;
+}
+
 static const char* PrimitiveTypeToString(PrimitiveType type)
 {
     switch (type)
diff --git a/src/ast.h b/src/ast.h
index ac65995..8e49ace 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -47,8 +47,11 @@ typedef enum
     Add,
     Subtract,
     Multiply,
+    Mod,
+    Equal,
     LessThan,
-    GreaterThan
+    GreaterThan,
+    LogicalOr
 } BinaryOperator;
 
 typedef enum
@@ -200,6 +203,12 @@ Node* MakeIfElseNode(
     Node *ifNode,
     Node *statementSequenceNode
 );
+Node* MakeForLoopNode(
+    Node *identifierNode,
+    Node *startNumberNode,
+    Node *endNumberNode,
+    Node *statementSequenceNode
+);
 
 void PrintTree(Node *node, uint32_t tabCount);
 
diff --git a/src/codegen.c b/src/codegen.c
index 4719317..29f48c2 100644
--- a/src/codegen.c
+++ b/src/codegen.c
@@ -19,6 +19,7 @@ typedef struct LocalVariable
 {
     char *name;
     LLVMValueRef pointer;
+    LLVMValueRef value;
 } LocalVariable;
 
 typedef struct FunctionArgument
@@ -111,14 +112,19 @@ static void PopScopeFrame(Scope *scope)
     scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
 }
 
-static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name)
-{
+static void AddLocalVariable(
+    Scope *scope,
+    LLVMValueRef pointer, /* can be NULL */
+    LLVMValueRef value, /* can be NULL */
+    char *name
+) {
     ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
     uint32_t index = scopeFrame->localVariableCount;
 
     scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
     scopeFrame->localVariables[index].name = strdup(name);
     scopeFrame->localVariables[index].pointer = pointer;
+    scopeFrame->localVariables[index].value = value;
 
     scopeFrame->localVariableCount += 1;
 }
@@ -220,7 +226,14 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
         {
             if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
             {
-                return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
+                if (scope->scopeStack[i].localVariables[j].value != NULL)
+                {
+                    return scope->scopeStack[i].localVariables[j].value;
+                }
+                else
+                {
+                    return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
+                }
             }
         }
     }
@@ -414,6 +427,7 @@ static void AddStructVariablesToScope(
                 AddLocalVariable(
                     scope,
                     elementPointer,
+                    NULL,
                     structTypeDeclarations[i].fields[j].name
                 );
             }
@@ -456,6 +470,15 @@ static LLVMValueRef CompileBinaryExpression(
 
         case GreaterThan:
             return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult");
+
+        case Mod:
+            return LLVMBuildSRem(builder, left, right, "modResult");
+
+        case Equal:
+            return LLVMBuildICmp(builder, LLVMIntEQ, left, right, "equalResult");
+
+        case LogicalOr:
+            return LLVMBuildOr(builder, left, right, "orResult");
     }
 
     return NULL;
@@ -581,20 +604,22 @@ static LLVMValueRef CompileExpression(
     return NULL;
 }
 
-static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
+static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
 
-static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement)
+static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
 {
     LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
     LLVMBuildRet(builder, expression);
+    return LLVMGetLastBasicBlock(function);
 }
 
-static void CompileReturnVoid(LLVMBuilderRef builder)
+static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef function)
 {
     LLVMBuildRetVoid(builder);
+    return LLVMGetLastBasicBlock(function);
 }
 
-static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
+static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
 {
     LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
     LLVMValueRef identifier;
@@ -609,14 +634,16 @@ static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
     else
     {
         printf("Identifier not found!");
-        return;
+        return LLVMGetLastBasicBlock(function);
     }
 
     LLVMBuildStore(builder, result, identifier);
+
+    return LLVMGetLastBasicBlock(function);
 }
 
 /* FIXME: path for reference types */
-static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration)
+static LLVMBasicBlockRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
 {
     LLVMValueRef variable;
     char *variableName = variableDeclaration->children[1]->value.string;
@@ -627,10 +654,12 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var
 
     free(ptrName);
 
-    AddLocalVariable(scope, variable, variableName);
+    AddLocalVariable(scope, variable, NULL, variableName);
+
+    return LLVMGetLastBasicBlock(function);
 }
 
-static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
+static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
 {
     uint32_t i;
     LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
@@ -649,9 +678,11 @@ static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, No
 
     LLVMBuildBr(builder, afterCond);
     LLVMPositionBuilderAtEnd(builder, afterCond);
+
+    return afterCond;
 }
 
-static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
+static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
 {
     uint32_t i;
     LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]);
@@ -686,45 +717,102 @@ static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function
     }
 
     LLVMBuildBr(builder, afterCond);
-
     LLVMPositionBuilderAtEnd(builder, afterCond);
+
+    return afterCond;
 }
 
-static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
+static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement)
+{
+    uint32_t i;
+    LLVMBasicBlockRef entryBlock = LLVMAppendBasicBlock(function, "loopEntry");
+    LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
+    LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
+    LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
+    char *iteratorVariableName = forLoopStatement->children[0]->value.string;
+
+    PushScopeFrame(scope);
+
+    LLVMBuildBr(builder, entryBlock);
+
+    LLVMPositionBuilderAtEnd(builder, entryBlock);
+    LLVMBuildBr(builder, checkBlock);
+
+    LLVMPositionBuilderAtEnd(builder, checkBlock);
+    LLVMValueRef iteratorValue = LLVMBuildPhi(builder, LLVMInt64Type(), iteratorVariableName);
+    AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName);
+
+    LLVMPositionBuilderAtEnd(builder, bodyBlock);
+    LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next");
+
+    LLVMPositionBuilderAtEnd(builder, checkBlock);
+
+    LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]);
+    LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
+
+    LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
+
+    LLVMPositionBuilderAtEnd(builder, bodyBlock);
+
+    LLVMBasicBlockRef lastBlock;
+    for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1)
+    {
+        lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]);
+    }
+
+    LLVMBuildBr(builder, checkBlock);
+
+    LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
+
+    LLVMValueRef incomingValues[2];
+    incomingValues[0] = CompileNumber(forLoopStatement->children[1]);
+    incomingValues[1] = nextValue;
+
+    LLVMBasicBlockRef incomingBlocks[2];
+    incomingBlocks[0] = entryBlock;
+    incomingBlocks[1] = lastBlock;
+
+    LLVMAddIncoming(iteratorValue, incomingValues, incomingBlocks, 2);
+
+    LLVMPositionBuilderAtEnd(builder, afterLoopBlock);
+
+    PopScopeFrame(scope);
+
+    return afterLoopBlock;
+}
+
+static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
 {
     switch (statement->syntaxKind)
     {
         case Assignment:
-            CompileAssignment(builder, statement);
-            return 0;
+            return CompileAssignment(builder, function, statement);
+
+        case Declaration:
+            return CompileFunctionVariableDeclaration(builder, function, statement);
+
+        case ForLoop:
+            return CompileForLoopStatement(builder, function, statement);
 
         case FunctionCallExpression:
             CompileFunctionCallExpression(builder, statement);
-            return 0;
-
-        case Declaration:
-            CompileFunctionVariableDeclaration(builder, statement);
-            return 0;
+            return LLVMGetLastBasicBlock(function);
 
         case IfStatement:
-            CompileIfStatement(builder, function, statement);
-            return 0;
+            return CompileIfStatement(builder, function, statement);
 
         case IfElseStatement:
-            CompileIfElseStatement(builder, function, statement);
-            return 0;
+            return CompileIfElseStatement(builder, function, statement);
 
         case Return:
-            CompileReturn(builder, statement);
-            return 1;
+            return CompileReturn(builder, function, statement);
 
         case ReturnVoid:
-            CompileReturnVoid(builder);
-            return 1;
+            return CompileReturnVoid(builder, function);
     }
 
     fprintf(stderr, "Unknown statement kind!\n");
-    return 0;
+    return NULL;
 }
 
 static void CompileFunction(
@@ -800,14 +888,16 @@ static void CompileFunction(
         LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
         LLVMBuildStore(builder, argument, argumentCopy);
         free(ptrName);
-        AddLocalVariable(scope, argumentCopy, functionSignature->children[2]->children[i]->children[1]->value.string);
+        AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string);
     }
 
     for (i = 0; i < functionBody->childCount; i += 1)
     {
-        hasReturn |= CompileStatement(builder, function, functionBody->children[i]);
+        CompileStatement(builder, function, functionBody->children[i]);
     }
 
+    hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
+
     if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
     {
         LLVMBuildRetVoid(builder);