未验证 提交 7dbc441e 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] cache cosmetics improvement (#25576)

上级 1a5d3def
......@@ -943,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@forward_pd";
const std::string key_conv_pd = key + "@fwd_pd";
std::vector<primitive> pipeline;
// Create user memory descriptors
......
......@@ -54,7 +54,7 @@ class MKLDNNHandlerT {
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@forward_p";
const std::string key_p = key_ + "@fwd_p";
auto forward_p =
std::static_pointer_cast<TForward>(dev_ctx_.GetBlob(key_p));
if (forward_p == nullptr) {
......@@ -65,7 +65,7 @@ class MKLDNNHandlerT {
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
const std::string key_p = key_ + "@backward_p";
const std::string key_p = key_ + "@bwd_p";
auto backward_p =
std::static_pointer_cast<TBackward>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
......@@ -112,11 +112,11 @@ class MKLDNNHandlerT {
protected:
bool isCached() {
const std::string key_pd = key_common_ + "@forward_pd";
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
const std::string key_p = key_ + "@forward_p";
const std::string key_p = key_ + "@fwd_p";
return (dev_ctx_.GetBlob(key_p) != nullptr);
}
......@@ -129,7 +129,7 @@ class MKLDNNHandlerT {
// Forward PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_pd = key_common_ + "@forward_pd";
const std::string key_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (fwd_pd_ == nullptr) {
......@@ -169,13 +169,13 @@ class MKLDNNHandlerT {
template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@forward_pd";
const std::string key_fwd_pd = key_common_ + "@fwd_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_, platform::errors::Unavailable(
"Get MKLDNN Forward primitive %s failed.", key_fwd_pd));
const std::string key_pd = key_ + "@backward_pd";
const std::string key_pd = key_ + "@bwd_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_pd_ == nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册