diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index ac6ddebb813fab2bc5d1c1faaaa8d96bbc22dbd4..17e1e1958346155af32cf75b5e9fc25cdbdd91eb 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -943,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { const std::string key = platform::CreateKey( src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); - const std::string key_conv_pd = key + "@forward_pd"; + const std::string key_conv_pd = key + "@fwd_pd"; std::vector pipeline; // Create user memory descriptors diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 6dc495fb009a316d09bcd02b647cf5ca18ab0f47..5d7143f56b3f394bb1a99c1b3802b7c20138dfb7 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -54,7 +54,7 @@ class MKLDNNHandlerT { } std::shared_ptr AcquireForwardPrimitive() { - const std::string key_p = key_ + "@forward_p"; + const std::string key_p = key_ + "@fwd_p"; auto forward_p = std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); if (forward_p == nullptr) { @@ -65,7 +65,7 @@ class MKLDNNHandlerT { } std::shared_ptr AcquireBackwardPrimitive() { - const std::string key_p = key_ + "@backward_p"; + const std::string key_p = key_ + "@bwd_p"; auto backward_p = std::static_pointer_cast(dev_ctx_.GetBlob(key_p)); if (backward_p == nullptr) { @@ -112,11 +112,11 @@ class MKLDNNHandlerT { protected: bool isCached() { - const std::string key_pd = key_common_ + "@forward_pd"; + const std::string key_pd = key_common_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); - const std::string key_p = key_ + "@forward_p"; + const std::string key_p = key_ + "@fwd_p"; return (dev_ctx_.GetBlob(key_p) != nullptr); } @@ -129,7 +129,7 @@ class MKLDNNHandlerT { // Forward PD has to be passed to Grad op that // may be executed by diffrent thread, hence // for that one we use key that does not contain TID - const std::string key_pd = key_common_ + "@forward_pd"; + const std::string key_pd = key_common_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); if (fwd_pd_ == nullptr) { @@ -169,13 +169,13 @@ class MKLDNNHandlerT { template void AcquireBackwardPrimitiveDescriptor(Args&&... args) { - const std::string key_fwd_pd = key_common_ + "@forward_pd"; + const std::string key_fwd_pd = key_common_ + "@fwd_pd"; fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_fwd_pd)); PADDLE_ENFORCE_NOT_NULL( fwd_pd_, platform::errors::Unavailable( "Get MKLDNN Forward primitive %s failed.", key_fwd_pd)); - const std::string key_pd = key_ + "@backward_pd"; + const std::string key_pd = key_ + "@bwd_pd"; bwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); if (bwd_pd_ == nullptr) {