未验证 提交 6afe03ef 编写于 作者: I Ivan Zlatanov 提交者: GitHub

Vector.Sum(Vector<T>) API implementation for horizontal add. (#53527)

* Vector.Sum(Vector<T>) API implementation for horizontal add.

* Fixed inccorrect referece to Arm64 AddAccross intrinsic function.

* Added implementation for hardware accelerated Vector<T>.Sum for long, ulong, float, double on ARM64.

* Fixed formatting issue.

* Correctness.

* Fixed compiler error for ARM64.

* Formatting issue.

* More explicit switch statement. Fixed wrong simd size for NI_Vector64_ToScalar.

* Fixed auto formatting issue.

* Use AddPairwiseScalar for double, long and ulong on ARM64 for VectorT128_Sum.

* Forgot ToScalar call after AddPairwiseScalar.

* Fixed wrong return type.
上级 b074b8db
......@@ -719,12 +719,118 @@ GenTree* Compiler::impSimdAsHWIntrinsicSpecial(NamedIntrinsic intrinsic,
}
break;
}
case NI_VectorT128_Sum:
{
if (compOpportunisticallyDependsOn(InstructionSet_SSSE3))
{
GenTree* tmp;
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength);
for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, NI_SSSE3_HorizontalAdd,
simdBaseJitType, simdSize);
}
return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType,
simdSize);
}
return nullptr;
}
case NI_VectorT256_Sum:
{
// HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together.
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength) - 1; // Minus 1 because for the last pass we split the vector
// to low / high and add them together.
GenTree* tmp;
NamedIntrinsic horizontalAdd = NI_AVX2_HorizontalAdd;
NamedIntrinsic add = NI_SSE2_Add;
if (simdBaseType == TYP_DOUBLE)
{
horizontalAdd = NI_AVX_HorizontalAdd;
}
else if (simdBaseType == TYP_FLOAT)
{
horizontalAdd = NI_AVX_HorizontalAdd;
add = NI_SSE_Add;
}
for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, horizontalAdd, simdBaseJitType, simdSize);
}
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, gtNewIconNode(0x01, TYP_INT),
NI_AVX_ExtractVector128, simdBaseJitType, simdSize);
tmp = gtNewSimdAsHWIntrinsicNode(simdType, tmp, NI_Vector256_GetLower, simdBaseJitType, simdSize);
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD16, op1, tmp, add, simdBaseJitType, 16);
return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType, 16);
}
#elif defined(TARGET_ARM64)
case NI_VectorT128_Abs:
{
assert(varTypeIsUnsigned(simdBaseType));
return op1;
}
case NI_VectorT128_Sum:
{
GenTree* tmp;
switch (simdBaseType)
{
case TYP_BYTE:
case TYP_UBYTE:
case TYP_SHORT:
case TYP_USHORT:
case TYP_INT:
case TYP_UINT:
{
tmp = gtNewSimdAsHWIntrinsicNode(simdType, op1, NI_AdvSimd_Arm64_AddAcross, simdBaseJitType,
simdSize);
return gtNewSimdAsHWIntrinsicNode(retType, tmp, NI_Vector64_ToScalar, simdBaseJitType, 8);
}
case TYP_FLOAT:
{
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
int haddCount = genLog2(vectorLength);
for (int i = 0; i < haddCount; i++)
{
op1 = impCloneExpr(op1, &tmp, clsHnd, (unsigned)CHECK_SPILL_ALL,
nullptr DEBUGARG("Clone op1 for Vector<T>.Sum"));
op1 = gtNewSimdAsHWIntrinsicNode(simdType, op1, tmp, NI_AdvSimd_Arm64_AddPairwise,
simdBaseJitType, simdSize);
}
return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector128_ToScalar, simdBaseJitType,
simdSize);
}
case TYP_DOUBLE:
case TYP_LONG:
case TYP_ULONG:
{
op1 = gtNewSimdAsHWIntrinsicNode(TYP_SIMD8, op1, NI_AdvSimd_Arm64_AddPairwiseScalar,
simdBaseJitType, simdSize);
return gtNewSimdAsHWIntrinsicNode(retType, op1, NI_Vector64_ToScalar, simdBaseJitType, 8);
}
default:
{
unreached();
}
}
}
#else
#error Unsupported platform
#endif // !TARGET_XARCH && !TARGET_ARM64
......
......@@ -132,6 +132,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Inequality,
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Multiply, 2, {NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_Illegal, NI_Illegal, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Subtraction, 2, {NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Subtract, NI_AdvSimd_Arm64_Subtract}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, SquareRoot, 1, {NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_AdvSimd_Arm64_Sqrt, NI_AdvSimd_Arm64_Sqrt}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, Sum, 1, {NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum}, SimdAsHWIntrinsicFlag::None)
#undef SIMD_AS_HWINTRINSIC_NM
#undef SIMD_AS_HWINTRINSIC_ID
......
......@@ -132,6 +132,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Inequality,
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Multiply, 2, {NI_Illegal, NI_Illegal, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply, NI_Illegal, NI_Illegal, NI_VectorT128_op_Multiply, NI_VectorT128_op_Multiply}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, op_Subtraction, 2, {NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE2_Subtract, NI_SSE_Subtract, NI_SSE2_Subtract}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, SquareRoot, 1, {NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_SSE_Sqrt, NI_SSE2_Sqrt}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT128, Sum, 1, {NI_Illegal, NI_Illegal, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_VectorT128_Sum, NI_Illegal, NI_Illegal, NI_VectorT128_Sum, NI_VectorT128_Sum}, SimdAsHWIntrinsicFlag::None)
// *************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
// ISA ID Name NumArg Instructions Flags
......@@ -170,6 +171,7 @@ SIMD_AS_HWINTRINSIC_ID(VectorT256, op_Inequality,
SIMD_AS_HWINTRINSIC_ID(VectorT256, op_Multiply, 2, {NI_Illegal, NI_Illegal, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply, NI_Illegal, NI_Illegal, NI_VectorT256_op_Multiply, NI_VectorT256_op_Multiply}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT256, op_Subtraction, 2, {NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX2_Subtract, NI_AVX_Subtract, NI_AVX_Subtract}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT256, SquareRoot, 1, {NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_Illegal, NI_AVX_Sqrt, NI_AVX_Sqrt}, SimdAsHWIntrinsicFlag::None)
SIMD_AS_HWINTRINSIC_ID(VectorT256, Sum, 1, {NI_Illegal, NI_Illegal, NI_VectorT256_Sum, NI_VectorT256_Sum, NI_VectorT256_Sum, NI_VectorT256_Sum, NI_Illegal, NI_Illegal, NI_VectorT256_Sum, NI_VectorT256_Sum}, SimdAsHWIntrinsicFlag::None)
#undef SIMD_AS_HWINTRINSIC_NM
#undef SIMD_AS_HWINTRINSIC_ID
......
......@@ -300,6 +300,7 @@ public static partial class Vector
[System.CLSCompliantAttribute(false)]
public static void Widen(System.Numerics.Vector<System.UInt32> source, out System.Numerics.Vector<System.UInt64> low, out System.Numerics.Vector<System.UInt64> high) { throw null; }
public static System.Numerics.Vector<T> Xor<T>(System.Numerics.Vector<T> left, System.Numerics.Vector<T> right) where T : struct { throw null; }
public static T Sum<T>(System.Numerics.Vector<T> value) where T : struct { throw null; }
}
public partial struct Vector2 : System.IEquatable<System.Numerics.Vector2>, System.IFormattable
{
......
......@@ -3137,6 +3137,49 @@ public void NarrowDouble()
}
#endregion
#region Sum
[Fact]
public void SumInt32() => TestSum<int>(x => x.Aggregate((a, b) => a + b));
[Fact]
public void SumInt64() => TestSum<long>(x => x.Aggregate((a, b) => a + b));
[Fact]
public void SumSingle() => TestSum<float>(x => x.Aggregate((a, b) => a + b));
[Fact]
public void SumDouble() => TestSum<double>(x => x.Aggregate((a, b) => a + b));
[Fact]
public void SumUInt32() => TestSum<uint>(x => x.Aggregate((a, b) => a + b));
[Fact]
public void SumUInt64() => TestSum<ulong>(x => x.Aggregate((a, b) => a + b));
[Fact]
public void SumByte() => TestSum<byte>(x => x.Aggregate((a, b) => (byte)(a + b)));
[Fact]
public void SumSByte() => TestSum<sbyte>(x => x.Aggregate((a, b) => (sbyte)(a + b)));
[Fact]
public void SumInt16() => TestSum<short>(x => x.Aggregate((a, b) => (short)(a + b)));
[Fact]
public void SumUInt16() => TestSum<ushort>(x => x.Aggregate((a, b) => (ushort)(a + b)));
private static void TestSum<T>(Func<T[], T> expected) where T : struct, IEquatable<T>
{
T[] values = GenerateRandomValuesForVector<T>();
Vector<T> vector = new(values);
T sum = Vector.Sum(vector);
AssertEqual(expected(values), sum, "Sum");
}
#endregion
#region Helper Methods
private static void AssertEqual<T>(T expected, T actual, string operation, int precision = -1) where T : IEquatable<T>
{
......
......@@ -1292,5 +1292,14 @@ internal static void ThrowInsufficientNumberOfElementsException(int requiredElem
return Unsafe.As<Vector<TFrom>, Vector<TTo>>(ref vector);
}
/// <summary>
/// Returns the sum of all elements inside the vector.
/// </summary>
[Intrinsic]
public static T Sum<T>(Vector<T> value) where T : struct
{
return Vector<T>.Sum(value);
}
}
}
......@@ -822,6 +822,19 @@ internal static T Dot(Vector<T> left, Vector<T> right)
return product;
}
[Intrinsic]
internal static T Sum(Vector<T> value)
{
T sum = default;
for (nint index = 0; index < Count; index++)
{
sum = ScalarAdd(sum, value.GetElement(index));
}
return sum;
}
[Intrinsic]
internal static unsafe Vector<T> SquareRoot(Vector<T> value)
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册