From f9df54730215c4c2d4eff1d6cd3f421907ee40ae Mon Sep 17 00:00:00 2001
From: Evan Hemsley <2342303+ehemsley@users.noreply.github.com>
Date: Sun, 29 Dec 2019 22:19:10 -0800
Subject: [PATCH] fast AABB transformation

---
 Bonk/AABB.cs                   | 98 +++++++++++++++++++---------------
 Bonk/BroadPhase/SpatialHash.cs | 25 ++++-----
 Bonk/IShape2D.cs               |  4 +-
 Bonk/Shapes/Circle.cs          | 11 ++--
 Bonk/Shapes/Line.cs            | 34 ++++++------
 Bonk/Shapes/Point.cs           | 22 ++++----
 Bonk/Shapes/Polygon.cs         | 25 +++++----
 Bonk/Shapes/Rectangle.cs       |  9 ++--
 Bonk/Shapes/Simplex.cs         | 16 ++----
 Test/Equality.cs               | 12 ++---
 10 files changed, 136 insertions(+), 120 deletions(-)

diff --git a/Bonk/AABB.cs b/Bonk/AABB.cs
index 0ecbfdd..8075275 100644
--- a/Bonk/AABB.cs
+++ b/Bonk/AABB.cs
@@ -1,4 +1,4 @@
-using System;
+using System;
 using System.Collections.Generic;
 using System.Numerics;
 using MoonTools.Core.Structs;
@@ -10,50 +10,74 @@ namespace MoonTools.Core.Bonk
     /// </summary>
     public struct AABB : IEquatable<AABB>
     {
-        public float MinX { get; private set; }
-        public float MinY { get; private set; }
-        public float MaxX { get; private set; }
-        public float MaxY { get; private set; }
+        public Vector2 Min { get; private set; }
+        public Vector2 Max { get; private set; }
 
-        public float Width { get { return MaxX - MinX; } }
-        public float Height { get { return MaxY - MinY; } }
+        public float Width { get { return Max.X - Min.X; } }
+        public float Height { get { return Max.Y - Min.Y; } }
 
-        public static AABB FromTransformedVertices(IEnumerable<Position2D> vertices, Transform2D transform)
+        public AABB(float minX, float minY, float maxX, float maxY)
         {
-            float minX = float.MaxValue;
-            float minY = float.MaxValue;
-            float maxX = float.MinValue;
-            float maxY = float.MinValue;
+            Min = new Vector2(minX, minY);
+            Max = new Vector2(maxX, maxY);
+        }
+
+        public AABB(Vector2 min, Vector2 max)
+        {
+            Min = min;
+            Max = max;
+        }
+
+        private static Matrix4x4 AbsoluteMatrix(Matrix4x4 matrix)
+        {
+            return new Matrix4x4
+            (
+                Math.Abs(matrix.M11), Math.Abs(matrix.M12), Math.Abs(matrix.M13), Math.Abs(matrix.M14),
+                Math.Abs(matrix.M21), Math.Abs(matrix.M22), Math.Abs(matrix.M23), Math.Abs(matrix.M24),
+                Math.Abs(matrix.M31), Math.Abs(matrix.M32), Math.Abs(matrix.M33), Math.Abs(matrix.M34),
+                Math.Abs(matrix.M41), Math.Abs(matrix.M42), Math.Abs(matrix.M43), Math.Abs(matrix.M44)
+            );
+        }
+
+        public static AABB Transformed(AABB aabb, Transform2D transform)
+        {
+            var center = (aabb.Min + aabb.Max) / 2f;
+            var extent = (aabb.Max - aabb.Min) / 2f;
+
+            var newCenter = Vector2.Transform(center, transform.TransformMatrix);
+            var newExtent = Vector2.TransformNormal(extent, AbsoluteMatrix(transform.TransformMatrix));
+
+            return new AABB(newCenter - newExtent, newCenter + newExtent);
+        }
+
+        public static AABB FromVertices(IEnumerable<Position2D> vertices)
+        {
+            var minX = float.MaxValue;
+            var minY = float.MaxValue;
+            var maxX = float.MinValue;
+            var maxY = float.MinValue;
 
             foreach (var vertex in vertices)
             {
-                var transformedVertex = Vector2.Transform(vertex, transform.TransformMatrix);
-
-                if (transformedVertex.X < minX)
+                if (vertex.X < minX)
                 {
-                    minX = transformedVertex.X;
+                    minX = vertex.X;
                 }
-                if (transformedVertex.Y < minY)
+                if (vertex.Y < minY)
                 {
-                    minY = transformedVertex.Y;
+                    minY = vertex.Y;
                 }
-                if (transformedVertex.X > maxX)
+                if (vertex.X > maxX)
                 {
-                    maxX = transformedVertex.X;
+                    maxX = vertex.X;
                 }
-                if (transformedVertex.Y > maxY)
+                if (vertex.Y > maxY)
                 {
-                    maxY = transformedVertex.Y;
+                    maxY = vertex.Y;
                 }
             }
 
-            return new AABB
-            {
-                MinX = minX,
-                MinY = minY,
-                MaxX = maxX,
-                MaxY = maxY
-            };
+            return new AABB(minX, minY, maxX, maxY);
         }
 
         public override bool Equals(object obj)
@@ -63,23 +87,13 @@ namespace MoonTools.Core.Bonk
 
         public bool Equals(AABB other)
         {
-            return MinX == other.MinX &&
-                   MinY == other.MinY &&
-                   MaxX == other.MaxX &&
-                   MaxY == other.MaxY;
+            return Min == other.Min &&
+                   Max == other.Max;
         }
 
         public override int GetHashCode()
         {
-            return HashCode.Combine(MinX, MinY, MaxX, MaxY);
-        }
-
-        public AABB(float minX, float minY, float maxX, float maxY)
-        {
-            MinX = minX;
-            MinY = minY;
-            MaxX = maxX;
-            MaxY = maxY;
+            return HashCode.Combine(Min, Max);
         }
 
         public static bool operator ==(AABB left, AABB right)
diff --git a/Bonk/BroadPhase/SpatialHash.cs b/Bonk/BroadPhase/SpatialHash.cs
index 05f0099..526cd52 100644
--- a/Bonk/BroadPhase/SpatialHash.cs
+++ b/Bonk/BroadPhase/SpatialHash.cs
@@ -1,5 +1,6 @@
-using System;
+using System;
 using System.Collections.Generic;
+using System.Numerics;
 using MoonTools.Core.Structs;
 
 namespace MoonTools.Core.Bonk
@@ -20,9 +21,9 @@ namespace MoonTools.Core.Bonk
             this.cellSize = cellSize;
         }
 
-        private (int, int) Hash(float x, float y)
+        private (int, int) Hash(Vector2 position)
         {
-            return ((int)Math.Floor(x / cellSize), (int)Math.Floor(y / cellSize));
+            return ((int)Math.Floor(position.X / cellSize), (int)Math.Floor(position.Y / cellSize));
         }
 
         /// <summary>
@@ -33,13 +34,13 @@ namespace MoonTools.Core.Bonk
         /// <param name="transform2D"></param>
         public void Insert(T id, IShape2D shape, Transform2D transform2D)
         {
-            var box = shape.AABB(transform2D);
-            var minHash = Hash(box.MinX, box.MinY);
-            var maxHash = Hash(box.MaxX, box.MaxY);
+            var box = shape.TransformedAABB(transform2D);
+            var minHash = Hash(box.Min);
+            var maxHash = Hash(box.Max);
 
-            for (int i = minHash.Item1; i <= maxHash.Item1; i++)
+            for (var i = minHash.Item1; i <= maxHash.Item1; i++)
             {
-                for (int j = minHash.Item2; j <= maxHash.Item2; j++)
+                for (var j = minHash.Item2; j <= maxHash.Item2; j++)
                 {
                     if (!hashDictionary.ContainsKey(i))
                     {
@@ -62,9 +63,9 @@ namespace MoonTools.Core.Bonk
         /// </summary>
         public IEnumerable<(T, IShape2D, Transform2D)> Retrieve(T id, IShape2D shape, Transform2D transform2D)
         {
-            AABB box = shape.AABB(transform2D);
-            var minHash = Hash(box.MinX, box.MinY);
-            var maxHash = Hash(box.MaxX, box.MaxY);
+            AABB box = shape.TransformedAABB(transform2D);
+            var minHash = Hash(box.Min);
+            var maxHash = Hash(box.Max);
 
             for (int i = minHash.Item1; i <= maxHash.Item1; i++)
             {
@@ -98,4 +99,4 @@ namespace MoonTools.Core.Bonk
             IDLookup.Clear();
         }
     }
-}
\ No newline at end of file
+}
diff --git a/Bonk/IShape2D.cs b/Bonk/IShape2D.cs
index 1a5f80e..05d7616 100644
--- a/Bonk/IShape2D.cs
+++ b/Bonk/IShape2D.cs
@@ -6,6 +6,8 @@ namespace MoonTools.Core.Bonk
 {
     public interface IShape2D : IEquatable<IShape2D>
     {
+        AABB AABB { get; }
+
         /// <summary>
         /// A Minkowski support function. Gives the farthest point on the edge of a shape along the given direction.
         /// </summary>
@@ -19,6 +21,6 @@ namespace MoonTools.Core.Bonk
         /// </summary>
         /// <param name="transform">A Transform for transforming the shape vertices.</param>
         /// <returns>Returns a bounding box based on the shape.</returns>
-        AABB AABB(Transform2D transform);
+        AABB TransformedAABB(Transform2D transform);
     }
 }
diff --git a/Bonk/Shapes/Circle.cs b/Bonk/Shapes/Circle.cs
index 6b2a93b..34aa9e2 100644
--- a/Bonk/Shapes/Circle.cs
+++ b/Bonk/Shapes/Circle.cs
@@ -10,10 +10,12 @@ namespace MoonTools.Core.Bonk
     public struct Circle : IShape2D, IEquatable<Circle>
     {
         public int Radius { get; }
+        public AABB AABB { get; }
 
         public Circle(int radius)
         {
             Radius = radius;
+            AABB = new AABB(-Radius, -Radius, Radius, Radius);
         }
 
         public Vector2 Support(Vector2 direction, Transform2D transform)
@@ -21,14 +23,9 @@ namespace MoonTools.Core.Bonk
             return Vector2.Transform(Vector2.Normalize(direction) * Radius, transform.TransformMatrix);
         }
 
-        public AABB AABB(Transform2D transform2D)
+        public AABB TransformedAABB(Transform2D transform2D)
         {
-            return new AABB(
-                transform2D.Position.X - (Radius * transform2D.Scale.X),
-                transform2D.Position.Y - (Radius * transform2D.Scale.Y),
-                transform2D.Position.X + (Radius * transform2D.Scale.X),
-                transform2D.Position.Y + (Radius * transform2D.Scale.Y)
-            );
+            return AABB.Transformed(AABB, transform2D);
         }
 
         public override bool Equals(object obj)
diff --git a/Bonk/Shapes/Line.cs b/Bonk/Shapes/Line.cs
index 6e9f3eb..7deb40c 100644
--- a/Bonk/Shapes/Line.cs
+++ b/Bonk/Shapes/Line.cs
@@ -10,36 +10,40 @@ namespace MoonTools.Core.Bonk
     /// </summary>
     public struct Line : IShape2D, IEquatable<Line>
     {
-        private Position2D v0;
-        private Position2D v1;
+        private Position2D _v0;
+        private Position2D _v1;
+
+        public AABB AABB { get; }
 
         public IEnumerable<Position2D> Vertices
         {
             get
             {
-                yield return v0;
-                yield return v1;
+                yield return _v0;
+                yield return _v1;
             }
         }
 
         public Line(Position2D start, Position2D end)
         {
-            v0 = start;
-            v1 = end;
+            _v0 = start;
+            _v1 = end;
+
+            AABB = new AABB(Math.Min(_v0.X, _v1.X), Math.Min(_v0.Y, _v1.Y), Math.Max(_v0.X, _v1.X), Math.Max(_v0.Y, _v1.Y));
         }
 
         public Vector2 Support(Vector2 direction, Transform2D transform)
         {
-            var TransformedStart = Vector2.Transform(v0, transform.TransformMatrix);
-            var TransformedEnd = Vector2.Transform(v1, transform.TransformMatrix);
-            return Vector2.Dot(TransformedStart, direction) > Vector2.Dot(TransformedEnd, direction) ?
-                TransformedStart :
-                TransformedEnd;
+            var transformedStart = Vector2.Transform(_v0, transform.TransformMatrix);
+            var transformedEnd = Vector2.Transform(_v1, transform.TransformMatrix);
+            return Vector2.Dot(transformedStart, direction) > Vector2.Dot(transformedEnd, direction) ?
+                transformedStart :
+                transformedEnd;
         }
 
-        public AABB AABB(Transform2D Transform2D)
+        public AABB TransformedAABB(Transform2D transform)
         {
-            return Bonk.AABB.FromTransformedVertices(Vertices, Transform2D);
+            return AABB.Transformed(AABB, transform);
         }
 
         public override bool Equals(object obj)
@@ -54,12 +58,12 @@ namespace MoonTools.Core.Bonk
 
         public bool Equals(Line other)
         {
-            return (v0 == other.v0 && v1 == other.v1) || (v1 == other.v0 && v0 == other.v1);
+            return (_v0 == other._v0 && _v1 == other._v1) || (_v1 == other._v0 && _v0 == other._v1);
         }
 
         public override int GetHashCode()
         {
-            return HashCode.Combine(v0, v1);
+            return HashCode.Combine(_v0, _v1);
         }
 
         public static bool operator ==(Line a, Line b)
diff --git a/Bonk/Shapes/Point.cs b/Bonk/Shapes/Point.cs
index c1fbc0b..eca1eee 100644
--- a/Bonk/Shapes/Point.cs
+++ b/Bonk/Shapes/Point.cs
@@ -1,5 +1,4 @@
-using System;
-using System.Linq;
+using System;
 using System.Numerics;
 using MoonTools.Core.Structs;
 
@@ -7,26 +6,29 @@ namespace MoonTools.Core.Bonk
 {
     public struct Point : IShape2D, IEquatable<Point>
     {
-        private Position2D position;
+        private Position2D _position;
+        public AABB AABB { get; }
 
         public Point(Position2D position)
         {
-            this.position = position;
+            _position = position;
+            AABB = new AABB(position, position);
         }
 
         public Point(int x, int y)
         {
-            this.position = new Position2D(x, y);
+            _position = new Position2D(x, y);
+            AABB = new AABB(x, y, x, y);
         }
 
-        public AABB AABB(Transform2D transform)
+        public AABB TransformedAABB(Transform2D transform)
         {
-            return Bonk.AABB.FromTransformedVertices(Enumerable.Repeat<Position2D>(position, 1), transform);
+            return AABB.Transformed(AABB, transform);
         }
 
         public Vector2 Support(Vector2 direction, Transform2D transform)
         {
-            return Vector2.Transform(position.ToVector2(), transform.TransformMatrix);
+            return Vector2.Transform(_position.ToVector2(), transform.TransformMatrix);
         }
 
         public override bool Equals(object obj)
@@ -41,12 +43,12 @@ namespace MoonTools.Core.Bonk
 
         public bool Equals(Point other)
         {
-            return position == other.position;
+            return _position == other._position;
         }
 
         public override int GetHashCode()
         {
-            return HashCode.Combine(position);
+            return HashCode.Combine(_position);
         }
 
         public static bool operator ==(Point a, Point b)
diff --git a/Bonk/Shapes/Polygon.cs b/Bonk/Shapes/Polygon.cs
index f3aab33..e68663b 100644
--- a/Bonk/Shapes/Polygon.cs
+++ b/Bonk/Shapes/Polygon.cs
@@ -13,27 +13,30 @@ namespace MoonTools.Core.Bonk
     /// </summary>
     public struct Polygon : IShape2D, IEquatable<Polygon>
     {
-        private ImmutableArray<Position2D> vertices;
+        private ImmutableArray<Position2D> _vertices;
+        public AABB AABB { get; }
 
-        public IEnumerable<Position2D> Vertices { get { return vertices; } }
+        public IEnumerable<Position2D> Vertices { get { return _vertices; } }
 
-        public int VertexCount { get { return vertices.Length; } }
+        public int VertexCount { get { return _vertices.Length; } }
 
         // vertices are local to the origin
-        public Polygon(IEnumerable<Position2D> vertices) // TODO: remove this, params is bad because it allocates an array
+        public Polygon(IEnumerable<Position2D> vertices)
         {
-            this.vertices = vertices.ToImmutableArray();
+            _vertices = vertices.ToImmutableArray();
+            AABB = AABB.FromVertices(vertices);
         }
 
         public Polygon(ImmutableArray<Position2D> vertices)
         {
-            this.vertices = vertices;
+            _vertices = vertices;
+            AABB = AABB.FromVertices(vertices);
         }
 
         public Vector2 Support(Vector2 direction, Transform2D transform)
         {
             var maxDotProduct = float.NegativeInfinity;
-            var maxVertex = vertices[0].ToVector2();
+            var maxVertex = _vertices[0].ToVector2();
             foreach (var vertex in Vertices)
             {
                 var transformed = Vector2.Transform(vertex, transform.TransformMatrix);
@@ -47,9 +50,9 @@ namespace MoonTools.Core.Bonk
             return maxVertex;
         }
 
-        public AABB AABB(Transform2D Transform2D)
+        public AABB TransformedAABB(Transform2D transform)
         {
-            return Bonk.AABB.FromTransformedVertices(Vertices, Transform2D);
+            return AABB.Transformed(AABB, transform);
         }
 
         public override bool Equals(object obj)
@@ -64,11 +67,11 @@ namespace MoonTools.Core.Bonk
 
         public bool Equals(Polygon other)
         {
-            var q = from a in vertices
+            var q = from a in _vertices
                     join b in other.Vertices on a equals b
                     select a;
 
-            return vertices.Length == other.VertexCount && q.Count() == vertices.Length;
+            return _vertices.Length == other.VertexCount && q.Count() == _vertices.Length;
         }
 
         public bool Equals(Rectangle rectangle)
diff --git a/Bonk/Shapes/Rectangle.cs b/Bonk/Shapes/Rectangle.cs
index a28ca7b..ebf9a98 100644
--- a/Bonk/Shapes/Rectangle.cs
+++ b/Bonk/Shapes/Rectangle.cs
@@ -1,6 +1,5 @@
 using System;
 using System.Collections.Generic;
-using System.Linq;
 using System.Numerics;
 using MoonTools.Core.Structs;
 
@@ -16,6 +15,8 @@ namespace MoonTools.Core.Bonk
         public int MaxX { get; }
         public int MaxY { get; }
 
+        public AABB AABB { get; }
+
         public IEnumerable<Position2D> Vertices
         {
             get
@@ -33,6 +34,8 @@ namespace MoonTools.Core.Bonk
             MinY = minY;
             MaxX = maxX;
             MaxY = maxY;
+
+            AABB = new AABB(minX, minY, maxX, maxY);
         }
 
         public Vector2 Support(Vector2 direction, Transform2D transform)
@@ -52,9 +55,9 @@ namespace MoonTools.Core.Bonk
             return maxVertex;
         }
 
-        public AABB AABB(Transform2D Transform2D)
+        public AABB TransformedAABB(Transform2D transform)
         {
-            return Bonk.AABB.FromTransformedVertices(Vertices, Transform2D);
+            return AABB.Transformed(AABB, transform);
         }
 
         public override bool Equals(object obj)
diff --git a/Bonk/Shapes/Simplex.cs b/Bonk/Shapes/Simplex.cs
index b504079..e30c1fa 100644
--- a/Bonk/Shapes/Simplex.cs
+++ b/Bonk/Shapes/Simplex.cs
@@ -1,4 +1,4 @@
-using System.Linq;
+using System.Linq;
 using System.Collections.Generic;
 using System.Numerics;
 using MoonTools.Core.Structs;
@@ -9,7 +9,7 @@ namespace MoonTools.Core.Bonk
     /// <summary>
     /// A simplex is a shape with up to n - 2 vertices in the nth dimension.
     /// </summary>
-    public struct Simplex2D : IShape2D, IEquatable<Simplex2D>
+    public struct Simplex2D : IEquatable<Simplex2D>
     {
         private Vector2 a;
         private Vector2? b;
@@ -56,11 +56,6 @@ namespace MoonTools.Core.Bonk
             }
         }
 
-        public AABB AABB(Transform2D transform)
-        {
-            return Bonk.AABB.FromTransformedVertices(Vertices, transform);
-        }
-
         public Vector2 Support(Vector2 direction, Transform2D transform)
         {
             var maxDotProduct = float.NegativeInfinity;
@@ -80,12 +75,7 @@ namespace MoonTools.Core.Bonk
 
         public override bool Equals(object obj)
         {
-            return obj is IShape2D other && Equals(other);
-        }
-
-        public bool Equals(IShape2D other)
-        {
-            return other is Simplex2D otherSimplex && Equals(otherSimplex);
+            return obj is Simplex2D other && Equals(other);
         }
 
         public bool Equals(Simplex2D other)
diff --git a/Test/Equality.cs b/Test/Equality.cs
index 75af6ff..15fe666 100644
--- a/Test/Equality.cs
+++ b/Test/Equality.cs
@@ -1,4 +1,4 @@
-using NUnit.Framework;
+using NUnit.Framework;
 using FluentAssertions;
 
 using MoonTools.Core.Bonk;
@@ -18,7 +18,7 @@ namespace Tests
                 var a = new Point(1, 1);
                 var b = new Point(1, 1);
 
-                a.Should().BeEquivalentTo(b);
+                a.Equals(b).Should().BeTrue();
             }
 
             [Test]
@@ -27,7 +27,7 @@ namespace Tests
                 var a = new Point(1, 1);
                 var b = new Point(-1, 1);
 
-                a.Should().NotBeEquivalentTo(b);
+                a.Equals(b).Should().BeFalse();
             }
 
             [Test]
@@ -305,7 +305,7 @@ namespace Tests
 
                 var b = new Rectangle(-1, -1, 1, 1);
 
-                a.Should().BeEquivalentTo(b);
+                a.Equals(b).Should().BeTrue();
             }
 
             [Test]
@@ -320,7 +320,7 @@ namespace Tests
 
                 var b = new Rectangle(-1, -1, 1, 1);
 
-                a.Should().NotBeEquivalentTo(b);
+                a.Equals(b).Should().BeFalse();
             }
 
             [Test]
@@ -538,4 +538,4 @@ namespace Tests
             }
         }
     }
-}
\ No newline at end of file
+}