提交 b8e54ab5 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: parameter name changed to ResidualData

上级 27573ece
......@@ -184,7 +184,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("EltwiseParameter", {elementwise_add_x->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true);
......
......@@ -390,14 +390,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
T* output_data = nullptr;
if (fuse_eltwise) {
auto eltwise_param = ctx.Input<Tensor>("EltwiseParameter");
auto eltwise_param_data = eltwise_param->data<T>();
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE(eltwise_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), eltwise_param->dims(), "Output and elementwise parameter need to have the same dimension sizes");
PADDLE_ENFORCE(residual_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), "Output and elementwise parameter need to have the same dimension sizes");
output_data = output->mutable_data<T>(ctx.GetPlace());
output->ShareDataWith(*eltwise_param);
output->ShareDataWith(*residual_param);
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
......
......@@ -132,8 +132,9 @@ void Conv2DOpMaker::Make() {
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.")
.Reuse("Input");
AddInput("EltwiseParameter",
"(Tensor) Tensor to which convolution output will be added."
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used on with fuse_eltwise fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册