diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 758a8f4a41f686aa9ad5ee965d0586df5b89476f..2dd67c3108095e9c3b3152d67701c267330320e3 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -378,20 +378,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::shared_ptr conv_p; std::shared_ptr src_memory_p; + std::shared_ptr user_src_memory_p; std::shared_ptr dst_memory_p; std::vector pipeline; auto prim_key = key + "@conv_p"; auto dst_key = key + "@dst_mem_p"; auto src_key = key + "@src_mem_p"; + auto user_src_key = key + "@user_src_mem_p"; + auto src_reorder_key = key + "@src_mem_p" + "reorder_p"; conv_p = std::static_pointer_cast(dev_ctx.GetBlob(prim_key)); + auto src_memory_reorder_p = std::static_pointer_cast(dev_ctx.GetBlob(src_reorder_key)); src_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(src_key)); - dst_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(dst_key)); - - if (src_memory_p) { + if(src_memory_reorder_p){ + user_src_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(user_src_key)); + user_src_memory_p->set_data_handle(to_void_cast(input_data)); + } else if(src_memory_p){ src_memory_p->set_data_handle(to_void_cast(input_data)); } + dst_memory_p = std::static_pointer_cast(dev_ctx.GetBlob(dst_key)); + std::shared_ptr conv_pd; conv_pd = std::static_pointer_cast(dev_ctx.GetBlob(key_conv_pd)); std::shared_ptr handler; @@ -414,7 +421,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { residual_data_tz, residual_data_type, residual_param->format()); auto user_residual_memory_p = handler->AcquireResidualDataMemory( user_residual_md, to_void_cast(residual_param_data)); - dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( user_residual_memory_p, to_void_cast(output_data), pipeline); } else { @@ -462,11 +468,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), memory::format::x); conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn); + fuse_relu, fuse_residual_conn, is_test); } else { conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn); + mkldnn_engine, fuse_relu, fuse_residual_conn, is_test); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -474,7 +480,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); // create mkldnn memory from input tensors (data/weights) - auto user_src_memory_p = + user_src_memory_p = handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); auto user_weights_memory_p = handler->AcquireWeightsMemory( user_weights_md, to_void_cast(filter_data)); @@ -508,7 +514,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { residual_data_tz, residual_data_type, residual_param->format()); auto user_residual_memory_p = handler->AcquireResidualDataMemory( user_residual_md, to_void_cast(residual_param_data)); - dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( user_residual_memory_p, to_void_cast(output_data), pipeline); } else { @@ -546,10 +551,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); - } else { + } else { + if(src_memory_reorder_p){ + pipeline.push_back(*src_memory_reorder_p); + } pipeline.push_back(*conv_p); stream(stream::kind::eager).submit(pipeline).wait(); - + output->set_layout(DataLayout::kMKLDNN); output->set_format(GetMKLDNNFormat(*dst_memory_p)); }