未验证 提交 bfaa2c93 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add elu activation op (#3978)

* add elu act

* fix elu act not find error, test=develop

* fix format.test=develop

* fix int8 model opt error in conv+conv fusion, test=develop

* fix format. test=develop

* test=develop
上级 4ddd9a98
......@@ -55,7 +55,8 @@ const std::string& ActivationTypeToStr(ActivationType act) {
"Tanh",
"Swish",
"Exp",
"ThresholdedRelu"};
"ThresholdedRelu",
"Elu"};
auto x = static_cast<int>(act);
CHECK_LT(x, static_cast<int>(ActivationType::NUM));
return act2string[x];
......
......@@ -108,7 +108,8 @@ enum class ActivationType : int {
kHardSwish = 10,
kReciprocal = 11,
kThresholdedRelu = 12,
NUM = 13,
kElu = 13,
NUM = 14,
};
static size_t PrecisionTypeLength(PrecisionType type) {
......
......@@ -763,6 +763,91 @@ void act_thresholded_relu<float>(
}
}
// elu: out = max(0,x) + min(0, alpha *(exp(x) - 1)
template <>
void act_elu<float>(
const float* din, float* dout, int size, float alpha, int threads) {
int nums_per_thread = size / threads;
int thread_remain = size % threads;
int neon_loop_cnt_dim16 = nums_per_thread >> 4;
int neon_loop_remain_dim16 = nums_per_thread & 15;
float32x4_t valpha = vdupq_n_f32(alpha);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vone = vdupq_n_f32(1.f);
int cnt = neon_loop_remain_dim16 >> 2;
int remain = neon_loop_remain_dim16 & 3;
#pragma omp parallel for
for (int i = 0; i < threads; i++) {
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim16; ++k) {
float32x4_t va = vld1q_f32(ptr_in_thread);
float32x4_t vb = vld1q_f32(ptr_in_thread + 4);
float32x4_t vc = vld1q_f32(ptr_in_thread + 8);
float32x4_t vd = vld1q_f32(ptr_in_thread + 12);
float32x4_t va_exp = exp_ps(va);
float32x4_t va_max = vmaxq_f32(va, vzero);
float32x4_t vb_exp = exp_ps(vb);
float32x4_t vb_max = vmaxq_f32(vb, vzero);
float32x4_t vc_exp = exp_ps(vc);
float32x4_t vc_max = vmaxq_f32(vc, vzero);
float32x4_t vd_exp = exp_ps(vd);
float32x4_t vd_max = vmaxq_f32(vd, vzero);
float32x4_t va_sub = vsubq_f32(va_exp, vone);
float32x4_t vb_sub = vsubq_f32(vb_exp, vone);
float32x4_t vc_sub = vsubq_f32(vc_exp, vone);
float32x4_t vd_sub = vsubq_f32(vd_exp, vone);
va_sub = vmulq_f32(va_sub, valpha);
vb_sub = vmulq_f32(vb_sub, valpha);
vc_sub = vmulq_f32(vc_sub, valpha);
vd_sub = vmulq_f32(vd_sub, valpha);
float32x4_t va_min = vminq_f32(va_sub, vzero);
float32x4_t vb_min = vminq_f32(vb_sub, vzero);
float32x4_t vc_min = vminq_f32(vc_sub, vzero);
float32x4_t vd_min = vminq_f32(vd_sub, vzero);
float32x4_t va_rst = vaddq_f32(va_max, va_min);
float32x4_t vb_rst = vaddq_f32(vb_max, vb_min);
float32x4_t vc_rst = vaddq_f32(vc_max, vc_min);
float32x4_t vd_rst = vaddq_f32(vd_max, vd_min);
vst1q_f32(ptr_out_thread, va_rst);
vst1q_f32(ptr_out_thread + 4, vb_rst);
vst1q_f32(ptr_out_thread + 8, vc_rst);
vst1q_f32(ptr_out_thread + 12, vd_rst);
ptr_out_thread += 16;
ptr_in_thread += 16;
}
for (int j = 0; j < cnt; j++) {
float32x4_t va = vld1q_f32(ptr_in_thread);
float32x4_t va_exp = exp_ps(va);
float32x4_t va_max = vmaxq_f32(va, vzero);
float32x4_t va_sub = vsubq_f32(va_exp, vone);
va_sub = vmulq_f32(va_sub, valpha);
float32x4_t va_min = vminq_f32(va_sub, vzero);
float32x4_t va_rst = vaddq_f32(va_max, va_min);
vst1q_f32(ptr_out_thread, va_rst);
ptr_out_thread += 4;
ptr_in_thread += 4;
}
for (int j = 0; j < remain; j++) {
float beta = alpha * (expf(ptr_in_thread[0]) - 1);
float max = ptr_in_thread[0] >= 0.f ? ptr_in_thread[0] : 0.f;
float min = beta <= 0.f ? beta : 0.f;
ptr_out_thread[0] = min + max;
ptr_in_thread++;
ptr_out_thread++;
}
}
float* ptr_out = dout + threads * nums_per_thread;
const float* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < thread_remain; j++) {
float beta = alpha * (expf(ptr_in[0]) - 1);
float max = ptr_in[0] >= 0.f ? ptr_in[0] : 0.f;
float min = beta <= 0.f ? beta : 0.f;
ptr_out[0] = max + min;
ptr_in++;
ptr_out++;
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -90,6 +90,9 @@ template <typename T>
void act_thresholded_relu(
const T* din, T* dout, int size, float threshold, int threads);
template <typename T>
void act_elu(const T* din, T* dout, int size, float alpha, int threads);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -27,21 +27,29 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
bool has_arm = false;
bool has_fp32 = false;
bool has_int8 = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
has_arm = true;
break;
if (place.target == TARGET(kARM)) {
if (place.precision == PRECISION(kFloat)) {
has_fp32 = true;
}
if (place.precision == PRECISION(kInt8)) {
has_int8 = true;
}
} else {
return;
}
}
if (!has_arm) {
// only support arm-fp32
if (has_int8 || (has_fp32 && has_int8)) {
return;
}
// only support fp32 fusion
for (auto conv_has_bias0 : conv_has_bias_cases) {
for (auto conv_has_bias1 : conv_has_bias_cases) {
for (auto conv_type0 : conv_type_cases) {
for (auto conv_type1 : conv_type_cases) {
for (auto conv_type1 : {"conv2d"}) { // it mustbe 1x1s1p0_conv
VLOG(4) << "conv_has_bias0:" << conv_has_bias0
<< " conv_type0:" << conv_type0;
VLOG(4) << "conv_has_bias1:" << conv_has_bias1
......
......@@ -106,7 +106,7 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
int kw = weight1_t->dims()[2];
int kh = weight1_t->dims()[3];
if (!(kw == 1 && kh == 1)) {
return;
LOG(FATAL) << "The kernel size of the second conv must be 1x1";
}
CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same";
CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1";
......@@ -117,11 +117,11 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
<< " must be 1";
}
for (int i = 0; i < paddings1.size(); i++) {
CHECK_EQ(paddings1[i], 0) << "paddings1[" << i << "]: " << paddings1[i]
<< " must be 0";
CHECK_EQ(paddings1[i], 1) << "paddings[" << i << "]: " << paddings1[i]
<< " must be 1";
}
for (int i = 0; i < dilations1.size(); i++) {
CHECK_EQ(dilations1[i], 1) << "dilations1[" << i << "]: " << dilations1[i]
CHECK_EQ(dilations1[i], 1) << "dilations[" << i << "]: " << dilations1[i]
<< " must be 1";
}
// comupte new_wight and new bias
......@@ -140,8 +140,7 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
// new_bias = az + b
///////////////////////////////////////////////////////////////////////////////
if (enable0_int8) {
LOG(FATAL) << "it doesn't support";
return;
LOG(FATAL) << "it doesn't support int8";
} else {
// compute new conv_weight
Tensor weight_tensor;
......
......@@ -228,6 +228,17 @@ void ThresholdedReluCompute::Run() {
x_data, output_data, x_dims.production(), threshold, ctx.threads());
}
void EluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
float alpha = param.Elu_alpha;
lite::arm::math::act_elu<float>(
x_data, output_data, x_dims.production(), alpha, ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -356,3 +367,8 @@ REGISTER_LITE_KERNEL(thresholded_relu,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
elu, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::EluCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -185,6 +185,15 @@ class ThresholdedReluCompute
virtual ~ThresholdedReluCompute() = default;
};
class EluCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~EluCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -85,6 +85,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
} else if (opdesc.Type() == "thresholded_relu") {
param_.active_type = lite_api::ActivationType::kThresholdedRelu;
param_.relu_threshold = opdesc.GetAttr<float>("threshold");
} else if (opdesc.Type() == "elu") {
param_.active_type = lite_api::ActivationType::kElu;
param_.Elu_alpha = opdesc.GetAttr<float>("alpha");
}
VLOG(4) << "opdesc.Type():" << opdesc.Type();
......@@ -105,3 +108,4 @@ REGISTER_LITE_OP(leaky_relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(prelu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(thresholded_relu, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(elu, paddle::lite::operators::ActivationOp);
......@@ -83,6 +83,9 @@ class ActivationOp : public OpLite {
case lite_api::ActivationType::kThresholdedRelu:
ch->macs = param_.X->numel();
break;
case lite_api::ActivationType::kElu:
ch->macs = param_.X->numel();
break;
default:
LOG(FATAL) << "This Type of Activation:"
<< static_cast<int>(param_.active_type)
......
......@@ -359,6 +359,8 @@ struct ActivationParam : ParamBase {
float hard_swish_offset{3.0};
// thresholded_relu
float relu_threshold{1.0f};
// elu
float Elu_alpha{1.0f};
};
struct ActivationGradParam : ParamBase {
......
......@@ -39,7 +39,8 @@ enum activation_type_test {
SQUARE,
HARD_SWISH,
RECIPROCAL,
THRESHOLDED_RELU
THRESHOLDED_RELU,
ELU
};
class ActivationComputeTester : public arena::TestCase {
......@@ -56,6 +57,7 @@ class ActivationComputeTester : public arena::TestCase {
float hard_swish_scale = 6.0;
float hard_swish_offset = 3.0;
float relu_threshold_ = 1.0;
float elu_alpha_ = 1.0;
DDim dims_{{1}};
std::string type_ = "";
activation_type_test act_type_ = RELU;
......@@ -67,6 +69,7 @@ class ActivationComputeTester : public arena::TestCase {
float relu_clipped_coef,
std::string prelu_mode,
float swish_beta,
float elu_alpha,
DDim dims,
std::string type,
activation_type_test act_type)
......@@ -75,6 +78,7 @@ class ActivationComputeTester : public arena::TestCase {
relu_clipped_coef_(relu_clipped_coef),
prelu_mode_(prelu_mode),
swish_beta_(swish_beta),
elu_alpha_(elu_alpha),
dims_(dims),
type_(type),
act_type_(act_type) {}
......@@ -87,6 +91,7 @@ class ActivationComputeTester : public arena::TestCase {
auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>();
LOG(INFO) << act_type_;
switch (act_type_) {
case RELU: {
for (int i = 0; i < dims_.production(); i++) {
......@@ -226,8 +231,17 @@ class ActivationComputeTester : public arena::TestCase {
}
break;
}
case ELU: {
for (int i = 0; i < dims_.production(); i++) {
float tmp = std::exp(x_data[i]) - 1;
float max = x_data[i] > 0.f ? x_data[i] : 0.f;
float min = x_data[i] < 0.f ? elu_alpha_ * tmp : 0.f;
output_data[i] = min + max;
}
break;
}
default:
LOG(INFO) << "the type of activation is unknow.";
LOG(INFO) << "the type of activation " << act_type_ << " is unknow.";
}
}
......@@ -256,6 +270,9 @@ class ActivationComputeTester : public arena::TestCase {
if (act_type_ == THRESHOLDED_RELU) {
op_desc->SetAttr("threshold", relu_threshold_);
}
if (act_type_ == ELU) {
op_desc->SetAttr("alpha", elu_alpha_);
}
}
void PrepareData() override {
......@@ -312,7 +329,7 @@ TEST(Activation_relu, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "relu", RELU));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "relu", RELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -344,6 +361,7 @@ TEST(Activation_leaky_relu, precision) {
6.,
"all",
0.,
1.0,
DDim(dims),
"leaky_relu",
LEAKY_RELU));
......@@ -376,6 +394,7 @@ TEST(Activation_relu_clipped, precision) {
coef,
"all",
0.,
1.0,
DDim(dims),
"relu_clipped",
RELU_CLIPPED));
......@@ -393,7 +412,7 @@ TEST(Activation_prelu, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{{1, 3, 2, 4}}) {
for (auto mode : {"all", "channel", "element"}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6, mode, 0., DDim(dims), "prelu", PRELU));
place, "def", 0.01, 6, mode, 0., 1.0, DDim(dims), "prelu", PRELU));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
......@@ -419,8 +438,17 @@ TEST(Activation_sigmoid, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "sigmoid", SIGMOID));
std::unique_ptr<arena::TestCase> tester(
new ActivationComputeTester(place,
"def",
0.01,
6.,
"all",
0.,
1.0,
DDim(dims),
"sigmoid",
SIGMOID));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -447,7 +475,7 @@ TEST(Activation_tanh, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "tanh", TANH));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "tanh", TANH));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -462,7 +490,7 @@ TEST(Activation_swish, precision) {
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto coef : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6, "all", coef, DDim(dims), "swish", SWISH));
place, "def", 0.01, 6, "all", coef, 1.0, DDim(dims), "swish", SWISH));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
......@@ -489,7 +517,7 @@ TEST(Activation_relu6, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "relu6", RELU6));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -511,7 +539,7 @@ TEST(Activation_log, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "log", LOG));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "log", LOG));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -525,7 +553,7 @@ TEST(Activation_exp, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "exp", EXP));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "exp", EXP));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
......@@ -539,7 +567,7 @@ TEST(Activation_floor, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "floor", FLOOR));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "floor", FLOOR));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
......@@ -554,7 +582,7 @@ TEST(Activation_rsqrt, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "rsqrt", RSQRT));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "rsqrt", RSQRT));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
......@@ -577,7 +605,7 @@ TEST(Activation_square, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "square", SQUARE));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "square", SQUARE));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -596,7 +624,7 @@ TEST(Activation_gelu, precision) {
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "gelu", GELU));
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "gelu", GELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
......@@ -622,6 +650,7 @@ TEST(activation_hard_swish, precision) {
6.,
"all",
0.,
1.0,
DDim(dims),
"hard_swish",
HARD_SWISH));
......@@ -650,6 +679,7 @@ TEST(activation_reciprocal, precision) {
6.,
"all",
0.,
1.0,
DDim(dims),
"reciprocal",
RECIPROCAL));
......@@ -680,6 +710,7 @@ TEST(Activation_thresholded_relu, precision) {
6.,
"all",
0.,
1.0,
DDim(dims),
"thresholded_relu",
THRESHOLDED_RELU));
......@@ -688,5 +719,20 @@ TEST(Activation_thresholded_relu, precision) {
}
}
TEST(Activation_elu, precision) {
LOG(INFO) << "test elu op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "elu", ELU));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
#endif
}
} // namespace lite
} // namespace paddle
......@@ -268,6 +268,8 @@ function make_all_tests {
cmake $root_dir \
${CMAKE_COMMON_OPTIONS} \
-DWITH_TESTING=ON \
-DLITE_WITH_PROFILE=OFF \
-DLITE_WITH_PRECISION_PROFILE=OFF \
-DLITE_BUILD_EXTRA=$BUILD_EXTRA \
-DLITE_WITH_CV=$BUILD_CV \
-DLITE_WITH_NPU=$BUILD_NPU \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册