未验证 提交 7dbc441e 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] cache cosmetics improvement (#25576)

上级 1a5d3def
...@@ -943,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -943,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@forward_pd"; const std::string key_conv_pd = key + "@fwd_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
// Create user memory descriptors // Create user memory descriptors
......
...@@ -54,7 +54,7 @@ class MKLDNNHandlerT { ...@@ -54,7 +54,7 @@ class MKLDNNHandlerT {
} }
std::shared_ptr<TForward> AcquireForwardPrimitive() { std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@forward_p"; const std::string key_p = key_ + "@fwd_p";
auto forward_p = auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p)); std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) { if (forward_p == nullptr) {
...@@ -65,7 +65,7 @@ class MKLDNNHandlerT { ...@@ -65,7 +65,7 @@ class MKLDNNHandlerT {
} }
std::shared_ptr<TBackward> AcquireBackwardPrimitive() { std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
const std::string key_p = key_ + "@backward_p"; const std::string key_p = key_ + "@bwd_p";
auto backward_p = auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p)); std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) { if (backward_p == nullptr) {
...@@ -112,11 +112,11 @@ class MKLDNNHandlerT { ...@@ -112,11 +112,11 @@ class MKLDNNHandlerT {
protected: protected:
bool isCached() { bool isCached() {
const std::string key_pd = key_common_ + "@forward_pd"; const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
const std::string key_p = key_ + "@forward_p"; const std::string key_p = key_ + "@fwd_p";
return (dev_ctx_.GetBlob(key_p) != nullptr); return (dev_ctx_.GetBlob(key_p) != nullptr);
} }
...@@ -129,7 +129,7 @@ class MKLDNNHandlerT { ...@@ -129,7 +129,7 @@ class MKLDNNHandlerT {
// Forward PD has to be passed to Grad op that // Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
const std::string key_pd = key_common_ + "@forward_pd"; const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) { if (fwd_pd_ == nullptr) {
...@@ -169,13 +169,13 @@ class MKLDNNHandlerT { ...@@ -169,13 +169,13 @@ class MKLDNNHandlerT {
template <typename... Args> template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) { void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@forward_pd"; const std::string key_fwd_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>( fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd)); dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable( fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd)); "Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@backward_pd"; const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>( bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd)); dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) { if (bwd_pd_ == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册