CyclicBuffers SIMD

This commit is contained in:
wreng 2024-02-21 16:06:00 +03:00 committed by Ikpil
parent fa837a84ed
commit c47cc79552
3 changed files with 128 additions and 26 deletions

View File

@ -262,6 +262,11 @@ namespace DotRecast.Core.Buffers
return new Span<T>(_buffer, 0, _end); return new Span<T>(_buffer, 0, _end);
} }
internal ReadOnlySpan<T> GetBufferSpan()
{
return _buffer;
}
public Enumerator GetEnumerator() => new Enumerator(this); public Enumerator GetEnumerator() => new Enumerator(this);
IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator(); IEnumerator<T> IEnumerable<T>.GetEnumerator() => GetEnumerator();

View File

@ -1,20 +1,31 @@
namespace DotRecast.Core.Buffers using System;
using System.Numerics;
using System.Runtime.InteropServices;
namespace DotRecast.Core.Buffers
{ {
public static class RcCyclicBuffers public static class RcCyclicBuffers
{ {
public static long Sum(this RcCyclicBuffer<long> source) public static long Sum(this RcCyclicBuffer<long> source)
{ {
long sum = 0; var buffer = source.GetBufferSpan();
checked var result = 0L;
if (Vector.IsHardwareAccelerated)
{ {
// NOTE: SIMD would be nice here var vectors = MemoryMarshal.Cast<long, Vector<long>>(buffer);
foreach (var x in source) var vecSum = Vector<long>.Zero;
{ foreach (var vec in vectors)
sum += x; vecSum += vec;
}
}
return sum; result = Vector.Dot(vecSum, Vector<long>.One);
var remainder = source.Size % Vector<long>.Count;
buffer = buffer[^remainder..];
}
foreach (var val in buffer)
result += val;
return result;
} }
public static double Average(this RcCyclicBuffer<long> source) public static double Average(this RcCyclicBuffer<long> source)
@ -27,32 +38,54 @@
public static long Min(this RcCyclicBuffer<long> source) public static long Min(this RcCyclicBuffer<long> source)
{ {
if (0 >= source.Size) var buffer = source.GetBufferSpan();
return 0; var result = long.MaxValue;
long minValue = long.MaxValue; if (Vector.IsHardwareAccelerated)
foreach (var x in source)
{ {
if (x < minValue) var vectors = MemoryMarshal.Cast<long, Vector<long>>(buffer);
minValue = x; var vecMin = Vector<long>.One * result;
}
foreach (var vec in vectors)
vecMin = Vector.Min(vecMin, vec);
return minValue; for (int i = 0; i < Vector<long>.Count; i++)
result = Math.Min(result, vecMin[i]);
var remainder = source.Size % Vector<long>.Count;
buffer = buffer[^remainder..];
}
foreach (var val in buffer)
result = Math.Min(result, val);
return result;
} }
public static long Max(this RcCyclicBuffer<long> source) public static long Max(this RcCyclicBuffer<long> source)
{ {
if (0 >= source.Size) var buffer = source.GetBufferSpan();
return 0; var result = long.MinValue;
long maxValue = long.MinValue; if (Vector.IsHardwareAccelerated)
foreach (var x in source)
{ {
if (x > maxValue) var vectors = MemoryMarshal.Cast<long, Vector<long>>(buffer);
maxValue = x; var vecMax = Vector<long>.One * result;
}
foreach (var vec in vectors)
vecMax = Vector.Max(vecMax, vec);
return maxValue; for (int i = 0; i < Vector<long>.Count; i++)
result = Math.Max(result, vecMax[i]);
var remainder = source.Size % Vector<long>.Count;
buffer = buffer[^remainder..];
}
foreach (var val in buffer)
result = Math.Max(result, val);
return result;
} }
} }
} }

View File

@ -330,4 +330,68 @@ public class RcCyclicBufferTests
Assert.That(enumerator.Current, Is.EqualTo(refValues[index++])); Assert.That(enumerator.Current, Is.EqualTo(refValues[index++]));
} }
} }
[Test]
public void RcCyclicBuffers_Sum()
{
var refValues = Enumerable.Range(-100, 211).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Sum(buffer), Is.EqualTo(refValues.Sum()));
}
[Test]
public void RcCyclicBuffers_Average()
{
var refValues = Enumerable.Range(-100, 211).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Average(buffer), Is.EqualTo(refValues.Average()));
}
[Test]
public void RcCyclicBuffers_Min()
{
var refValues = Enumerable.Range(-100, 211).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Min(buffer), Is.EqualTo(refValues.Min()));
}
[Test]
public void RcCyclicBuffers_Max()
{
var refValues = Enumerable.Range(-100, 211).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Max(buffer), Is.EqualTo(refValues.Max()));
}
[Test]
public void RcCyclicBuffers_SumUnaligned()
{
var refValues = Enumerable.Range(-1, 3).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Sum(buffer), Is.EqualTo(refValues.Sum()));
}
[Test]
public void RcCyclicBuffers_AverageUnaligned()
{
var refValues = Enumerable.Range(-1, 3).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Average(buffer), Is.EqualTo(refValues.Average()));
}
[Test]
public void RcCyclicBuffers_MinUnaligned()
{
var refValues = Enumerable.Range(5, 3).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Min(buffer), Is.EqualTo(refValues.Min()));
}
[Test]
public void RcCyclicBuffers_MaxUnaligned()
{
var refValues = Enumerable.Range(-5, 3).Select(x => (long)x).ToArray();
var buffer = new RcCyclicBuffer<long>(refValues.Length, refValues);
Assert.That(RcCyclicBuffers.Max(buffer), Is.EqualTo(refValues.Max()));
}
} }