diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index eae65968285703f5882d910e29bc5d8e1511cba6..d9666c1cedc56242d795a764c25483191e090a5b 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -386,8 +386,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { auto user_weights_memory_p = handler.AcquireWeightsMemory( user_weights_md, to_void_cast<T>(filter_data)); - T* output_data = + + T* output_data = nullptr; + + if (fuse_eltwise) { + auto eltwise_param = ctx.Input<Tensor>("EltwiseParameter"); + auto eltwise_param_data = eltwise_param->data<T>(); + + PADDLE_ENFORCE(eltwise_param_data != nullptr, "Provide data if you want MKLDNN conv+elementwise_add fusion"); + PADDLE_ENFORCE_EQ(output->dims(), eltwise_param->dims(), "Output and elementwise parameter need to have the same dimension sizes"); + + output_data = const_cast<T*>(eltwise_param_data); + } else { + output_data = output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); + } + // create reorder primitive if the input format is not the preferred one auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 8f84bf71a7f77606bed6672f0830e3fc80165a42..efb8c62737593fae3811c451d3f3df594e31937f 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -132,6 +132,9 @@ void Conv2DOpMaker::Make() { "(Tensor) The output tensor of convolution operator. " "The format of output tensor is also NCHW.") .Reuse("Input"); + AddInput("EltwiseParameter", + "(Tensor) Tensor to which convolution output will be added." + "Used on with fuse_eltwise fusion."); AddAttr<std::vector<int>>("strides", "(vector<int> default:{1, 1}), the " "strides(h_stride, w_stride) of "