提交 dea10506 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!814 compare context pointer in AnfNodeConfig for performance

Merge pull request !814 from xychow/compare-with-context-ptr
......@@ -997,6 +997,9 @@ bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const Abstract
for (std::size_t i = 0; i < size; i++) {
MS_EXCEPTION_IF_NULL(lhs[i]);
MS_EXCEPTION_IF_NULL(rhs[i]);
if (lhs[i] == rhs[i]) {
continue;
}
if (!(*lhs[i] == *rhs[i])) {
return false;
}
......
......@@ -23,6 +23,24 @@
namespace mindspore {
namespace abstract {
AnalysisContextPtr AnalysisContext::NewContext(AnalysisContextPtr parent, FuncGraphPtr fg,
const AbstractBasePtrList &args_spec_list) {
auto children_context_map_iter = parent->children_cache_.find(fg);
if (children_context_map_iter != parent->children_cache_.end()) {
auto children_context_map = children_context_map_iter->second;
auto children_context_iter = children_context_map.find(args_spec_list);
if (children_context_iter != children_context_map.end()) {
return children_context_iter->second.lock();
}
}
AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
// Reference to myself, so use weak_ptr to break reference cycle.
auto weak_context = std::weak_ptr<AnalysisContext>(context_new);
context_new->parent_cache_[fg] = weak_context;
parent->children_cache_[fg][args_spec_list] = weak_context;
return context_new;
}
AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func_graph,
const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr graph_parent = func_graph->parent();
......@@ -89,6 +107,13 @@ AnalysisContextPtr AnalysisContext::DummyContext() {
return dummy_context;
}
bool AnalysisContext::IsDummyContext() {
if (parent_ == nullptr && func_graph_ == nullptr && args_spec_list_.empty()) {
return true;
}
return false;
}
const AnalysisContextPtr kDummyAnalysisContext =
std::make_shared<AnalysisContext>(nullptr, nullptr, AbstractBasePtrList());
......
......@@ -28,6 +28,11 @@
namespace mindspore {
namespace abstract {
class AnalysisContext;
using AnalysisContextWeakPtr = std::weak_ptr<AnalysisContext>;
using ArgsSpecToAnalysisContextMap =
std::unordered_map<AbstractBasePtrList, AnalysisContextWeakPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
// AnalysisContext will be stored in Config in AnalysisCache.
class AnalysisContext {
public:
......@@ -41,12 +46,7 @@ class AnalysisContext {
~AnalysisContext() = default;
// Helper function to wrapper constructor to save shared_ptr in parent_cache.
AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list) {
AnalysisContextPtr context_new = std::make_shared<AnalysisContext>(parent, fg, args_spec_list);
// Reference to myself, so use weak_ptr to break reference cycle.
context_new->parent_cache_[fg] = std::weak_ptr<AnalysisContext>(context_new);
return context_new;
}
AnalysisContextPtr NewContext(AnalysisContextPtr parent, FuncGraphPtr fg, const AbstractBasePtrList &args_spec_list);
// Extend this context with values for another graph.
AnalysisContextPtr NewFuncGraphContext(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
......@@ -56,6 +56,7 @@ class AnalysisContext {
bool operator==(const AnalysisContext &other) const;
std::size_t hash();
static AnalysisContextPtr DummyContext();
bool IsDummyContext();
FuncGraphPtr func_graph() const { return func_graph_; }
AnalysisContextPtr parent() const { return parent_; }
std::string ToString() const;
......@@ -66,7 +67,8 @@ class AnalysisContext {
AnalysisContextPtr parent_;
FuncGraphPtr func_graph_;
AbstractBasePtrList args_spec_list_;
std::unordered_map<FuncGraphPtr, std::weak_ptr<AnalysisContext>> parent_cache_;
std::unordered_map<FuncGraphPtr, AnalysisContextWeakPtr> parent_cache_;
std::unordered_map<FuncGraphPtr, ArgsSpecToAnalysisContextMap> children_cache_;
};
struct ContextHasher {
......
......@@ -87,7 +87,10 @@ AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(conf->node());
std::size_t hash_value = hash_combine(conf->node()->hash(), conf->context()->hash());
std::size_t hash_value = conf->node()->hash();
if (!conf->context()->IsDummyContext()) {
hash_value = hash_combine(hash_value, std::hash<AnalysisContext *>{}(conf->context().get()));
}
if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) {
MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString()
<< ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value;
......
......@@ -83,9 +83,12 @@ class AnfNodeConfig : public Config {
// used by unordered_map;
bool operator==(const AnfNodeConfig &other) const {
// compare node with pointer, context with content;
// compare node with pointer, context with pointer except DummyContext as it's created by make_shared;
// context should not be nullptr;
return (node_ == other.node_) && (*context_ == *other.context_);
if (context_->IsDummyContext() && other.context_->IsDummyContext()) {
return true;
}
return (node_ == other.node_) && (context_ == other.context_);
}
std::string ToString() const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册