From 87a44b114965518417dff7734fd5c7d3f526354a Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Mon, 15 Apr 2019 18:29:18 +0200 Subject: [PATCH] [MKL-DNN] Added reusing of primitive descriptors (fp32) (#16667) * - Reuse of conv PD - conv transpose pd reused - Added PD reusing of softmax and Batch Norm - Refactoring and removal of not needed routines of mkl-dnn ops test=develop - Fix to reusing conv test=develop - Lint fixes test=develop - Further lint fixes test=develop - Lint fixes test=develop - lint fixes test=develop - Lint workaround test=develop * - Fix after review on including boost as third party header test=develop * - Fix after review. Name change to something more descriptive test=develop --- .../operators/mkldnn/batch_norm_mkldnn_op.cc | 40 ++++++--- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 89 ++----------------- .../mkldnn/conv_transpose_mkldnn_op.cc | 82 ++--------------- .../operators/mkldnn/softmax_mkldnn_op.cc | 38 +++++--- paddle/fluid/platform/mkldnn_reuse.h | 87 ++++++++++++++++++ 5 files changed, 153 insertions(+), 183 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc index bddca232e..911c4d22e 100644 --- a/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc @@ -39,13 +39,9 @@ struct bn_type_traits { class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { public: - BatchNormMKLDNNHandler( - std::shared_ptr batch_norm_pd, - const platform::MKLDNNDeviceContext &dev_ctx, mkldnn::engine engine, - const std::string &base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key) { - batch_norm_pd_ = batch_norm_pd; - } + BatchNormMKLDNNHandler(const platform::MKLDNNDeviceContext &dev_ctx, + mkldnn::engine engine, const std::string &base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} std::shared_ptr AcquireScaleshiftMemoryFromPrimitive(void *ptr) { return this->AcquireMemoryFromPrimitive( @@ -62,6 +58,26 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); } + std::shared_ptr + AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc, + const mkldnn::engine &engine) { + const std::string key_batch_norm_fwd_pd = key_ + "@bn_fwd_pd"; + auto batch_norm_pd = + std::static_pointer_cast( + dev_ctx_.GetBlob(key_batch_norm_fwd_pd)); + + if (batch_norm_pd == nullptr) { + batch_norm_pd_.reset( + new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine)); + dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_); + } else { + batch_norm_pd_ = batch_norm_pd; + is_reusing_ = true; + } + + return batch_norm_pd_; + } + std::shared_ptr AcquireTestTrainingBatchNormFwd( std::shared_ptr src_memory, std::shared_ptr scaleshift_memory, @@ -213,7 +229,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { const std::string key = BatchNormMKLDNNHandler::GetHash( src_tz, epsilon, flags, global_stats, input_format, ctx.op().Output("SavedMean")); - const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; + BatchNormMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); auto user_src_md = platform::MKLDNNMemDesc( {src_tz}, platform::MKLDNNGetDataType(), input_format); @@ -222,13 +238,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel { using bn_fwd_types = bn_type_traits; auto batch_norm_fwd_desc = bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags}; - auto batch_norm_fwd_pd = std::make_shared( - batch_norm_fwd_desc, mkldnn_engine); - // Save conv_pd/src_memory/weights_memory for backward pass - dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd); - BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine, - key); + auto batch_norm_fwd_pd = handler.AcquireBatchNormPrimitiveDescriptor( + batch_norm_fwd_desc, mkldnn_engine); auto src_memory = handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data)); diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 5e4d79f1c..faf518005 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -144,7 +144,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::string key = platform::ConvMKLDNNHandler::GetHash( src_tz, weights_tz, strides, paddings, dilations, groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); - const std::string key_conv_pd = key + "@conv_pd"; std::vector pipeline; @@ -183,6 +182,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); + // create a conv primitive descriptor and save it for usage in backward std::shared_ptr conv_pd; auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference @@ -191,18 +192,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc( + conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind); } else { - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, - paddings, mkldnn_engine, fuse_relu, - fuse_residual_conn, fwd_prop_kind); + conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( + src_md, weights_md, boost::none, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind); } - // Save conv_pd/src_memory/weights_memory for backward pass - if (!is_test) dev_ctx.SetBlob(key_conv_pd, conv_pd); - - platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); // create mkldnn memory from input tensors (data/weights) auto user_src_memory_p = @@ -633,31 +630,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } private: - mkldnn::primitive_attr CreatePostOps(bool fuse_relu, - bool fuse_residual_conn) const { - mkldnn::primitive_attr conv_attr; - mkldnn::post_ops post_operations; - // Fusion with Elementwise layer relies on adding a sum post-operation with - // the scale parameter. It is assumed that when fuse_residual_connection is - // true, the output tensor contains the data coming from residual - // connection. The result of this post_op is: - // Output = scale * Output + Conv_Out. - if (fuse_residual_conn) { - post_operations.append_sum(1.0f); - } - // Fusion with ReLU layer is executed through the PostOps feature. Create a - // PostOps object and configure it to execute an eltwise relu operation. - if (fuse_relu) { - constexpr float scale = 1.0f; - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 0.0f; - post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, - negative_slope, placeholder); - } - conv_attr.set_post_ops(post_operations); - return conv_attr; - } - mkldnn::primitive_attr CreatePostOps( bool fuse_relu, bool fuse_residual_conn, const std::vector output_shift_scale, float sum_scale) const { @@ -679,30 +651,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { return conv_attr; } - std::unique_ptr - ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, - const memory::desc& dst, const std::vector& strides, - const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, - mkldnn::prop_kind fwd_prop_kind) const { - memory::dims stride_dims = strides; - memory::dims padding_dims = paddings; - - auto conv_desc = mkldnn::convolution_forward::desc( - fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst, - stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn); - - auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( - conv_desc, conv_attr, engine); - - return std::unique_ptr( - p_conv_pd); - } - std::unique_ptr ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, const memory::desc& dst, const std::vector& strides, @@ -731,31 +679,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { p_conv_pd); } - std::unique_ptr - ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, - const memory::desc& bias, const memory::desc& dst, - const std::vector& strides, - const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, - mkldnn::prop_kind fwd_prop_kind) const { - memory::dims stride_dims = strides; - memory::dims padding_dims = paddings; - - auto conv_desc = mkldnn::convolution_forward::desc( - fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst, - stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn); - - auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( - conv_desc, conv_attr, engine); - - return std::unique_ptr( - p_conv_pd); - } - std::unique_ptr ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, const memory::desc& bias, const memory::desc& dst, diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 317d4cebe..30d2469ee 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "boost/optional.hpp" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" @@ -124,7 +125,6 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash( src_tz, weights_tz, strides, paddings, dilations, groups, ctx.op().Output("Output")); - const std::string key_conv_transpose_pd = key + "@conv_transpose_pd"; std::vector pipeline; @@ -153,6 +153,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + platform::ConvTransposeMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); // create a deconv(conv transpose) primitive descriptor and save it for // usage in backward std::shared_ptr @@ -163,19 +164,14 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::x); - conv_transpose_pd = ConvTransposeFwdPrimitiveDesc( + conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fwd_prop_kind); + fuse_relu, false, fwd_prop_kind); } else { - conv_transpose_pd = ConvTransposeFwdPrimitiveDesc( - src_md, weights_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fwd_prop_kind); + conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( + src_md, weights_md, boost::none, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, false, fwd_prop_kind); } - // Save conv_pd/src_memory/weights_memory for backward pass - if (!is_test) dev_ctx.SetBlob(key_conv_transpose_pd, conv_transpose_pd); - - platform::ConvTransposeMKLDNNHandler handler(conv_transpose_pd, dev_ctx, - mkldnn_engine, key); // create mkldnn memory from input tensors (data/weights) auto user_src_memory_p = handler.AcquireSrcMemory( @@ -224,70 +220,6 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { output->set_layout(DataLayout::kMKLDNN); output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); } - - private: - mkldnn::primitive_attr CreatePostOps(bool fuse_relu) const { - mkldnn::primitive_attr conv_attr; - mkldnn::post_ops post_operations; - // Fusion with ReLU layer is executed through the PostOps feature. Create a - // PostOps object and configure it to execute an eltwise relu operation. - if (fuse_relu) { - constexpr float scale = 1.0f; - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 0.0f; - post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, - negative_slope, placeholder); - } - conv_attr.set_post_ops(post_operations); - return conv_attr; - } - - std::unique_ptr - ConvTransposeFwdPrimitiveDesc( - const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights, - const mkldnn::memory::desc& dst, const std::vector& strides, - const std::vector& paddings, const mkldnn::engine& engine, - const bool fuse_relu, mkldnn::prop_kind fwd_prop_kind) const { - mkldnn::memory::dims stride_dims = {strides[0], strides[1]}; - mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]}; - - auto deconv_desc = mkldnn::deconvolution_forward::desc( - fwd_prop_kind, mkldnn::deconvolution_direct, src, weights, dst, - stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - - mkldnn::primitive_attr deconv_attr = CreatePostOps(fuse_relu); - - auto p_conv_transpose_pd = - new mkldnn::deconvolution_forward::primitive_desc(deconv_desc, - deconv_attr, engine); - - return std::unique_ptr( - p_conv_transpose_pd); - } - - std::unique_ptr - ConvTransposeFwdPrimitiveDesc( - const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights, - const mkldnn::memory::desc& bias, const mkldnn::memory::desc& dst, - const std::vector& strides, const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - mkldnn::prop_kind fwd_prop_kind) const { - mkldnn::memory::dims stride_dims = {strides[0], strides[1]}; - mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]}; - - auto deconv_desc = mkldnn::deconvolution_forward::desc( - fwd_prop_kind, mkldnn::deconvolution_direct, src, weights, bias, dst, - stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - - mkldnn::primitive_attr deconv_attr = CreatePostOps(fuse_relu); - - auto p_conv_transpose_pd = - new mkldnn::deconvolution_forward::primitive_desc(deconv_desc, - deconv_attr, engine); - - return std::unique_ptr( - p_conv_transpose_pd); - } }; } // namespace operators diff --git a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc index dc1176f08..1b3f33d34 100644 --- a/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc @@ -34,12 +34,9 @@ using platform::to_void_cast; class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { public: - SoftmaxMKLDNNHandler( - std::shared_ptr softmax_pd, - const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, - const std::string& base_key) - : platform::MKLDNNHandler(dev_ctx, engine, base_key), - softmax_pd_(softmax_pd) {} + SoftmaxMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} SoftmaxMKLDNNHandler( std::shared_ptr softmax_pd, @@ -54,6 +51,26 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { key_ += "-BWD"; } + std::shared_ptr + AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc, + const mkldnn::engine& engine) { + const std::string key_softmax_pd = key_ + "@softmax_pd"; + + auto softmax_pd = std::static_pointer_cast( + dev_ctx_.GetBlob(key_softmax_pd)); + + if (softmax_pd == nullptr) { + softmax_pd_.reset( + new softmax_forward::primitive_desc(softmax_desc, engine)); + dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_); + } else { + softmax_pd_ = softmax_pd; + is_reusing_ = true; + } + + return softmax_pd_; + } + std::shared_ptr AcquireSoftmax( std::shared_ptr dst_memory_p, std::shared_ptr src_memory_p) { @@ -138,19 +155,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { // Generate keys for storing/retriving primitives for this operator const std::string key = platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out")); - const std::string key_softmax_pd = key + "@softmax_pd"; + SoftmaxMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); // Currently only NC data format is supported auto softmax_md = MKLDNNMemDesc( {softmax_tz}, platform::MKLDNNGetDataType(), memory::format::nc); // Normalization is made after innermost dimension eg. C out of NC auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, softmax_md, 1 /*dim: C*/); - auto softmax_pd = std::make_shared( - softmax_desc, mkldnn_engine); - dev_ctx.SetBlob(key_softmax_pd, softmax_pd); - SoftmaxMKLDNNHandler handler(softmax_pd, dev_ctx, mkldnn_engine, key); + auto softmax_pd = + handler.AcquireSoftmaxPrimitiveDescriptor(softmax_desc, mkldnn_engine); + auto softmax_src_memory_p = handler.AcquireSrcMemory(softmax_md, to_void_cast(input_data)); auto softmax_dst_memory_p = diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ecaad4ec0..ba3a82b4b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include +#include "boost/optional.hpp" #include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/mkldnn_helper.h" @@ -395,9 +396,28 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { std::vector logical_axis_; }; +template +struct convolutional_algorithm; + +template <> +struct convolutional_algorithm { + static constexpr mkldnn::algorithm T = mkldnn::algorithm::convolution_direct; +}; + +template <> +struct convolutional_algorithm { + static constexpr mkldnn::algorithm T = + mkldnn::algorithm::deconvolution_direct; +}; + template class ConvMKLDNNTemplateHandler : public MKLDNNHandler { public: + ConvMKLDNNTemplateHandler(const platform::MKLDNNDeviceContext& dev_ctx, + mkldnn::engine engine, const std::string& base_key) + : platform::MKLDNNHandler(dev_ctx, engine, base_key) {} + + // TODO(jczaja): remove after conv int8 is adapted ConvMKLDNNTemplateHandler( std::shared_ptr conv_pd, const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, @@ -542,6 +562,73 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { scale_data, mask); } + mkldnn::primitive_attr CreatePostOps(bool fuse_relu, + bool fuse_residual_conn = false) const { + mkldnn::primitive_attr conv_attr; + mkldnn::post_ops post_operations; + // Fusion with Elementwise layer relies on adding a sum post-operation with + // the scale parameter. It is assumed that when fuse_residual_connection is + // true, the output tensor contains the data coming from residual + // connection. The result of this post_op is: + // Output = scale * Output + Conv_Out. + if (fuse_residual_conn) { + post_operations.append_sum(1.0f); + } + // Fusion with ReLU layer is executed through the PostOps feature. Create a + // PostOps object and configure it to execute an eltwise relu operation. + if (fuse_relu) { + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + } + conv_attr.set_post_ops(post_operations); + return conv_attr; + } + + std::shared_ptr + AcquireConvolutionPrimitiveDescriptor( + const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights, + boost::optional bias, + const mkldnn::memory::desc& dst, const std::vector& strides, + const std::vector& paddings, const mkldnn::engine& engine, + const bool fuse_relu, const bool fuse_residual_conn, + mkldnn::prop_kind fwd_prop_kind) { + const std::string key_conv_pd = key_ + "@conv_pd"; + + auto conv_pd = std::static_pointer_cast( + dev_ctx_.GetBlob(key_conv_pd)); + + if (conv_pd == nullptr) { + mkldnn::memory::dims stride_dims = strides; + mkldnn::memory::dims padding_dims = paddings; + + auto conv_desc = + bias ? typename forward_t::desc( + fwd_prop_kind, convolutional_algorithm::T, src, + weights, *bias, dst, stride_dims, padding_dims, + padding_dims, mkldnn::padding_kind::zero) + : typename forward_t::desc( + fwd_prop_kind, convolutional_algorithm::T, src, + weights, dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); + + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_residual_conn); + + conv_pd_.reset( + new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); + // Save conv_pd/src_memory/weights_memory for backward pass + dev_ctx_.SetBlob(key_conv_pd, conv_pd_); + } else { + conv_pd_ = conv_pd; + is_reusing_ = true; + } + + return conv_pd_; + } + std::shared_ptr AcquireConvolution( std::shared_ptr src_memory_p, std::shared_ptr weights_memory_p, -- GitLab