未验证 提交 e8200397 编写于 作者: xiebaiyuan's avatar xiebaiyuan 提交者: GitHub

Add leakyrelu arm cpu op close #1636 (#1638)

上级 bd961137
......@@ -43,6 +43,7 @@ const char *G_OP_TYPE_POOL2D = "pool2d";
const char *G_OP_TYPE_PRIOR_BOX = "prior_box";
const char *G_OP_TYPE_RELU = "relu";
const char *G_OP_TYPE_RELU6 = "relu6";
const char *G_OP_TYPE_LEAKY_RELU = "leaky_relu";
const char *G_OP_TYPE_RESHAPE = "reshape";
const char *G_OP_TYPE_RESHAPE2 = "reshape2";
const char *G_OP_TYPE_SIGMOID = "sigmoid";
......@@ -126,6 +127,7 @@ std::unordered_map<
{G_OP_TYPE_FUSION_CONV_ADD, {{"Input"}, {"Out"}}},
{G_OP_TYPE_RELU, {{"X"}, {"Out"}}},
{G_OP_TYPE_RELU6, {{"X"}, {"Out"}}},
{G_OP_TYPE_LEAKY_RELU, {{"X"}, {"Out"}}},
{G_OP_TYPE_SOFTMAX, {{"X"}, {"Out"}}},
{G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}},
{G_OP_TYPE_MUL, {{"X"}, {"Out"}}},
......
......@@ -142,6 +142,7 @@ extern const char *G_OP_TYPE_POOL2D;
extern const char *G_OP_TYPE_PRIOR_BOX;
extern const char *G_OP_TYPE_RELU;
extern const char *G_OP_TYPE_RELU6;
extern const char *G_OP_TYPE_LEAKY_RELU;
extern const char *G_OP_TYPE_RESHAPE;
extern const char *G_OP_TYPE_SIGMOID;
extern const char *G_OP_TYPE_SOFTMAX;
......
......@@ -45,6 +45,10 @@ DEFINE_ACTIVATION_INFERSHAPE(Tanh);
DEFINE_ACTIVATION_INFERSHAPE(Log);
#endif // LOG_OP
#ifdef LEAKY_RELU_OP
DEFINE_ACTIVATION_INFERSHAPE(LeakyRelu);
#endif // LEAKY_RELU_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -83,3 +87,7 @@ REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#ifdef LOG_OP
REGISTER_OPERATOR_CPU(log, ops::LogOp);
#endif // LOG_OP
#ifdef LEAKY_RELU_OP
REGISTER_OPERATOR_CPU(leaky_relu, ops::LeakyReluOp);
#endif // LEAKY_RELU_OP
......@@ -39,5 +39,9 @@ DECLARE_OPERATOR(Tanh, TanhParam, TanhKernel);
DECLARE_OPERATOR(Log, ReluParam, LogKernel);
#endif
#ifdef LEAKY_RELU_OP
DECLARE_OPERATOR(LeakyRelu, LeakyReluParam, LeakyReluKernel);
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -37,5 +37,8 @@ DECLARE_KERNEL(Tanh, TanhParam);
DECLARE_KERNEL(Log, ReluParam);
#endif
#ifdef LEAKY_RELU_OP
DECLARE_KERNEL(LeakyRelu, LeakyReluParam);
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -25,6 +25,7 @@ namespace operators {
template <typename Dtype, ActivationType Act>
struct ActivationCompute {
void operator()(const Tensor *input, Tensor *output) {}
void operator()(const Tensor *input, Tensor *output, float alpha) {}
};
template <ActivationType Act>
......@@ -61,6 +62,44 @@ struct ActivationCompute<float, Act> {
y[i] = math::Active<Act>(x[i]);
}
}
void operator()(const Tensor *input, Tensor *output, float falpha) {
const float *x = input->data<float>();
float *y = output->mutable_data<float>();
size_t remain = input->numel();
float alphas[4] = {falpha, falpha, falpha, falpha};
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
size_t loop = remain >> 4;
remain = remain & 0xF;
#pragma omp parallel for
for (size_t i = 0; i < loop; ++i) {
const float *local_x = x + (i << 4);
float *local_y = y + (i << 4);
float32x4_t r0 = vld1q_f32(local_x);
float32x4_t r1 = vld1q_f32(local_x + 4);
float32x4_t r2 = vld1q_f32(local_x + 8);
float32x4_t r3 = vld1q_f32(local_x + 12);
float32x4_t a_r0 = vld1q_f32(alphas);
float32x4_t a_r1 = vld1q_f32(alphas);
float32x4_t a_r2 = vld1q_f32(alphas);
float32x4_t a_r3 = vld1q_f32(alphas);
r0 = math::vActiveq_f32<Act>(r0, a_r0);
r1 = math::vActiveq_f32<Act>(r1, a_r1);
r2 = math::vActiveq_f32<Act>(r2, a_r2);
r3 = math::vActiveq_f32<Act>(r3, a_r3);
vst1q_f32(local_y, r0);
vst1q_f32(local_y + 4, r1);
vst1q_f32(local_y + 8, r2);
vst1q_f32(local_y + 12, r3);
}
x += (loop << 4);
y += (loop << 4);
#endif
for (size_t i = 0; i < remain; ++i) {
y[i] = math::Active<Act>(x[i], falpha);
}
}
};
#ifdef RELU_OP
......@@ -136,5 +175,20 @@ void LogKernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
}
#endif
#ifdef LEAKY_RELU_OP
template <>
bool LeakyReluKernel<CPU, float>::Init(LeakyReluParam<CPU> *param) {
return true;
}
template <>
void LeakyReluKernel<CPU, float>::Compute(const LeakyReluParam<CPU> &param) {
const LoDTensor *input = param.InputX();
LoDTensor *output = param.Out();
ActivationCompute<float, LEAKY_RELU>()(input, output, param.Alpha());
output->set_lod(input->lod());
}
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -104,6 +104,18 @@ template <>
inline float32x4_t vActiveq_f32<LOG>(const float32x4_t &x) {
return log_ps(x);
}
template <ActivationType Act = IDENTITY>
inline float32x4_t vActiveq_f32(const float32x4_t &x,
const float32x4_t &alpha) {
return x;
}
template <>
inline float32x4_t vActiveq_f32<LEAKY_RELU>(const float32x4_t &x,
const float32x4_t &alpha) {
return vmaxq_f32(x, vmulq_f32(x, alpha));
}
#endif
template <ActivationType Act = IDENTITY>
......@@ -142,6 +154,16 @@ inline float Active<LOG>(const float &x) {
return log(x);
}
template <ActivationType Act = IDENTITY>
inline float Active(const float &x, const float &alpha) {
return x;
}
template <>
inline float Active<LEAKY_RELU>(const float &x, const float &alpha) {
return std::max(x, alpha * x);
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -1752,6 +1752,31 @@ class PReluParam : public OpParam {
};
#endif
#ifdef LEAKY_RELU_OP
template <typename Dtype>
class LeakyReluParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
LeakyReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, Scope *scope)
: OpParam(inputs, outputs, attrs, scope) {
input_x_ = InputXFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
alpha_ = GetAttr<float>("alpha", attrs);
}
const GType *InputX() const { return input_x_; }
const float Alpha() const { return alpha_; }
GType *Out() const { return out_; }
private:
GType *input_x_;
GType *out_;
float alpha_;
};
#endif
template <typename Dtype>
class FusionFcParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
......
......@@ -183,6 +183,9 @@ if (CON GREATER -1)
ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h)
target_link_libraries(test-sigmoid paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-leakyrelu operators/test_leaky_relu_op.cpp)
target_link_libraries(test-leakyrelu paddle-mobile)
set(FOUND_MATCH ON)
endif ()
......@@ -384,6 +387,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-sigmoid-op operators/test_sigmoid_op.cpp test_include.h)
target_link_libraries(test-sigmoid-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-leakyrelu operators/test_leaky_relu_op.cpp)
target_link_libraries(test-leakyrelu paddle-mobile)
# gen test
ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-depthwise-conv-op paddle-mobile)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/activation_op.h"
namespace paddle_mobile {
void LeakyRelu(const framework::Tensor *X, framework::Tensor *Y, float alpha) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = std::max(x[i], x[i] * alpha);
}
}
int TestLeakyReluOp(const std::vector<int> input_shape, float alpha) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
attrs["alpha"].Set<float>(alpha);
auto *op = new operators::LeakyReluOp<CPU, float>(
"leaky_relu", inputs, outputs, attrs, scope.get());
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
LeakyRelu(input, &output_cmp, alpha);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (gap > 1e-5 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestLeakyReluOp({1, 1, 2, 3}, 0.2f);
paddle_mobile::TestLeakyReluOp({1, 3, 11, 22}, 0.3f);
paddle_mobile::TestLeakyReluOp({1, 32, 112, 112}, 0.4f);
std::cout << "test leaky_relu op pass." << std::endl;
return 0;
}
......@@ -273,6 +273,7 @@ list(FIND NET "op" CON)
if (CON GREATER -1)
message("op enabled")
set(SIGMOID_OP ON)
set(LEAKY_RELU_OP ON)
set(FOUND_MATCH ON)
endif()
......@@ -362,6 +363,7 @@ if(NOT FOUND_MATCH)
set(PAD2D_OP ON)
set(ONE_HOT_OP ON)
set(ASSIGN_VALUE_OP ON)
set(LEAKY_RELU_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -691,3 +693,6 @@ endif()
if (ASSIGN_VALUE_OP)
add_definitions(-DASSIGN_VALUE_OP)
endif()
if (LEAKY_RELU_OP)
add_definitions(-DLEAKY_RELU_OP)
endif()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册