提交 f926650c 编写于 作者: Z zhousiyi

if AbstractFunction comparison succeed in NewContext, then the evaluator...

if AbstractFunction comparison succeed in NewContext, then the evaluator should use the same one, otherwise one of the evaluator will not be evaluated.
if funcgraph or metafuncgraph call it recursively, then anf_node should be used as tracking_id to discriminate the first occurcance and the
recursive occurance.
add anf_node to PrimitiveAbstractClosure hash() to reduce cost of GetEvaluatorFor().

ignore the tracking_id to make cse work.
上级 8ff7c0b6
......@@ -36,6 +36,11 @@ BasePtr AbsOf(const AnfNodePtr &node) {
if (node_abs == nullptr) {
return kAnyValue;
}
// Ignore the tracking_id and prim pointer hash;
if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
auto prim_abs = node_abs->cast<abstract::PrimitiveAbstractClosurePtr>();
return prim_abs->prim();
}
return node_abs;
}
......
......@@ -470,7 +470,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
MS_LOG(DEBUG) << "The tracking_id: " << func->tracking_id()->DebugString();
}
MS_EXCEPTION_IF_NULL(func);
if (func->tracking_id() == nullptr) {
if (func->tracking_id() == nullptr || func->isa<abstract::MetaFuncGraphAbstractClosure>() ||
func->isa<abstract::FuncGraphAbstractClosure>()) {
EvaluatorPtr evaluator = _GetEvaluatorFor(func);
return evaluator;
}
......@@ -639,12 +640,12 @@ EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
}
abstract::AbstractBasePtr MakeAbstractClosure(const FuncGraphPtr &func_graph,
const abstract::AnalysisContextPtr &context) {
const abstract::AnalysisContextPtr &context, const AnfNodePtr &anf_node) {
AnalysisContextPtr temp_context = context;
if (temp_context == nullptr) {
temp_context = abstract::AnalysisContext::DummyContext();
}
return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context);
return std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, anf_node);
}
abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const AnfNodePtr &anf_node) {
......@@ -652,7 +653,8 @@ abstract::AbstractBasePtr MakeAbstractClosure(const MetaFuncGraphPtr &meta_func_
if (anf_node == nullptr) {
meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph);
} else {
meta_func_graph_fn = std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node->scope());
meta_func_graph_fn =
std::make_shared<abstract::MetaFuncGraphAbstractClosure>(meta_func_graph, anf_node, anf_node->scope());
}
return meta_func_graph_fn;
}
......@@ -663,14 +665,14 @@ abstract::AbstractBasePtr MakeAbstractClosure(const PrimitivePtr &primitive, con
}
AbstractBasePtr ToAbstract(const ValuePtr &value, const AnalysisContextPtr &context, const AnfNodeConfigPtr &conf) {
if (value->isa<FuncGraph>()) {
auto func_graph = value->cast<FuncGraphPtr>();
return MakeAbstractClosure(func_graph, context);
}
AnfNodePtr anf_node = nullptr;
if (conf != nullptr) {
anf_node = conf->node();
}
if (value->isa<FuncGraph>()) {
auto func_graph = value->cast<FuncGraphPtr>();
return MakeAbstractClosure(func_graph, context, anf_node);
}
if (value->isa<MetaFuncGraph>()) {
auto meta_func_graph = value->cast<MetaFuncGraphPtr>();
return MakeAbstractClosure(meta_func_graph, anf_node);
......
......@@ -232,7 +232,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const PrimEvaluatorMap &prim_constructors_;
FuncGraphManagerPtr func_graph_manager_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr> constructors_;
std::unordered_map<AbstractFunctionPtr, EvaluatorPtr, AbstractFunctionHasher, AbstractFunctionEqual> constructors_;
AnfNodeConfigMap anfnode_config_map_;
// Use a list to trace multiple evaluators.
std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_;
......
......@@ -143,14 +143,23 @@ bool PrimitiveAbstractClosure::operator==(const AbstractFunction &other) const {
return false;
}
std::size_t PrimitiveAbstractClosure::hash() const { return hash_combine(tid(), prim_->hash()); }
std::size_t PrimitiveAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), prim_->hash());
// Keep in sync with operator==() which compares the prim_ pointer;
hash_value = hash_combine(hash_value, std::hash<Primitive *>{}(prim_.get()));
if (tracking_id() != nullptr) {
hash_value = hash_combine(hash_value, tracking_id()->hash());
}
return hash_value;
}
bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
if (!other.isa<FuncGraphAbstractClosure>()) {
return false;
}
auto other_fg = static_cast<const FuncGraphAbstractClosure *>(&other);
if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_) {
if (func_graph_ == other_fg->func_graph_ && context_ == other_fg->context_ &&
tracking_id() == other_fg->tracking_id()) {
return true;
}
return false;
......@@ -159,9 +168,11 @@ bool FuncGraphAbstractClosure::operator==(const AbstractFunction &other) const {
std::size_t FuncGraphAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), func_graph_->hash());
hash_value = hash_combine(hash_value, context_->hash());
if (tracking_id() != nullptr) {
hash_value = hash_combine(hash_value, tracking_id()->hash());
}
return hash_value;
}
std::string FuncGraphAbstractClosure::ToString() const {
std::stringstream ss;
ss << "FuncGraphAbstractClosure: "
......@@ -174,7 +185,7 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con
return false;
}
auto other_meta_fg = static_cast<const MetaFuncGraphAbstractClosure *>(&other);
if (meta_func_graph_ == other_meta_fg->meta_func_graph_) {
if (meta_func_graph_ == other_meta_fg->meta_func_graph_ && tracking_id() == other_meta_fg->tracking_id()) {
return true;
}
return false;
......@@ -182,6 +193,9 @@ bool MetaFuncGraphAbstractClosure::operator==(const AbstractFunction &other) con
std::size_t MetaFuncGraphAbstractClosure::hash() const {
auto hash_value = hash_combine(tid(), meta_func_graph_->hash());
if (tracking_id() != nullptr) {
hash_value = hash_combine(hash_value, tracking_id()->hash());
}
return hash_value;
}
......
......@@ -92,13 +92,15 @@ class PrimitiveAbstractClosure : public AbstractFuncAtom {
// one reference cycle example is Graph::set_output() input0 local variable.
AnfNodeWeakPtr tracking_id_;
};
using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>;
class FuncGraphAbstractClosure : public AbstractFuncAtom {
public:
// Represents a Graph in a certain Context.
// context: The context, or Context.empty()
FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context)
: func_graph_(func_graph), context_(context) {
FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const AnfNodePtr &tracking_id = nullptr)
: func_graph_(func_graph), context_(context), tracking_id_(AnfNodeWeakPtr(tracking_id)) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(context);
}
......@@ -109,8 +111,10 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
AnalysisContextPtr context() const override { return context_; }
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
AbstractFunctionPtr Copy() const override {
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_);
return std::make_shared<FuncGraphAbstractClosure>(func_graph_, context_, tracking_id());
}
bool operator==(const AbstractFunction &other) const override;
......@@ -121,13 +125,22 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom {
private:
FuncGraphPtr func_graph_;
AnalysisContextPtr context_;
// To discriminate different usage of same graph by using this tracking_id,
// so different tracking_id will produce different FuncGraphAbstractClosure,
// different FuncGraphEvaluator.
// Espcecially usefull for recursive func graph call, so it will not mess up
// the graph_context_ in FuncGraphEvaluator.
// Notes: Be careful to use nullptr for this variable.
// store it as weak_ptr to break reference cycle.
AnfNodeWeakPtr tracking_id_;
};
using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>;
class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
public:
explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, const ScopePtr &scope = kDefaultScope)
: meta_func_graph_(meta_func_graph), scope_(scope) {}
explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph,
const AnfNodePtr &tracking_id = nullptr, const ScopePtr &scope = kDefaultScope)
: meta_func_graph_(meta_func_graph), tracking_id_(AnfNodeWeakPtr(tracking_id)), scope_(scope) {}
~MetaFuncGraphAbstractClosure() override = default;
MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom)
......@@ -137,7 +150,11 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
ScopePtr GetScope() { return scope_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_); }
AnfNodePtr tracking_id() const override { return tracking_id_.lock(); }
AbstractFunctionPtr Copy() const override {
return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id());
}
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override;
......@@ -145,6 +162,9 @@ class MetaFuncGraphAbstractClosure : public AbstractFuncAtom {
private:
MetaFuncGraphPtr meta_func_graph_;
// refer the comment in FuncGraphAbstractClosure;
// store it as weak_ptr to break reference cycle.
AnfNodeWeakPtr tracking_id_;
ScopePtr scope_;
};
using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>;
......
......@@ -67,3 +67,62 @@ def test_assign_in_while():
z = Tensor(np.random.randn(*input_shape).astype(np.float32))
net = Net(input_shape)
net(x, y, z)
def test_dup_context():
''' different func_with_fv in net1 and net2 should produce 2 different FuncGraphAbstractClosure and
Evaluator.
'''
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
def identity(f):
return f
def func_with_fv():
return x
def net1():
local_func = identity(func_with_fv)
out = local_func() + 20.0
return out
def net2():
local_func = identity(func_with_fv)
out = local_func() + 15.0
return out
return net1() + net2()
Net()(5.0)
def test_maybe_poly_func():
''' different func_with_fv in net1 and net2 may produce poly node. '''
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y, z):
def identity(f, inp):
return f(inp)
def func_with_fv(yy):
return (x, yy)
def make_call():
out1 = identity(func_with_fv, y)
out2 = identity(func_with_fv, z)
return (out1, out2)
return make_call()
y_input = Tensor(np.array([1, 2]).astype(np.int32))
z_input = Tensor(np.array([[2, 2], [3, 3]]).astype(np.int32))
Net()(1, y_input, z_input)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册