提交 41f3d78f 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: output and elemwise param share data in...

MKLDNN conv + elementwise_add fusion: output and elemwise param share data in conv primitive. Output is properly allocated
上级 07a62ddc
...@@ -118,6 +118,7 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { ...@@ -118,6 +118,7 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
if (same != std::end(node.inputs)) { if (same != std::end(node.inputs)) {
LinkNodes(to, &node); LinkNodes(to, &node);
node.Op()->SetInput("X", {to->Name()});
} }
} }
} }
...@@ -145,7 +146,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -145,7 +146,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
op_desc.SetInput("Input", {conv_input->Name()}); op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()}); op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ElementwiseParameter", {elementwise_add_x->Name()}); op_desc.SetInput("EltwiseParameter", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()}); op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true); op_desc.SetAttr("use_mkldnn", true);
...@@ -155,6 +156,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -155,6 +156,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(conv_input, fused_conv_op); patterns::LinkNodes(conv_input, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op); patterns::LinkNodes(conv_filter, fused_conv_op);
patterns::LinkNodes(elementwise_add_x, fused_conv_op);
patterns::LinkNodes(fused_conv_op, conv_output); patterns::LinkNodes(fused_conv_op, conv_output);
}; };
......
...@@ -396,7 +396,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -396,7 +396,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(eltwise_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion"); PADDLE_ENFORCE(eltwise_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion");
PADDLE_ENFORCE_EQ(output->dims(), eltwise_param->dims(), "Output and elementwise parameter need to have the same dimension sizes"); PADDLE_ENFORCE_EQ(output->dims(), eltwise_param->dims(), "Output and elementwise parameter need to have the same dimension sizes");
output_data = const_cast<T*>(eltwise_param_data); output_data = output->mutable_data<T>(ctx.GetPlace());
output->ShareDataWith(*eltwise_param);
} else { } else {
output_data = output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
......
...@@ -134,7 +134,8 @@ void Conv2DOpMaker::Make() { ...@@ -134,7 +134,8 @@ void Conv2DOpMaker::Make() {
.Reuse("Input"); .Reuse("Input");
AddInput("EltwiseParameter", AddInput("EltwiseParameter",
"(Tensor) Tensor to which convolution output will be added." "(Tensor) Tensor to which convolution output will be added."
"Used on with fuse_eltwise fusion."); "Used on with fuse_eltwise fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册