提交 3797f8fa 编写于 作者: H hong19860320

enable relu fuse in arm conv kernel

test=develop
上级 b8e6fefc
......@@ -45,7 +45,7 @@ void conv_compute_ref(const operators::ConvParam& param) {
bias_data = param.bias->mutable_data<float>();
}
bool flag_bias = bias_data != nullptr;
bool flag_relu = false; // TODO(hong19860320) param.relu
bool flag_relu = param.fuse_relu;
int num = input_dims[0];
int chout = output_dims[1];
......@@ -183,7 +183,8 @@ TEST(conv_arm, compute) {
auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
for (int i = 0; i < input.dims().production(); i++) {
input_data[i] = static_cast<float>(i % 128);
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] =
......@@ -208,7 +209,7 @@ TEST(conv_arm, compute) {
}
param.bias = &bias;
}
// TODO(hong19860320) param.relu = flag_relu;
param.fuse_relu = flag_relu;
param.paddings = std::vector<int>({padding, padding});
param.strides = std::vector<int>({stride, stride});
param.dilations =
......
......@@ -73,6 +73,7 @@ class ConvOpLite : public OpLite {
}
}
}
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册