From c50858ee1359d8c548ea881ec4f079cf4ae0726a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Oct 2021 15:06:44 +0800 Subject: [PATCH] fix(dnn): specialize pow to make it consistent GitOrigin-RevId: cff3bbbadd28297e6704e19d27b1c1882bd88a66 --- dnn/src/common/elemwise/kern_defs.cuh | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index bde0341de..643d2a6cd 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -26,6 +26,8 @@ #include using std::max; using std::min; + +#define rsqrtf(x) (1.f / sqrt(x)) #endif #ifndef MEGDNN_ELEMWISE_MODE_ENABLE @@ -93,6 +95,30 @@ __device__ __host__ inline float gelu_grad(float x, float dy) { return dy * (normcdf_v + x * phi); } +__device__ __host__ inline bool feq(float a, float b) { + return fabsf(a - b) < 1e-6; +} + +__device__ __host__ inline float dispatch_powf(float x, float y) { +#define CALL_IF(_v, _stmt) \ + do { \ + if (feq(y, _v)) { \ + return _stmt; \ + } \ + } while (0) + + CALL_IF(2.f, x * x); + CALL_IF(0.5f, sqrtf(x)); + CALL_IF(-0.5f, rsqrtf(x)); + CALL_IF(0.f, 1.f); + CALL_IF(1.f, x); + CALL_IF(3.f, x * x * x); + CALL_IF(-1.f, 1.f / x); + CALL_IF(-2.f, 1.f / (x * x)); +#undef CALL_IF + return powf(x, y); +} + #include "src/common/elemwise/each_mode.inl" template -- GitLab