未验证 提交 df82fd35 编写于 作者: Y YuanRisheng 提交者: GitHub

[BugFix]Fix OneDNN Kernels Bug when use pass (#48364)

* Fix onednn kernel bugs

* fix gpu bugs
上级 b4b926f4
...@@ -3233,6 +3233,29 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -3233,6 +3233,29 @@ void OperatorWithKernel::BuildPhiKernelContext(
} }
VLOG(4) << "Done attributes"; VLOG(4) << "Done attributes";
// Clear All old attrs before add new attrs,
// because sometimes old attrs may be misused.
#if defined(PADDLE_WITH_MKLDNN)
if (phi::OneDNNContext::classof(dev_ctx)) {
phi::OneDNNContext* one_dnn_ctx = static_cast<phi::OneDNNContext*>(dev_ctx);
one_dnn_ctx->ClearDnnAttr();
}
#endif
// Note(YuanRisheng): Now, we can't open code below.
// Because some unittest run OLD dygraph and ExtraAttr is not supported in OLD
// dygraph. So, here we use trick that dev_ctx is a global object. We can
// store ExtraAttr in static graph and when unittest run OLD dygraph, it can
// obtain these ExtraAttr. We can open this code when OLD dygraph is no longer
// used.
/*
#if defined(PADDLE_WITH_CUDA)
if(phi::GPUContext::classof(dev_ctx)) {
phi::GPUContext* gpu_dnn_ctx = static_cast<phi::GPUContext*>(dev_ctx);
gpu_dnn_ctx->ClearDnnAttr();
}
#endif
*/
// For compatible with Op with extra attrs for specific backend // For compatible with Op with extra attrs for specific backend
#if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_MKLDNN) || defined(PADDLE_WITH_CUDA)
auto& runtime_attrs = RuntimeAttrs(); auto& runtime_attrs = RuntimeAttrs();
......
...@@ -740,6 +740,8 @@ struct GPUContext::Impl { ...@@ -740,6 +740,8 @@ struct GPUContext::Impl {
dnn_attrs_[attr_name] = attr; dnn_attrs_[attr_name] = attr;
} }
void ClearDnnAttr() { dnn_attrs_.clear(); }
// use one flag for all handles? // use one flag for all handles?
// they should be accessed consistently // they should be accessed consistently
bool owned_{false}; bool owned_{false};
...@@ -1042,4 +1044,6 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) { ...@@ -1042,4 +1044,6 @@ void GPUContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
return impl_->SetDnnAttr(attr_name, std::move(attr)); return impl_->SetDnnAttr(attr_name, std::move(attr));
} }
void GPUContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }
} // namespace phi } // namespace phi
...@@ -172,6 +172,7 @@ class PADDLE_API GPUContext : public DeviceContext, ...@@ -172,6 +172,7 @@ class PADDLE_API GPUContext : public DeviceContext,
bool HasDnnAttr(const std::string& attr_name) const; bool HasDnnAttr(const std::string& attr_name) const;
const Attribute& GetDnnAttr(const std::string& attr_name) const; const Attribute& GetDnnAttr(const std::string& attr_name) const;
void SetDnnAttr(const std::string& attr_name, Attribute attr); void SetDnnAttr(const std::string& attr_name, Attribute attr);
void ClearDnnAttr();
static const char* name() { return "GPUContext"; } static const char* name() { return "GPUContext"; }
......
...@@ -301,6 +301,8 @@ struct OneDNNContext::Impl { ...@@ -301,6 +301,8 @@ struct OneDNNContext::Impl {
dnn_attrs_[attr_name] = attr; dnn_attrs_[attr_name] = attr;
} }
void ClearDnnAttr() { dnn_attrs_.clear(); }
bool HasDnnInput(const std::string& input_name) const { bool HasDnnInput(const std::string& input_name) const {
return dnn_inputs_.count(input_name) != 0UL; return dnn_inputs_.count(input_name) != 0UL;
} }
...@@ -425,6 +427,8 @@ void OneDNNContext::SetDnnAttr(const std::string& attr_name, Attribute attr) { ...@@ -425,6 +427,8 @@ void OneDNNContext::SetDnnAttr(const std::string& attr_name, Attribute attr) {
return impl_->SetDnnAttr(attr_name, std::move(attr)); return impl_->SetDnnAttr(attr_name, std::move(attr));
} }
void OneDNNContext::ClearDnnAttr() { return impl_->ClearDnnAttr(); }
bool OneDNNContext::HasDnnInput(const std::string& input_name) const { bool OneDNNContext::HasDnnInput(const std::string& input_name) const {
return impl_->HasDnnInput(input_name); return impl_->HasDnnInput(input_name);
} }
......
...@@ -146,6 +146,8 @@ class OneDNNContext : public CPUContext { ...@@ -146,6 +146,8 @@ class OneDNNContext : public CPUContext {
const DenseTensor* GetDnnInput(const std::string& input_name) const; const DenseTensor* GetDnnInput(const std::string& input_name) const;
void SetDnnInput(const std::string& input_name, const DenseTensor* input); void SetDnnInput(const std::string& input_name, const DenseTensor* input);
void ClearDnnAttr();
void SetInputsName(const TensorNameMap& inputs_name); void SetInputsName(const TensorNameMap& inputs_name);
void SetOutputsName(const TensorNameMap& outputs_name); void SetOutputsName(const TensorNameMap& outputs_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册