From 3797f8fafbb5e9bb92522c0e2462540c66d553ea Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Sat, 15 Jun 2019 05:17:49 +0000 Subject: [PATCH] enable relu fuse in arm conv kernel test=develop --- paddle/fluid/lite/kernels/arm/conv_compute_test.cc | 7 ++++--- paddle/fluid/lite/operators/conv_op.h | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc index e4d80265d77..f25a5cf0745 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc @@ -45,7 +45,7 @@ void conv_compute_ref(const operators::ConvParam& param) { bias_data = param.bias->mutable_data(); } bool flag_bias = bias_data != nullptr; - bool flag_relu = false; // TODO(hong19860320) param.relu + bool flag_relu = param.fuse_relu; int num = input_dims[0]; int chout = output_dims[1]; @@ -183,7 +183,8 @@ TEST(conv_arm, compute) { auto* filter_data = filter.mutable_data(); auto* output_data = output.mutable_data(); for (int i = 0; i < input.dims().production(); i++) { - input_data[i] = static_cast(i % 128); + float sign = i % 3 == 0 ? -1.0f : 1.0f; + input_data[i] = sign * static_cast(i % 128); } for (int i = 0; i < filter.dims().production(); i++) { filter_data[i] = @@ -208,7 +209,7 @@ TEST(conv_arm, compute) { } param.bias = &bias; } - // TODO(hong19860320) param.relu = flag_relu; + param.fuse_relu = flag_relu; param.paddings = std::vector({padding, padding}); param.strides = std::vector({stride, stride}); param.dilations = diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h index ceb80456312..3d09d42241c 100644 --- a/paddle/fluid/lite/operators/conv_op.h +++ b/paddle/fluid/lite/operators/conv_op.h @@ -73,6 +73,7 @@ class ConvOpLite : public OpLite { } } } + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); return true; } -- GitLab