From 75efc59433f295c038d647248beae76d8baaa018 Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Tue, 21 May 2019 20:18:46 +0800 Subject: [PATCH] Add leakyrelu arm cpu op close #1636 (#1638) --- src/common/types.cpp | 2 + src/common/types.h | 1 + src/operators/activation_op.cpp | 8 ++ src/operators/activation_op.h | 4 + src/operators/kernel/activation_kernel.h | 3 + .../kernel/arm/activation_kernel.cpp | 54 +++++++++++++ src/operators/math/activation.h | 22 +++++ src/operators/op_param.h | 25 ++++++ test/CMakeLists.txt | 7 ++ test/operators/test_leaky_relu_op.cpp | 80 +++++++++++++++++++ tools/op.cmake | 5 ++ 11 files changed, 211 insertions(+) create mode 100644 test/operators/test_leaky_relu_op.cpp diff --git a/src/common/types.cpp b/src/common/types.cpp index 20656acb20..8bec90d547 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -43,6 +43,7 @@ const char *G_OP_TYPE_POOL2D = "pool2d"; const char *G_OP_TYPE_PRIOR_BOX = "prior_box"; const char *G_OP_TYPE_RELU = "relu"; const char *G_OP_TYPE_RELU6 = "relu6"; +const char *G_OP_TYPE_LEAKY_RELU = "leaky_relu"; const char *G_OP_TYPE_RESHAPE = "reshape"; const char *G_OP_TYPE_RESHAPE2 = "reshape2"; const char *G_OP_TYPE_SIGMOID = "sigmoid"; @@ -126,6 +127,7 @@ std::unordered_map< {G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}}, {G_OP_TYPE_RELU, {{"X"}, {"Out"}}}, {G_OP_TYPE_RELU6, {{"X"}, {"Out"}}}, + {G_OP_TYPE_LEAKY_RELU, {{"X"}, {"Out"}}}, {G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}}, {G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}}, {G_OP_TYPE_MUL, {{"X"}, {"Out"}}}, diff --git a/src/common/types.h b/src/common/types.h index f8e1fd26ea..e7b3e3b9a9 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -142,6 +142,7 @@ extern const char *G_OP_TYPE_POOL2D; extern const char *G_OP_TYPE_PRIOR_BOX; extern const char *G_OP_TYPE_RELU; extern const char *G_OP_TYPE_RELU6; +extern const char *G_OP_TYPE_LEAKY_RELU; extern const char *G_OP_TYPE_RESHAPE; extern const char *G_OP_TYPE_SIGMOID; extern const char *G_OP_TYPE_SOFTMAX; diff --git a/src/operators/activation_op.cpp b/src/operators/activation_op.cpp index 7bc78ef77f..952e317261 100644 --- a/src/operators/activation_op.cpp +++ b/src/operators/activation_op.cpp @@ -45,6 +45,10 @@ DEFINE_ACTIVATION_INFERSHAPE(Tanh); DEFINE_ACTIVATION_INFERSHAPE(Log); #endif // LOG_OP +#ifdef LEAKY_RELU_OP +DEFINE_ACTIVATION_INFERSHAPE(LeakyRelu); +#endif // LEAKY_RELU_OP + } // namespace operators } // namespace paddle_mobile @@ -83,3 +87,7 @@ REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp); #ifdef LOG_OP REGISTER_OPERATOR_CPU(log, ops::LogOp); #endif // LOG_OP + +#ifdef LEAKY_RELU_OP +REGISTER_OPERATOR_CPU(leaky_relu, ops::LeakyReluOp); +#endif // LEAKY_RELU_OP diff --git a/src/operators/activation_op.h b/src/operators/activation_op.h index cecf22c225..d248da51fc 100644 --- a/src/operators/activation_op.h +++ b/src/operators/activation_op.h @@ -39,5 +39,9 @@ DECLARE_OPERATOR(Tanh, TanhParam, TanhKernel); DECLARE_OPERATOR(Log, ReluParam, LogKernel); #endif +#ifdef LEAKY_RELU_OP +DECLARE_OPERATOR(LeakyRelu, LeakyReluParam, LeakyReluKernel); +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/activation_kernel.h b/src/operators/kernel/activation_kernel.h index 9eaf3fd967..34be4b3d16 100644 --- a/src/operators/kernel/activation_kernel.h +++ b/src/operators/kernel/activation_kernel.h @@ -37,5 +37,8 @@ DECLARE_KERNEL(Tanh, TanhParam); DECLARE_KERNEL(Log, ReluParam); #endif +#ifdef LEAKY_RELU_OP +DECLARE_KERNEL(LeakyRelu, LeakyReluParam); +#endif } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/activation_kernel.cpp b/src/operators/kernel/arm/activation_kernel.cpp index 73018c8868..37c31f6ac0 100644 --- a/src/operators/kernel/arm/activation_kernel.cpp +++ b/src/operators/kernel/arm/activation_kernel.cpp @@ -25,6 +25,7 @@ namespace operators { template struct ActivationCompute { void operator()(const Tensor *input, Tensor *output) {} + void operator()(const Tensor *input, Tensor *output, float alpha) {} }; template @@ -61,6 +62,44 @@ struct ActivationCompute { y[i] = math::Active(x[i]); } } + + void operator()(const Tensor *input, Tensor *output, float falpha) { + const float *x = input->data(); + float *y = output->mutable_data(); + size_t remain = input->numel(); + float alphas[4] = {falpha, falpha, falpha, falpha}; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + size_t loop = remain >> 4; + remain = remain & 0xF; + +#pragma omp parallel for + for (size_t i = 0; i < loop; ++i) { + const float *local_x = x + (i << 4); + float *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); + float32x4_t a_r0 = vld1q_f32(alphas); + float32x4_t a_r1 = vld1q_f32(alphas); + float32x4_t a_r2 = vld1q_f32(alphas); + float32x4_t a_r3 = vld1q_f32(alphas); + r0 = math::vActiveq_f32(r0, a_r0); + r1 = math::vActiveq_f32(r1, a_r1); + r2 = math::vActiveq_f32(r2, a_r2); + r3 = math::vActiveq_f32(r3, a_r3); + vst1q_f32(local_y, r0); + vst1q_f32(local_y + 4, r1); + vst1q_f32(local_y + 8, r2); + vst1q_f32(local_y + 12, r3); + } + x += (loop << 4); + y += (loop << 4); +#endif + for (size_t i = 0; i < remain; ++i) { + y[i] = math::Active(x[i], falpha); + } + } }; #ifdef RELU_OP @@ -136,5 +175,20 @@ void LogKernel::Compute(const ReluParam ¶m) { } #endif +#ifdef LEAKY_RELU_OP +template <> +bool LeakyReluKernel::Init(LeakyReluParam *param) { + return true; +} + +template <> +void LeakyReluKernel::Compute(const LeakyReluParam ¶m) { + const LoDTensor *input = param.InputX(); + LoDTensor *output = param.Out(); + ActivationCompute()(input, output, param.Alpha()); + output->set_lod(input->lod()); +} +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/activation.h b/src/operators/math/activation.h index fb90a35516..5210a9f650 100644 --- a/src/operators/math/activation.h +++ b/src/operators/math/activation.h @@ -104,6 +104,18 @@ template <> inline float32x4_t vActiveq_f32(const float32x4_t &x) { return log_ps(x); } + +template +inline float32x4_t vActiveq_f32(const float32x4_t &x, + const float32x4_t &alpha) { + return x; +} + +template <> +inline float32x4_t vActiveq_f32(const float32x4_t &x, + const float32x4_t &alpha) { + return vmaxq_f32(x, vmulq_f32(x, alpha)); +} #endif template @@ -142,6 +154,16 @@ inline float Active(const float &x) { return log(x); } +template +inline float Active(const float &x, const float &alpha) { + return x; +} + +template <> +inline float Active(const float &x, const float &alpha) { + return std::max(x, alpha * x); +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 97dbd091a1..0beea15994 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1752,6 +1752,31 @@ class PReluParam : public OpParam { }; #endif +#ifdef LEAKY_RELU_OP +template +class LeakyReluParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + LeakyReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + input_x_ = InputXFrom(inputs, *scope); + out_ = OutFrom(outputs, *scope); + alpha_ = GetAttr("alpha", attrs); + } + const GType *InputX() const { return input_x_; } + const float Alpha() const { return alpha_; } + GType *Out() const { return out_; } + + private: + GType *input_x_; + GType *out_; + float alpha_; +}; +#endif + template class FusionFcParam : public OpParam { typedef typename DtypeTensorTrait::gtype GType; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 024b221b85..2e576741ba 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -183,6 +183,9 @@ if (CON GREATER -1) ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h) target_link_libraries(test-sigmoid paddle-mobile) + # gen test log + ADD_EXECUTABLE(test-leakyrelu operators/test_leaky_relu_op.cpp) + target_link_libraries(test-leakyrelu paddle-mobile) set(FOUND_MATCH ON) endif () @@ -384,6 +387,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-sigmoid-op operators/test_sigmoid_op.cpp test_include.h) target_link_libraries(test-sigmoid-op paddle-mobile) + # gen test log + ADD_EXECUTABLE(test-leakyrelu operators/test_leaky_relu_op.cpp) + target_link_libraries(test-leakyrelu paddle-mobile) + # gen test ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-depthwise-conv-op paddle-mobile) diff --git a/test/operators/test_leaky_relu_op.cpp b/test/operators/test_leaky_relu_op.cpp new file mode 100644 index 0000000000..3349fbd92c --- /dev/null +++ b/test/operators/test_leaky_relu_op.cpp @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "../test_include.h" +#include "operators/activation_op.h" + +namespace paddle_mobile { + +void LeakyRelu(const framework::Tensor *X, framework::Tensor *Y, float alpha) { + const float *x = X->data(); + float *y = Y->mutable_data(); + + for (int i = 0; i < X->numel(); ++i) { + y[i] = std::max(x[i], x[i] * alpha); + } +} + +int TestLeakyReluOp(const std::vector input_shape, float alpha) { + framework::DDim dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"input"}); + outputs["Out"] = std::vector({"output"}); + auto input_var = scope.get()->Var("input"); + auto input = input_var->template GetMutable(); + SetupTensor(input, dims, -100.0, 100.0); + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["alpha"].Set(alpha); + + auto *op = new operators::LeakyReluOp( + "leaky_relu", inputs, outputs, attrs, scope.get()); + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + + framework::Tensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + LeakyRelu(input, &output_cmp, alpha); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (gap > 1e-5 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } + delete op; + return 0; +} + +} // namespace paddle_mobile + +int main() { + paddle_mobile::TestLeakyReluOp({1, 1, 2, 3}, 0.2f); + paddle_mobile::TestLeakyReluOp({1, 3, 11, 22}, 0.3f); + paddle_mobile::TestLeakyReluOp({1, 32, 112, 112}, 0.4f); + std::cout << "test leaky_relu op pass." << std::endl; + return 0; +} diff --git a/tools/op.cmake b/tools/op.cmake index eb6501de22..a24c3f3597 100755 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -273,6 +273,7 @@ list(FIND NET "op" CON) if (CON GREATER -1) message("op enabled") set(SIGMOID_OP ON) + set(LEAKY_RELU_OP ON) set(FOUND_MATCH ON) endif() @@ -362,6 +363,7 @@ if(NOT FOUND_MATCH) set(PAD2D_OP ON) set(ONE_HOT_OP ON) set(ASSIGN_VALUE_OP ON) + set(LEAKY_RELU_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -691,3 +693,6 @@ endif() if (ASSIGN_VALUE_OP) add_definitions(-DASSIGN_VALUE_OP) endif() +if (LEAKY_RELU_OP) + add_definitions(-DLEAKY_RELU_OP) +endif() \ No newline at end of file -- GitLab