From 1658958fe697f1b7a2c558e8bda06285826b058a Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Mon, 10 Sep 2018 14:57:10 +0200 Subject: [PATCH] Reusing converted weights --- paddle/fluid/operators/conv_mkldnn_op.cc | 9 ++++++--- paddle/fluid/operators/conv_op.cc | 1 + paddle/fluid/platform/mkldnn_helper.h | 6 +++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index c5cbadc8929..1ccf2494f27 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { std::shared_ptr AcquireWeightsMemoryFromPrimitive( const std::shared_ptr user_weights_memory_p, - std::vector& pipeline) { // NOLINT + const std::vector& pipeline, + bool is_test = false) { // NOLINT auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc(); return this->AcquireMemory(weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p", - pipeline); + pipeline, is_test); } std::shared_ptr AcquireBiasMemoryFromPrimitive( @@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); + const bool is_test = ctx.Attr("is_test"); + auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_memory_p = handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( - user_weights_memory_p, pipeline); + user_weights_memory_p, pipeline, is_test); auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data)); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 61ca80877a6..6070173ee2e 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } void Conv2DOpMaker::Make() { + AddAttr("is_test", "").SetDefault(false); AddInput( "Input", "(Tensor) The input tensor of convolution operator. " diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index f6e9a52b275..c64e5dafda6 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -191,8 +191,8 @@ class MKLDNNHandler { mkldnn::memory::primitive_desc& mpd, // NOLINT mkldnn::memory::primitive_desc& user_mpd, // NOLINT const std::shared_ptr user_memory_p, - const std::string& suffix, - std::vector& pipeline) { // NOLINT + const std::string& suffix, const std::vector& pipeline, + bool is_test = false) { // NOLINT // create reorder primitive if the input format is not the preferred one auto local_key = key_ + suffix; auto key_reorder_p = key_ + suffix + "reorder_p"; @@ -213,7 +213,7 @@ class MKLDNNHandler { pipeline.push_back(*reorder_p); } dev_ctx_.SetBlob(local_key, target_memory_p); - } else { + } else if (!is_test) { // Make reorder if needed auto reorder_p = std::static_pointer_cast( dev_ctx_.GetBlob(key_reorder_p)); -- GitLab