From 7dbc441eab258b23769b446e8c2f1179f682e8d2 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Fri, 17 Jul 2020 15:58:21 +0200 Subject: [PATCH] [oneDNN] cache cosmetics improvement (#25576) --- paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc | 2 +- paddle/fluid/platform/mkldnn_reuse.h | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index ac6ddebb81..17e1e19583 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -943,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { const std::string key = platform::CreateKey( src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); - const std::string key_conv_pd = key + "@forward_pd"; + const std::string key_conv_pd = key + "@fwd_pd"; std::vector pipeline; // Create user memory descriptors diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 6dc495fb00..5d7143f56b 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -54,7 +54,7 @@ class MKLDNNHandlerT { } std::shared_ptr AcquireForwardPrimitive() { - const std::string key_p = key_ + "@forward_p"; + const std::string key_p = key_ + "@fwd_p"; auto forward_p = std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); if (forward_p == nullptr) { @@ -65,7 +65,7 @@ class MKLDNNHandlerT { } std::shared_ptr AcquireBackwardPrimitive() { - const std::string key_p = key_ + "@backward_p"; + const std::string key_p = key_ + "@bwd_p"; auto backward_p = std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); if (backward_p == nullptr) { @@ -112,11 +112,11 @@ class MKLDNNHandlerT { protected: bool isCached() { - const std::string key_pd = key_common_ + "@forward_pd"; + const std::string key_pd = key_common_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); - const std::string key_p = key_ + "@forward_p"; + const std::string key_p = key_ + "@fwd_p"; return (dev_ctx_.GetBlob(key_p) != nullptr); } @@ -129,7 +129,7 @@ class MKLDNNHandlerT { // Forward PD has to be passed to Grad op that // may be executed by diffrent thread, hence // for that one we use key that does not contain TID - const std::string key_pd = key_common_ + "@forward_pd"; + const std::string key_pd = key_common_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); if (fwd_pd_ == nullptr) { @@ -169,13 +169,13 @@ class MKLDNNHandlerT { template void AcquireBackwardPrimitiveDescriptor(Args&&... args) { - const std::string key_fwd_pd = key_common_ + "@forward_pd"; + const std::string key_fwd_pd = key_common_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_fwd_pd)); PADDLE_ENFORCE_NOT_NULL( fwd_pd_, platform::errors::Unavailable( "Get MKLDNN Forward primitive %s failed.", key_fwd_pd)); - const std::string key_pd = key_ + "@backward_pd"; + const std::string key_pd = key_ + "@bwd_pd"; bwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); if (bwd_pd_ == nullptr) { -- GitLab