diff --git a/lite/api/paddle_place.cc b/lite/api/paddle_place.cc index 59603e25f3b7e4942a6be4d7af008c4a9dd6772b..3567f986d2ff753b558fec4d40c478972addd0e6 100644 --- a/lite/api/paddle_place.cc +++ b/lite/api/paddle_place.cc @@ -55,7 +55,8 @@ const std::string& ActivationTypeToStr(ActivationType act) { "Tanh", "Swish", "Exp", - "ThresholdedRelu"}; + "ThresholdedRelu", + "Elu"}; auto x = static_cast(act); CHECK_LT(x, static_cast(ActivationType::NUM)); return act2string[x]; diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index a43e74cd3a13b2e4fecd95428b9fd3fe8579d4d3..1aa62683d3a252769ee76c8f287e37828d0d0595 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -107,7 +107,8 @@ enum class ActivationType : int { kHardSwish = 10, kReciprocal = 11, kThresholdedRelu = 12, - NUM = 13, + kElu = 13, + NUM = 14, }; static size_t PrecisionTypeLength(PrecisionType type) { diff --git a/lite/backends/arm/math/activation.cc b/lite/backends/arm/math/activation.cc index 01f25cbd36d327f7a3c252fdc675262d39748318..805b87da09a3f07edbb0591ac9d1f9eb4488ab89 100644 --- a/lite/backends/arm/math/activation.cc +++ b/lite/backends/arm/math/activation.cc @@ -763,6 +763,86 @@ void act_thresholded_relu( } } +// elu: out = max(0,x) + min(0, alpha *(exp(x) - 1) +template <> +void act_elu( + const float* din, float* dout, int size, float alpha, int threads) { + int nums_per_thread = size / threads; + int thread_remain = size % threads; + int neon_loop_cnt_dim16 = nums_per_thread >> 4; + int neon_loop_remain_dim16 = nums_per_thread & 15; + float32x4_t alpha = vdupq_n_f32(alpha); + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t vone = vdupq_n_f32(1.f); + int cnt = neon_loop_remain_dim16 >> 2; + int remain = neon_loop_remain_dim16 & 3; +#pragma omp parallel for + for (int i = 0; i < threads; i++) { + const float* ptr_in_thread = din + i * nums_per_thread; + float* ptr_out_thread = dout + i * nums_per_thread; + for (int k = 0; k < neon_loop_cnt_dim16; ++k) { + float32x4_t va = vld1q_f32(ptr_in_thread); // x + float32x4_t vb = vld1q_f32(ptr_in_thread + 4); + float32x4_t vc = vld1q_f32(ptr_in_thread + 8); + float32x4_t vd = vld1q_f32(ptr_in_thread + 12); + float32x4_t va_exp = exp_ps(va); + float32x4_t va_max = vmaxq_f32(va, vzero); + float32x4_t vb_exp = exp_ps(vb); + float32x4_t vb_max = vmaxq_f32(vb, vzero); + float32x4_t vc_exp = exp_ps(vc); + float32x4_t vc_max = vmaxq_f32(vc, vzero); + float32x4_t vd_exp = exp_ps(vd); + float32x4_t vd_max = vmaxq_f32(vd, vzero); + float32x4_t va_sub = vsubq_f32(va_exp, vone); + float32x4_t vb_sub = vsubq_f32(vb_exp, vone); + float32x4_t vc_sub = vsubq_f32(vc_exp, vone); + float32x4_t vd_sub = vsubq_f32(vd_exp, vone); + float32x4_t va_min = vminq_f32(va_sub, vzero); + float32x4_t vb_min = vminq_f32(vb_sub, vzero); + float32x4_t vc_min = vminq_f32(vc_sub, vzero); + float32x4_t vd_min = vminq_f32(vd_sub, vzero); + float32x4_t va_rst = vaddq_f32(va_max, va_min); + float32x4_t vb_rst = vaddq_f32(vb_max, vb_min); + float32x4_t vc_rst = vaddq_f32(vc_max, vc_min); + float32x4_t vd_rst = vaddq_f32(vd_max, vd_min); + vst1q_f32(ptr_out_thread, va_rst); + vst1q_f32(ptr_out_thread + 4, vb_rst); + vst1q_f32(ptr_out_thread + 8, vc_rst); + vst1q_f32(ptr_out_thread + 12, vd_rst); + ptr_out_thread += 16; + ptr_in_thread += 16; + } + for (int j = 0; j < cnt; j++) { + float32x4_t va = vld1q_f32(ptr_in_thread); // x + float32x4_t va_exp = exp_ps(va); + float32x4_t va_max = vmaxq_f32(va, vzero); + float32x4_t va_sub = vsubq_f32(va_exp, vone); + float32x4_t va_min = vminq_f32(va_sub, vzero); + float32x4_t va_rst = vaddq_f32(va_max, va_min); + vst1q_f32(ptr_out_thread, va_rst); + ptr_out_thread += 4; + ptr_in_thread += 4; + } + for (int j = 0; j < remain; j++) { + float beta = alpha * (expf(ptr_in_thread[0]) - 1); + float max = ptr_in[0] >= 0.f ? ptr_in_thread[0] : 0.f; + float min = beta <= 0.f ? beta : 0.f; + ptr_out_thread[0] = min + max; + ptr_in_thread++; + ptr_out_thread++; + } + } + float* ptr_out = dout + threads * nums_per_thread; + const float* ptr_in = din + threads * nums_per_thread; + for (int j = 0; j < thread_remain; j++) { + float beta = alpha * (expf(ptr_in[0]) - 1); + float max = ptr_in[0] >= 0.f ? ptr_in[0] : 0.f; + float min = beta <= 0.f ? beta : 0.f; + ptr_out[0] = max + min; + ptr_in++; + ptr_out++; + } +} } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/activation.h b/lite/backends/arm/math/activation.h index b0147040cd11a888ec045948f0914a13aa932a2f..c4a002bbe7d8c8ebafc4d0415fbe84509bf14f93 100644 --- a/lite/backends/arm/math/activation.h +++ b/lite/backends/arm/math/activation.h @@ -90,6 +90,10 @@ template void act_thresholded_relu( const T* din, T* dout, int size, float threshold, int threads); +template +void act_elu(const T* din, T* dout, int size, float alpha, int threads); + + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/kernels/arm/activation_compute.cc b/lite/kernels/arm/activation_compute.cc index 5f3174edbbb53381db29bfa6b99f62a9e7094a4d..a5da9dfec8ba9d93db2c493a81ae08da536840c7 100644 --- a/lite/kernels/arm/activation_compute.cc +++ b/lite/kernels/arm/activation_compute.cc @@ -228,6 +228,17 @@ void ThresholdedReluCompute::Run() { x_data, output_data, x_dims.production(), threshold, ctx.threads()); } +void EluCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto x_dims = param.X->dims(); + auto x_data = param.X->data(); + auto output_data = param.Out->mutable_data(); + float alpha = param.Elu_alpha; + lite::arm::math::act_elu( + x_data, output_data, x_dims.production(), alpha, ctx.threads()); +} + } // namespace arm } // namespace kernels } // namespace lite @@ -356,3 +367,12 @@ REGISTER_LITE_KERNEL(thresholded_relu, .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); +REGISTER_LITE_KERNEL(elu, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::EluCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); \ No newline at end of file diff --git a/lite/kernels/arm/activation_compute.h b/lite/kernels/arm/activation_compute.h index a915937590ee8748ac419c5b33f82c81d8480852..fb0753768b0a4bde703f0016e75c24b35455168e 100644 --- a/lite/kernels/arm/activation_compute.h +++ b/lite/kernels/arm/activation_compute.h @@ -185,6 +185,16 @@ class ThresholdedReluCompute virtual ~ThresholdedReluCompute() = default; }; +class EluCompute + : public KernelLite { + public: + using param_t = operators::ActivationParam; + + void Run() override; + + virtual ~EluCompute() = default; +}; + } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index 01e4116c94c75df3bd5360494c57419fe57c18ef..19a2134440d4d47adc5324e56e08f08457d1acf8 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -85,6 +85,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { } else if (opdesc.Type() == "thresholded_relu") { param_.active_type = lite_api::ActivationType::kThresholdedRelu; param_.relu_threshold = opdesc.GetAttr("threshold"); + } else if (opdesc.Type() == "elu") { + param_.active_type = lite_api::ActivationType::kElu; + param_.param_.Elu_alpha = opdesc.GetAttr("alpha"); } VLOG(4) << "opdesc.Type():" << opdesc.Type(); @@ -105,3 +108,4 @@ REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(prelu, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(thresholded_relu, paddle::lite::operators::ActivationOp); +REGISTER_LITE_OP(elu, paddle::lite::operators::ActivationOp); diff --git a/lite/operators/activation_ops.h b/lite/operators/activation_ops.h index 250a88de42b4004932f78b0490a844d4a8dbc6fe..aadfe8ba09514b9a0a37f18603eb47ec96427206 100644 --- a/lite/operators/activation_ops.h +++ b/lite/operators/activation_ops.h @@ -83,6 +83,9 @@ class ActivationOp : public OpLite { case lite_api::ActivationType::kThresholdedRelu: ch->macs = param_.X->numel(); break; + case lite_api::ActivationType::kElu: + ch->macs = param_.X->numel(); + break; default: LOG(FATAL) << "This Type of Activation:" << static_cast(param_.active_type) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index f351e8e5344424d80fa79f8d7c83be3bf367441f..24f081f3751003d8e04ca4e6a0e57de542d9488d 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -359,6 +359,8 @@ struct ActivationParam : ParamBase { float hard_swish_offset{3.0}; // thresholded_relu float relu_threshold{1.0f}; + // elu + float Elu_alpha{1.0f}; }; struct ActivationGradParam : ParamBase {