#include "./opr_impl.h" #include "src/naive/handle.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include "src/arm_common/simd_macro/marm_neon.h" #endif #include using namespace megdnn; using namespace fallback; namespace { template struct powci; template <> struct powci<0> { template static T apply(T) { return static_cast(1); } }; template <> struct powci<1> { template static T apply(T x) { return x; } }; template <> struct powci<2> { template static T apply(T x) { return x * x; } }; template <> struct powci<3> { template static T apply(T x) { return x * x * x; } }; template <> struct powci<4> { template static T apply(T x) { x = x * x; return x * x; } }; template struct powci { static_assert(exp < 0, "bad arg"); template static T apply(T x) { return powci<-exp>::apply(static_cast(1) / x); } }; struct powci_general_even { int exp; powci_general_even(int e) : exp{e} {} template T apply(T x) { return static_cast(std::pow(std::abs(x), static_cast(exp))); } }; template struct float_itype; #if !MEGDNN_DISABLE_FLOAT16 template <> struct float_itype<2> { using type = uint16_t; static constexpr uint16_t mask = 1u << 15; }; #endif template <> struct float_itype<4> { using type = uint32_t; static constexpr uint32_t mask = 1u << 31; }; struct powci_general_odd { template union fiu { T f; typename float_itype::type i; fiu() {} }; int exp; powci_general_odd(int e) : exp{e} {} template T apply(T x) { fiu iret, ix; iret.f = std::pow(std::abs(x), static_cast(exp)); ix.f = x; iret.i |= ix.i & float_itype::mask; return iret.f; } }; struct powcf_sqrt { template static T apply(T x) { return static_cast(std::sqrt(x)); } }; struct powcf_cbrt { template static T apply(T x) { return static_cast(std::cbrt(x)); } }; struct powcf_rep_sqrt { template static T apply(T x) { return static_cast(std::sqrt(static_cast(1) / x)); } }; struct powcf_rep_cbrt { template static T apply(T x) { return static_cast(std::cbrt(static_cast(1) / x)); } }; template struct powcf_general { float exp; powcf_general(float e) : exp{e} {} T apply(T x) { return static_cast(std::pow(std::abs(x), exp)); } }; template void pow_invoke(const T* src, T* dst, size_t size, ExpFunc expfunc) { size_t i; for (i = 0; i + 4 <= size; i += 4) { T a0 = src[i], a1 = src[i + 1], a2 = src[i + 2], a3 = src[i + 3]; T b0 = expfunc.apply(a0), b1 = expfunc.apply(a1), b2 = expfunc.apply(a2), b3 = expfunc.apply(a3); dst[i] = b0; dst[i + 1] = b1; dst[i + 2] = b2; dst[i + 3] = b3; } #if MEGDNN_FIX_AARCH32_BUG // FIXME: as llvm may cause cannot select error if enable vectorize #pragma clang loop vectorize(disable) #endif for (; i < size; ++i) { dst[i] = expfunc.apply(src[i]); } } bool float_eq(float x, float y) { return std::abs(x - y) < std::numeric_limits::epsilon(); } } // anonymous namespace template void PowCImpl::do_exec_ct( _megdnn_tensor_in src, _megdnn_tensor_out dst, const float* exp_f, const int* exp_i) { auto handle = static_cast(this->handle()); auto size = src.layout.total_nr_elems(); #define CALL(_expfunc) \ do { \ auto kern = [src, dst, size, expfunc = _expfunc]() { \ auto sptr = reinterpret_cast(src.raw_ptr()); \ auto dptr = reinterpret_cast(dst.raw_ptr()); \ pow_invoke(sptr, dptr, size, expfunc); \ }; \ handle->dispatch_kern(kern); \ return; \ } while (0) if (exp_f) { float fv = *exp_f; #define CALL_IF(_v, _expfunc) \ if (float_eq(fv, _v)) { \ CALL(_expfunc); \ return; \ } constexpr float croot = 1.f / 3.f; CALL_IF(.5f, powcf_sqrt{}); CALL_IF(croot, powcf_cbrt{}); CALL_IF(-.5f, powcf_rep_sqrt{}); CALL_IF(-croot, powcf_rep_cbrt{}); CALL(powcf_general{fv}); #undef CALL_IF } int iv = *exp_i; switch (iv) { #define CASE(n) \ case n: \ CALL(powci{}); \ return CASE(0); CASE(1); CASE(2); CASE(3); CASE(4); CASE(-1); CASE(-2); CASE(-3); CASE(-4); #undef CASE } if (iv & 1) { CALL(powci_general_odd{iv}); } else { CALL(powci_general_even{iv}); } #undef CALL } void PowCImpl::do_exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, const float* exp_f, const int* exp_i) { if (!src.layout.is_contiguous()) { naive::PowCImpl::do_exec(src, dst, exp_f, exp_i); return; } switch (src.layout.dtype.enumv()) { #define cb(dt) \ case DTypeTrait
::enumv: \ return do_exec_ct::ctype>(src, dst, exp_f, exp_i); cb(dtype::Float32); #undef cb #if !MEGDNN_DISABLE_FLOAT16 case DTypeTrait::enumv: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC return DNN_INC_FLOAT16(do_exec_ct<__fp16>(src, dst, exp_f, exp_i)); #else return DNN_INC_FLOAT16(do_exec_ct(src, dst, exp_f, exp_i)); #endif #endif default: megdnn_throw("unsupported dtype for PowC"); } } // vim: syntax=cpp.doxygen