From 7da5368dd52405867da5dac835cec2828b582524 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 23 Dec 2021 12:24:28 +0100 Subject: [PATCH] Make GetBlob assuming elements are cached (#38336) * First set of fixes * - Make more likely to GetBlob find a blobs * - Lint --- paddle/fluid/framework/data_transform.cc | 4 ++-- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 5 ++--- paddle/fluid/platform/device_context.cc | 18 +++++++++++++++--- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 70693a5df26..16c1923ce18 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -43,8 +43,8 @@ void TransformData(const OpKernelType &expected_kernel_type, Tensor in; in.ShareDataWith(input_tensor); Tensor out; - DataLayout lin = kernel_type_for_var.data_layout_; - DataLayout lout = expected_kernel_type.data_layout_; + const DataLayout lin = kernel_type_for_var.data_layout_; + const DataLayout lout = expected_kernel_type.data_layout_; // do layout transform if (NeedTransformLayout(lout, lin)) { diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index d584da72393..6bfa8032fdc 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -806,11 +806,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { ctx.Attr>("Scale_weights"); const bool is_multi_channel = scale_weights_data.size() > 1; const int& groups = ctx.Attr("groups"); - const bool& is_test = ctx.Attr("is_test"); int mask_reorder = is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, groups, false, is_test, scale_weights_data, mask_reorder); + filter, groups, false, true, scale_weights_data, mask_reorder); std::shared_ptr dst_memory_p; if (fuse_residual_conn) { @@ -842,7 +841,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel { auto p_scales_tuple = handler.get_int8_bias_scales(ctx); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( - bias, is_test, std::get<1>(*p_scales_tuple), + bias, true, std::get<1>(*p_scales_tuple), std::get<0>(*p_scales_tuple)); args.insert({DNNL_ARG_BIAS, *bias_memory_p}); } diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 07508da703d..23c4f216ba9 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -869,6 +869,15 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const { return num_entries; } +// TODO(jczaja): Replace with C++20 equivalents when applicable +#ifdef _WIN32 +#define likely(expr) (expr) +#define unlikely(expr) (expr) +#else +#define likely(expr) (__builtin_expect(!!(expr), 1)) +#define unlikely(expr) (__builtin_expect(!!(expr), 0)) +#endif + MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( const std::string& name) const { BlobMap* pMap = p_blobmap_.get(); @@ -881,7 +890,10 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( // Find ShapeBlob for current mkldnn session id firstly auto map_it = pMap->find(sid); - if (map_it == pMap->end()) { + // (jczaja): After first iteration of model's execution we + // should have all elements cached (mostly) so failures are unlikely (less + // likely for dynamic shapes) + if (unlikely(map_it == pMap->end())) { VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n"; return nullptr; } @@ -889,7 +901,7 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( // Find KeyBlob for current input shape secondly auto sBlob_it = sBlob->find(tls().cur_input_shape_str); - if (sBlob_it == sBlob->end()) { + if (unlikely(sBlob_it == sBlob->end())) { VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str << ", miss input_shape_str\n"; return nullptr; @@ -899,7 +911,7 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( // Find Blob via name auto key_it = pBlob->find(name); - if (key_it == pBlob->end()) { + if (unlikely(key_it == pBlob->end())) { VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; return nullptr; } -- GitLab