提交 41f3d78f 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: output and elemwise param share data in...

MKLDNN conv + elementwise_add fusion: output and elemwise param share data in conv primitive. Output is properly allocated
上级 07a62ddc
......@@ -118,6 +118,7 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
if (same != std::end(node.inputs)) {
LinkNodes(to, &node);
node.Op()->SetInput("X", {to->Name()});
}
}
}
......@@ -145,7 +146,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()});
op_desc.SetInput("EltwiseParameter", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true);
......@@ -155,6 +156,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(conv_input, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op);
patterns::LinkNodes(elementwise_add_x, fused_conv_op);
patterns::LinkNodes(fused_conv_op, conv_output);
};
......
......@@ -396,7 +396,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(eltwise_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), eltwise_param->dims(), "Output and elementwise parameter need to have the same dimension sizes");
output_data = const_cast<T*>(eltwise_param_data);
output_data = output->mutable_data<T>(ctx.GetPlace());
output->ShareDataWith(*eltwise_param);
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
......
......@@ -134,7 +134,8 @@ void Conv2DOpMaker::Make() {
.Reuse("Input");
AddInput("EltwiseParameter",
"(Tensor) Tensor to which convolution output will be added."
"Used on with fuse_eltwise fusion.");
"Used on with fuse_eltwise fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册