提交 c50858ee 编写于 作者: M Megvii Engine Team

fix(dnn): specialize pow to make it consistent

GitOrigin-RevId: cff3bbbadd28297e6704e19d27b1c1882bd88a66
上级 d898838e
......@@ -26,6 +26,8 @@
#include <algorithm>
using std::max;
using std::min;
#define rsqrtf(x) (1.f / sqrt(x))
#endif
#ifndef MEGDNN_ELEMWISE_MODE_ENABLE
......@@ -93,6 +95,30 @@ __device__ __host__ inline float gelu_grad(float x, float dy) {
return dy * (normcdf_v + x * phi);
}
__device__ __host__ inline bool feq(float a, float b) {
return fabsf(a - b) < 1e-6;
}
__device__ __host__ inline float dispatch_powf(float x, float y) {
#define CALL_IF(_v, _stmt) \
do { \
if (feq(y, _v)) { \
return _stmt; \
} \
} while (0)
CALL_IF(2.f, x * x);
CALL_IF(0.5f, sqrtf(x));
CALL_IF(-0.5f, rsqrtf(x));
CALL_IF(0.f, 1.f);
CALL_IF(1.f, x);
CALL_IF(3.f, x * x * x);
CALL_IF(-1.f, 1.f / x);
CALL_IF(-2.f, 1.f / (x * x));
#undef CALL_IF
return powf(x, y);
}
#include "src/common/elemwise/each_mode.inl"
template <megcorePlatform_t plat, uint32_t mode, typename dtype>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册