From 7a395f69f3076dc063d92d8e124d022f7886cc62 Mon Sep 17 00:00:00 2001 From: zp7 <9678873+ForceDaryl@users.noreply.github.com> Date: Wed, 3 Jul 2019 12:43:20 +0800 Subject: [PATCH] add relu6 threshold param (#1724) --- src/operators/activation_op.h | 2 +- src/operators/kernel/activation_kernel.h | 2 +- src/operators/kernel/arm/activation_kernel.cpp | 7 ++++--- src/operators/kernel/cl/cl_kernel/relu6.cl | 5 +++-- src/operators/kernel/cl/relu6_kernel.cpp | 6 ++++-- src/operators/math/activation.h | 13 +++++++++++++ src/operators/op_param.h | 14 ++++++++++++++ test/operators/test_relu6_op.cpp | 1 + 8 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/operators/activation_op.h b/src/operators/activation_op.h index d248da51fc..cd250080e5 100644 --- a/src/operators/activation_op.h +++ b/src/operators/activation_op.h @@ -24,7 +24,7 @@ namespace operators { #ifdef RELU_OP DECLARE_OPERATOR(Relu, ReluParam, ReluKernel); -DECLARE_OPERATOR(Relu6, ReluParam, Relu6Kernel); +DECLARE_OPERATOR(Relu6, Relu6Param, Relu6Kernel); #endif #ifdef SIGMOID_OP diff --git a/src/operators/kernel/activation_kernel.h b/src/operators/kernel/activation_kernel.h index 34be4b3d16..b27691d521 100644 --- a/src/operators/kernel/activation_kernel.h +++ b/src/operators/kernel/activation_kernel.h @@ -22,7 +22,7 @@ namespace operators { #ifdef RELU_OP DECLARE_KERNEL(Relu, ReluParam); -DECLARE_KERNEL(Relu6, ReluParam); +DECLARE_KERNEL(Relu6, Relu6Param); #endif #ifdef SIGMOID_OP diff --git a/src/operators/kernel/arm/activation_kernel.cpp b/src/operators/kernel/arm/activation_kernel.cpp index d5343e5a04..be8ebc532f 100644 --- a/src/operators/kernel/arm/activation_kernel.cpp +++ b/src/operators/kernel/arm/activation_kernel.cpp @@ -38,15 +38,16 @@ void ReluKernel::Compute(const ReluParam ¶m) { } template <> -bool Relu6Kernel::Init(ReluParam *param) { +bool Relu6Kernel::Init(Relu6Param *param) { return true; } template <> -void Relu6Kernel::Compute(const ReluParam ¶m) { +void Relu6Kernel::Compute(const Relu6Param ¶m) { const LoDTensor *input = param.InputX(); LoDTensor *output = param.Out(); - ActivationCompute()(input, output); + float threshold = param.getThreshold(); + ActivationCompute()(input, output, threshold); output->set_lod(input->lod()); } #endif diff --git a/src/operators/kernel/cl/cl_kernel/relu6.cl b/src/operators/kernel/cl/cl_kernel/relu6.cl index 5040dd6bba..7a2f0e022f 100644 --- a/src/operators/kernel/cl/cl_kernel/relu6.cl +++ b/src/operators/kernel/cl/cl_kernel/relu6.cl @@ -15,7 +15,8 @@ limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable __kernel void relu6(__read_only image2d_t input, - __write_only image2d_t output){ + __write_only image2d_t output, + __private const float threshold){ const int x = get_global_id(0); const int y = get_global_id(1); @@ -26,6 +27,6 @@ __kernel void relu6(__read_only image2d_t input, half4 in = read_imageh(input, sampler, (int2)(x, y)); in = max((half4)(0.0f, 0.0f, 0.0f, 0.0f), in); - in = min((half4)(6.0f, 6.0f, 6.0f, 6.0f), in); + in = min((half4)(threshold, threshold, threshold, threshold), in); write_imageh(output, (int2)(x, y), in); } diff --git a/src/operators/kernel/cl/relu6_kernel.cpp b/src/operators/kernel/cl/relu6_kernel.cpp index f918f8ee04..06167e8075 100644 --- a/src/operators/kernel/cl/relu6_kernel.cpp +++ b/src/operators/kernel/cl/relu6_kernel.cpp @@ -19,21 +19,23 @@ namespace paddle_mobile { namespace operators { template <> -bool Relu6Kernel::Init(ReluParam* param) { +bool Relu6Kernel::Init(Relu6Param* param) { this->cl_helper_.AddKernel("relu6", "relu6.cl"); return true; } template <> -void Relu6Kernel::Compute(const ReluParam& param) { +void Relu6Kernel::Compute(const Relu6Param& param) { auto kernel = this->cl_helper_.KernelAt(0); const auto* input = param.InputX(); auto* output = param.Out(); + float threshold = param.getThreshold(); auto default_work_size = this->cl_helper_.DefaultWorkSize(*output); auto inputImage = input->GetCLImage(); auto outputImage = output->GetCLImage(); clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); + clSetKernelArg(kernel, 2, sizeof(cl_mem), &threshold); const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()}; clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL, diff --git a/src/operators/math/activation.h b/src/operators/math/activation.h index 5210a9f650..d5a4cba4ce 100644 --- a/src/operators/math/activation.h +++ b/src/operators/math/activation.h @@ -116,6 +116,14 @@ inline float32x4_t vActiveq_f32(const float32x4_t &x, const float32x4_t &alpha) { return vmaxq_f32(x, vmulq_f32(x, alpha)); } + +template <> +inline float32x4_t vActiveq_f32(const float32x4_t &x, + const float32x4_t &alpha) { + float32x4_t __zero = vdupq_n_f32(0.f); + float32x4_t __threshold = vdupq_n_f32(vgetq_lane_f32(alpha, 0)); + return vminq_f32(vmaxq_f32(x, __zero), __threshold); +} #endif template @@ -164,6 +172,11 @@ inline float Active(const float &x, const float &alpha) { return std::max(x, alpha * x); } +template <> +inline float Active(const float &x, const float &alpha) { + return std::min(std::max(x, 0.f), alpha); +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 751c7b87ff..d7de42cb08 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1675,6 +1675,20 @@ class ReluParam : public ReluParamBase { using ReluParamBase::ReluParamBase; }; +template +class Relu6Param : public ReluParamBase { + public: + Relu6Param(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : ReluParamBase(inputs, outputs, attrs, scope) { + threshold = OpParam::GetAttr("threshold", attrs); + } + float getThreshold() const { return threshold; } + + private: + float threshold; +}; + #ifdef PADDLE_MOBILE_CL template <> class ReluParam : public ReluParamBase { diff --git a/test/operators/test_relu6_op.cpp b/test/operators/test_relu6_op.cpp index fcbaa0ba89..8681c4155d 100644 --- a/test/operators/test_relu6_op.cpp +++ b/test/operators/test_relu6_op.cpp @@ -44,6 +44,7 @@ int TestRelu6Op(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; + attrs["threshold"].Set(6.f); auto *op = new operators::Relu6Op("relu6", inputs, outputs, attrs, scope.get()); op->InferShape(); -- GitLab