未验证 提交 9adad42d 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Kernel Selection] Opt CanMKLDNNBeUsed performance (#47563)

* opt CanMKLDNNBeUsed performance

* fix nullptr bug

* fix OpBase default_attrs=nullptr bug

* fix OpBase default_attrs=nullptr bug

* fix OpBase default_attrs=nullptr bug
上级 a2761308
...@@ -1414,16 +1414,10 @@ bool OperatorWithKernel::SupportsKernelType( ...@@ -1414,16 +1414,10 @@ bool OperatorWithKernel::SupportsKernelType(
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const { proto::VarType::Type data_type) const {
// NOTE(jiahongyu): Only mkldnn kernels need to check "use_mkldnn" attribute,
// hence we first call function SupportsMKLDNN. If we check "use_mkldnn"
// attribute first, it will cause error because some codes add "use_mkldnn"
// attribute to non-mkldnn ops.
if (!this->SupportsMKLDNN(data_type)) {
return false;
}
const std::string use_mkldnn_attr = "use_mkldnn"; const std::string use_mkldnn_attr = "use_mkldnn";
return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr<bool>(use_mkldnn_attr) && return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr<bool>(use_mkldnn_attr) &&
platform::is_cpu_place(ctx.GetPlace()); platform::is_cpu_place(ctx.GetPlace()) &&
this->SupportsMKLDNN(data_type);
} }
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const { void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
......
...@@ -209,7 +209,7 @@ class GradOpBaseMakerBase { ...@@ -209,7 +209,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_; const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_; const framework::AttributeMap& attrs_;
const framework::AttributeMap* default_attrs_; const framework::AttributeMap* default_attrs_ = nullptr;
const std::map<std::string, std::string>& inplace_map_; const std::map<std::string, std::string>& inplace_map_;
}; };
......
...@@ -103,7 +103,8 @@ class DygraphExecutionContext : public framework::ExecutionContext { ...@@ -103,7 +103,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
bool HasAttr(const std::string& name) const override { bool HasAttr(const std::string& name) const override {
if (attrs_.find(name) == attrs_.end()) { if (attrs_.find(name) == attrs_.end()) {
return default_attrs_.find(name) != default_attrs_.end(); return &default_attrs_ != nullptr &&
default_attrs_.find(name) != default_attrs_.end();
} }
return true; return true;
} }
......
...@@ -221,7 +221,7 @@ class OpBase { ...@@ -221,7 +221,7 @@ class OpBase {
NameVarMap<VariableWrapper> ins_; NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_; NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_; framework::AttributeMap attrs_;
const framework::AttributeMap* default_attrs_; const framework::AttributeMap* default_attrs_ = nullptr;
std::unique_ptr<framework::OperatorBase> op_; std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_; platform::Place place_;
size_t id_{-1UL}; size_t id_{-1UL};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册