diff --git a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc index e4d80265d7728fa0eeea97fd070a982a8888ec7e..f25a5cf07452f128681bb4367b7dfc8f7fb09c0d 100644 --- a/paddle/fluid/lite/kernels/arm/conv_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/conv_compute_test.cc @@ -45,7 +45,7 @@ void conv_compute_ref(const operators::ConvParam& param) { bias_data = param.bias->mutable_data(); } bool flag_bias = bias_data != nullptr; - bool flag_relu = false; // TODO(hong19860320) param.relu + bool flag_relu = param.fuse_relu; int num = input_dims[0]; int chout = output_dims[1]; @@ -183,7 +183,8 @@ TEST(conv_arm, compute) { auto* filter_data = filter.mutable_data(); auto* output_data = output.mutable_data(); for (int i = 0; i < input.dims().production(); i++) { - input_data[i] = static_cast(i % 128); + float sign = i % 3 == 0 ? -1.0f : 1.0f; + input_data[i] = sign * static_cast(i % 128); } for (int i = 0; i < filter.dims().production(); i++) { filter_data[i] = @@ -208,7 +209,7 @@ TEST(conv_arm, compute) { } param.bias = &bias; } - // TODO(hong19860320) param.relu = flag_relu; + param.fuse_relu = flag_relu; param.paddings = std::vector({padding, padding}); param.strides = std::vector({stride, stride}); param.dilations = diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h index ceb80456312f8bd208807f1c1ef74c055a709181..3d09d42241c7cbfcc6dd6893d50196550469d28c 100644 --- a/paddle/fluid/lite/operators/conv_op.h +++ b/paddle/fluid/lite/operators/conv_op.h @@ -73,6 +73,7 @@ class ConvOpLite : public OpLite { } } } + param_.fuse_relu = op_desc.GetAttr("fuse_relu"); return true; }