未验证 提交 9adad42d 编写于 作者: H HongyuJia 提交者: GitHub

[Opt Kernel Selection] Opt CanMKLDNNBeUsed performance (#47563)

* opt CanMKLDNNBeUsed performance

* fix nullptr bug

* fix OpBase default_attrs=nullptr bug

* fix OpBase default_attrs=nullptr bug

* fix OpBase default_attrs=nullptr bug
上级 a2761308
......@@ -1414,16 +1414,10 @@ bool OperatorWithKernel::SupportsKernelType(
bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const {
// NOTE(jiahongyu): Only mkldnn kernels need to check "use_mkldnn" attribute,
// hence we first call function SupportsMKLDNN. If we check "use_mkldnn"
// attribute first, it will cause error because some codes add "use_mkldnn"
// attribute to non-mkldnn ops.
if (!this->SupportsMKLDNN(data_type)) {
return false;
}
const std::string use_mkldnn_attr = "use_mkldnn";
return ctx.HasAttr(use_mkldnn_attr) && ctx.Attr<bool>(use_mkldnn_attr) &&
platform::is_cpu_place(ctx.GetPlace());
platform::is_cpu_place(ctx.GetPlace()) &&
this->SupportsMKLDNN(data_type);
}
void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
......
......@@ -209,7 +209,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap* default_attrs_;
const framework::AttributeMap* default_attrs_ = nullptr;
const std::map<std::string, std::string>& inplace_map_;
};
......
......@@ -103,7 +103,8 @@ class DygraphExecutionContext : public framework::ExecutionContext {
bool HasAttr(const std::string& name) const override {
if (attrs_.find(name) == attrs_.end()) {
return default_attrs_.find(name) != default_attrs_.end();
return &default_attrs_ != nullptr &&
default_attrs_.find(name) != default_attrs_.end();
}
return true;
}
......
......@@ -221,7 +221,7 @@ class OpBase {
NameVarMap<VariableWrapper> ins_;
NameVarMap<VariableWrapper> outs_;
framework::AttributeMap attrs_;
const framework::AttributeMap* default_attrs_;
const framework::AttributeMap* default_attrs_ = nullptr;
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册