提交 e5ce9659 编写于 作者: T tensor-tang

refine and add eltadd_relu unit test

上级 7cb19a59
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): add deps #include "paddle/fluid/operators/math/jit_kernel.h"
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
......
...@@ -447,20 +447,17 @@ class VAddReluKernelImpl : public VAddReluKernel<T> { ...@@ -447,20 +447,17 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef __AVX__ #ifdef __AVX__
INTRI8_FLOAT(jit::avx); INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx); INTRI16_FLOAT(jit::avx);
INTRI_COMMON_FLOAT(jit::avx, kGT8LT16);
INTRI_COMMON_FLOAT(jit::avx, kGT16); INTRI_COMMON_FLOAT(jit::avx, kGT16);
#endif #endif
#ifdef __AVX2__ #ifdef __AVX2__
INTRI8_FLOAT(jit::avx2); INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2); INTRI16_FLOAT(jit::avx2);
INTRI_COMMON_FLOAT(jit::avx2, kGT8LT16);
INTRI_COMMON_FLOAT(jit::avx2, kGT16); INTRI_COMMON_FLOAT(jit::avx2, kGT16);
#endif #endif
#ifdef __AVX512F__ #ifdef __AVX512F__
// TODO(TJ): refine avx512 // TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f); INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f); INTRI16_FLOAT(jit::avx512f);
INTRI_COMMON_FLOAT(jit::avx512f, kGT8LT16);
INTRI_COMMON_FLOAT(jit::avx512f, kGT16); INTRI_COMMON_FLOAT(jit::avx512f, kGT16);
#endif #endif
......
...@@ -712,6 +712,63 @@ TEST(JitKernel, vadd) { ...@@ -712,6 +712,63 @@ TEST(JitKernel, vadd) {
} }
} }
void vaddrelu_ref(const int n, const float* x, const float* y, float* z) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] + y[i];
z[i] = z[i] > 0 ? z[i] : 0;
}
}
void vaddrelu_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z) {
vadd->Compute(x, y, z);
vrelu->Compute(z, z);
}
TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddReluKernel<float>>(d);
const auto& vadd =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
const auto& vrelu =
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_ref(d, x_data, y_data, zref_data);
}
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data);
}
auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat;
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
TEST(JitKernel, pool) { TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel; namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4; const int frame_size = 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册