提交 5b9c145f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1383 keep different attributes for cnode evaluation

Merge pull request !1383 from amongo/KeepPrimAttrInCNode
......@@ -230,11 +230,11 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto ctx = node_cfg_->context();
auto engine = node_cfg_->engine();
auto cfg = engine->MakeConfig(node, ctx);
auto abs = engine->cache().GetValue(cfg);
if (abs == nullptr) {
auto eval_result = engine->cache().GetValue(cfg);
if (eval_result == nullptr || eval_result->abstract() == nullptr) {
return "Undefined";
}
auto abs = eval_result->abstract();
auto dtype = abs->BuildType();
auto shape = abs->BuildShape();
std::ostringstream oss;
......
......@@ -42,7 +42,11 @@ enum PrimType {
class Primitive : public Named {
public:
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name), is_base_(is_base), has_signature_(false), prim_type_(prim_type) {}
: Named(name),
is_base_(is_base),
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false) {}
Primitive(const Primitive &prim)
: Named(prim),
......@@ -50,14 +54,23 @@ class Primitive : public Named {
instance_name_(prim.instance_name_),
is_base_(prim.is_base_),
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_) {}
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false) {}
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
std::string ToString() const override { return name(); }
void BeginRecordAddAttr() {
evaluate_added_attrs_.clear();
record_evaluate_add_attr_ = true;
}
void EndRecordAddAttr() { record_evaluate_add_attr_ = false; }
Primitive &AddAttr(const std::string &name, const ValuePtr &attr) {
attrs_[name] = attr;
if (record_evaluate_add_attr_) {
evaluate_added_attrs_[name] = attr;
}
return *this;
}
......@@ -80,6 +93,7 @@ class Primitive : public Named {
py::function hook() const { return hook_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; }
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); }
......@@ -106,6 +120,7 @@ class Primitive : public Named {
protected:
std::unordered_map<std::string, ValuePtr> attrs_;
std::unordered_map<std::string, ValuePtr> evaluate_added_attrs_;
private:
std::string instance_name_;
......@@ -113,6 +128,7 @@ class Primitive : public Named {
bool is_base_;
bool has_signature_;
PrimType prim_type_;
bool record_evaluate_add_attr_;
};
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
......
......@@ -377,10 +377,10 @@ AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const Primitiv
}
subargs.push_back(AbstractJoin(l_ptr->elements()));
}
AbstractBasePtr engin_exc = engine->Execute(fn, subargs);
EvalResultPtr engin_exc = engine->Execute(fn, subargs);
AbstractBasePtrList result;
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
result.push_back(engin_exc);
result.push_back(engin_exc->abstract());
}
return std::make_shared<AbstractList>(result);
}
......@@ -398,8 +398,9 @@ AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const Primi
AbstractBasePtr list_type = AbstractJoin(lst->elements());
auto result1 = engine->Execute(fn, lst->elements());
auto result2 = engine->Execute(fn, {dflt, list_type});
MS_EXCEPTION_IF_NULL(result1);
return result1->Join(result2);
MS_EXCEPTION_IF_NULL(result1->abstract());
MS_EXCEPTION_IF_NULL(result2->abstract());
return result1->abstract()->Join(result2->abstract());
}
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
......
......@@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
return sorted_nodes;
}
AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg);
std::size_t nargs = fg->parameters().size();
......@@ -106,7 +106,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const auto &arg = args_spec_list[i];
const auto &node = parameters[i];
AnfNodeConfigPtr conf = engine->MakeConfig(node, graph_context_);
engine->cache().set_value(conf, arg);
engine->cache().set_value(conf, std::make_shared<EvalResult>(arg, nullptr));
}
const AnfNodePtr &func_node = fg->get_return();
......@@ -118,14 +118,14 @@ AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abs
const auto &node = *it;
AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_);
MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString();
ret_base = engine->GetEvaluatedValue(node_conf);
ret_base = engine->GetEvaluatedValue(node_conf)->abstract();
MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg->ToString() << ", node_conf: " << node_conf->ToString()
<< ", abstract: " << ret_base->ToString();
}
MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString();
return ret_base;
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString();
return std::make_shared<EvalResult>(ret_base, nullptr);
}
AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
......@@ -236,15 +236,14 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons
return cloned_func_graph;
}
AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) {
const std::string &evaluator_name = ToString();
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
args_spec_list = NormalizeArgs(args_spec_list);
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
......@@ -254,79 +253,79 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
auto iter = cache_->find(args_spec_list);
if (iter == cache_->end()) {
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
AbstractBasePtr ret = Eval(engine, args_spec_list);
if (ret == nullptr) {
EvalResultPtr ret = Eval(engine, args_spec_list);
if (ret->abstract() == nullptr) {
EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
}
MS_EXCEPTION_IF_NULL(ret);
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->abstract()->ToString() << ".";
(*cache_)[args_spec_list] = ret;
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return ret;
} else {
MS_EXCEPTION_IF_NULL(iter->second);
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
MS_EXCEPTION_IF_NULL(iter->second->abstract());
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->abstract()->ToString() << ".";
trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return iter->second;
}
}
AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr) {
EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
AbstractBasePtr ret = EvalPrim(engine, args_spec_list);
EvalResultPtr ret = EvalPrim(engine, args_spec_list);
return ret;
}
AbstractBasePtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Size should greater than 0";
}
AbstractBasePtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
EvalResultPtr ret = EvalPrim(engine, args_spec_list, args_conf_list[0], out_conf);
// No need to cache.
return ret;
}
AbstractBasePtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
AbstractBasePtr ret = EvalPrim(args_conf_list);
EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
EvalResultPtr ret = EvalPrim(args_conf_list);
return ret;
}
AbstractBasePtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
AbstractBasePtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
EvalResultPtr ret = sub_evaluator_->Run(engine, args_conf_list, out_conf);
// Don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map_, like getattr primitive.
(*cache_)[args_spec_list] = ret;
return ret;
}
AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
MS_EXCEPTION_IF_NULL(cache_);
auto iter = cache_->find(args_spec_list);
......@@ -341,17 +340,18 @@ AbstractBasePtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigP
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(partial_args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
AbstractBasePtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
EvalResultPtr ret = evaluator_->Run(engine, partial_args_conf_list, out_conf);
(*cache_)[args_spec_list] = ret;
return ret;
}
AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
MS_EXCEPTION_IF_NULL(cache_);
auto iter = cache_->find(args_spec_list);
......@@ -360,7 +360,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
}
// Call the original evaluator, get the result: y = f(x)
AbstractBasePtr result = evaluator_->Run(engine, args_conf_list, nullptr);
EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
// Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList bparams;
......@@ -369,16 +369,18 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
AbstractFunctionPtr bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result), bparams_final);
AbstractFunctionPtr bprop =
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
// J(f)(J(x)) return a tuple (y, bprop_f)
AbstractBasePtrList jargs = {result, bprop};
AbstractBasePtrList jargs = {result->abstract(), bprop};
AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
(*cache_)[args_spec_list] = jtuple;
return jtuple;
auto infer_reuslt = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = infer_reuslt;
return infer_reuslt;
}
AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() != args_spec_list_.size()) {
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
<< ", arguments no: " << args_spec_list.size();
......@@ -388,7 +390,7 @@ AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrL
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
(void)args_spec_list[i]->Join(args_spec_list_[i]);
}
return output_;
return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
}
} // namespace abstract
} // namespace mindspore
......@@ -29,21 +29,28 @@
namespace mindspore {
namespace abstract {
using EvaluatorCacheMap =
std::unordered_map<AbstractBasePtrList, AbstractBasePtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
std::unordered_map<AbstractBasePtrList, EvalResultPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvaluatorCacheMapPtr = std::shared_ptr<EvaluatorCacheMap>;
using EvaluatorAttrMap =
std::unordered_map<AbstractBasePtrList, AttrValueMapPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual>;
using EvaluatorAttrMapPtr = std::shared_ptr<EvaluatorAttrMap>;
class Evaluator : public Base {
public:
explicit Evaluator(const std::string &id) : cache_(std::make_shared<EvaluatorCacheMap>()), identifier_(id) {}
explicit Evaluator(const std::string &id)
: cache_(std::make_shared<EvaluatorCacheMap>()),
attr_cache_(std::make_shared<EvaluatorAttrMap>()),
identifier_(id) {}
~Evaluator() override = default;
MS_DECLARE_PARENT(Evaluator, Base);
// difference between Run() and Eval():
// Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
// Run() will modify cache_ member, so it cannot marked as const;
virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
......@@ -58,9 +65,10 @@ class Evaluator : public Base {
virtual void set_bound_node(const AnfNodePtr &node) { bound_node_ = AnfNodeWeakPtr(node); }
EvaluatorCacheMapPtr &cache() { return cache_; }
EvaluatorAttrMapPtr &attr_cache() { return attr_cache_; }
EvaluatorCacheMapPtr cache_;
EvaluatorAttrMapPtr attr_cache_;
std::string identifier_;
AnfNodeWeakPtr bound_node_;
......@@ -71,7 +79,7 @@ class PrimEvaluator : public Evaluator {
explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
~PrimEvaluator() override = default;
MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
};
......@@ -81,8 +89,8 @@ class TrivialPrimEvaluator : public PrimEvaluator {
explicit TrivialPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~TrivialPrimEvaluator() override = default;
MS_DECLARE_PARENT(TrivialPrimEvaluator, PrimEvaluator);
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list) = 0;
};
class TransitionPrimEvaluator : public PrimEvaluator {
......@@ -90,10 +98,10 @@ class TransitionPrimEvaluator : public PrimEvaluator {
explicit TransitionPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~TransitionPrimEvaluator() override = default;
MS_DECLARE_PARENT(TransitionPrimEvaluator, PrimEvaluator);
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
// Parameter in_conf0 : the first element in args_conf_list;
virtual AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
virtual EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) = 0;
};
class SymbolicPrimEvaluator : public PrimEvaluator {
......@@ -101,8 +109,8 @@ class SymbolicPrimEvaluator : public PrimEvaluator {
explicit SymbolicPrimEvaluator(const std::string &id) : PrimEvaluator(id) {}
~SymbolicPrimEvaluator() override = default;
MS_DECLARE_PARENT(SymbolicPrimEvaluator, PrimEvaluator);
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
virtual AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) final;
virtual EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) = 0;
};
// Evaluator will be stored in AnalysisEngine.constructors_
......@@ -113,7 +121,7 @@ class DummyEvaluator : public Evaluator {
DummyEvaluator() : Evaluator("dummy") {}
~DummyEvaluator() override = default;
MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
};
// Wrap another evaluator to track a subset of uses.
......@@ -139,11 +147,10 @@ class TrackedEvaluator : public Evaluator {
bound_node_ = AnfNodeWeakPtr(node);
}
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
std::string ToString() const override { return identifier_ + "_" + sub_evaluator_->ToString(); }
private:
......@@ -158,7 +165,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
~BaseFuncGraphEvaluator() override = default;
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
......@@ -238,12 +245,12 @@ class PartialAppEvaluator : public Evaluator {
}
bound_node_ = AnfNodeWeakPtr(node);
}
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
}
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private:
......@@ -258,7 +265,7 @@ class VirtualEvaluator : public Evaluator {
~VirtualEvaluator() override = default;
MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
std::string ToString() const override { return identifier_; }
private:
......@@ -285,11 +292,11 @@ class JEvaluator : public Evaluator {
}
bound_node_ = AnfNodeWeakPtr(node);
}
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
}
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) override;
std::string ToString() const override { return identifier_ + "_" + evaluator_->ToString(); }
private:
......
......@@ -135,13 +135,17 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
using mindspore::parse::PyObjectWrapper;
AbstractBasePtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
prim_->BeginRecordAddAttr();
AbstractBasePtr abs_base = eval_impl_(engine, prim_, args);
return abs_base;
prim_->EndRecordAddAttr();
auto added_attrs = prim_->evaluate_added_attrs();
auto infer_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
return infer_result;
}
AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
AbstractBasePtrList args_spec_list;
if (!prim_->isa<prim::DoSignaturePrimitive>()) {
MS_LOG(EXCEPTION) << "Primitive should be DoSignature, but " << prim_->ToString();
......@@ -161,7 +165,7 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
ScopePtr scope = kDefaultScope;
if (out_conf != nullptr) {
......@@ -212,8 +216,8 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
return graph_specialize_args;
}
AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) {
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
}
......@@ -232,7 +236,7 @@ AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const Config
AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue()->abstract(); });
// get the forward graph
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
AbstractFunctionPtr fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
......@@ -411,7 +415,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
}
} // end anonymous namespace
AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
const auto &iter = cache_->find(args);
......@@ -425,17 +429,20 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
}
auto infer_fuc = pyobj.attr("__infer__");
prim_py_->BeginRecordAddAttr();
py::dict output = infer_fuc(*py_args);
prim_py_->EndRecordAddAttr();
auto added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
auto res_spec = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
(*cache_)[args] = res_spec;
return res_spec;
auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
(*cache_)[args] = infer_result;
return infer_result;
}
AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
// if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
if (nargs_ != args.size()) {
MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
......@@ -476,7 +483,7 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const
}
AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
return abs_base;
return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
}
ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
......@@ -553,8 +560,8 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
manager->AddFuncGraph(func_graph);
}
AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &old_conf) {
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &old_conf) {
MS_EXCEPTION_IF_NULL(old_conf);
AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
......@@ -585,9 +592,9 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
return eng->ForwardConfig(old_conf, fn_conf);
}
AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) {
// args_spec_list: same as StaticGetter
if (args_spec_list.size() < 2) {
MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
......@@ -627,9 +634,9 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
return eng->ForwardConfig(out_conf, fn_conf);
}
AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list, const ValuePtr &item_v,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list is empty";
}
......@@ -646,7 +653,7 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
AbstractBasePtr attr = cls->GetAttribute(item_name);
if (attr != nullptr) {
return attr;
return std::make_shared<EvalResult>(attr, nullptr);
}
ValuePtr method = cls->GetMethod(item_name);
......@@ -660,9 +667,9 @@ AbstractBasePtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &e
return StaticGetterInferred(converted_v, data_conf, out_conf);
}
AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
EvalResultPtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_v,
const TypePtr &data_type, const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(item_v);
MS_EXCEPTION_IF_NULL(data_type);
// The method maybe a Primitive or Composite
......@@ -689,8 +696,8 @@ AbstractBasePtr GetEvaluatedValueForBuiltinTypeMethod(const AnalysisEnginePtr &e
return StaticGetterInferred(converted_v, data_conf, out_conf);
}
AbstractBasePtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function
CheckArgsSize("StaticGetter", args_spec_list, 2);
......@@ -725,7 +732,7 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
~EmbedEvaluator() override = default;
MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
// arg: free variable to be embedded
if (args_conf_list.size() != 1) {
MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
......@@ -733,11 +740,11 @@ class EmbedEvaluator : public SymbolicPrimEvaluator {
AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
MS_EXCEPTION_IF_NULL(node_conf);
AbstractBasePtr x = node_conf->GetEvaluatedValue();
AbstractBasePtr x = node_conf->GetEvaluatedValue()->abstract();
x = SensitivityTransform(x);
SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
return abs_scalar;
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
}
};
......@@ -762,7 +769,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
~RefToEmbedEvaluator() override = default;
MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
AbstractBasePtr EvalPrim(const ConfigPtrList &args_conf_list) override {
EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
if (args_conf_list.size() != 1) {
MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
return nullptr;
......@@ -773,7 +780,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
return nullptr;
}
AbstractBasePtr abs = node_conf->GetEvaluatedValue();
AbstractBasePtr abs = node_conf->GetEvaluatedValue()->abstract();
AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
if (ref_abs == nullptr) {
MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref.";
......@@ -791,7 +798,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
}
auto refkey = key_value->cast<RefKeyPtr>();
if (refkey == nullptr) {
return std::make_shared<AbstractScalar>(type);
return std::make_shared<EvalResult>(std::make_shared<AbstractScalar>(type), std::make_shared<AttrValueMap>());
}
std::string name = refkey->tag();
......@@ -805,7 +812,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
x = SensitivityTransform(x);
std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
return abs_scalar;
return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
}
};
......@@ -814,13 +821,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
~GetAttrEvaluator() override = default;
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
// Inputs: data, item
if (args_spec_list.size() != 2) {
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
}
AbstractBasePtr ret = nullptr;
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
......@@ -840,13 +847,13 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
~ResolveEvaluator() override = default;
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
// Inputs: namespace, symbol
if (args_spec_list.size() != 2) {
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
}
AbstractBasePtr ret = nullptr;
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {
TraceManager::DebugTrace(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
......@@ -863,8 +870,8 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
~CreateInstanceEvaluator() override = default;
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) override {
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
}
......@@ -915,8 +922,9 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
}
AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
(*cache_)[args_spec_list] = ret;
return ret;
auto infer_result = std::make_shared<EvalResult>(ret, nullptr);
(*cache_)[args_spec_list] = infer_result;
return infer_result;
}
pybind11::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
......@@ -942,23 +950,24 @@ class PartialEvaluator : public Evaluator {
public:
PartialEvaluator() : Evaluator("PartialEvaluator") {}
~PartialEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf = nullptr) override {
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf = nullptr) override {
if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
}
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
auto arg0_value = args_conf_list[0]->GetEvaluatedValue()->abstract();
AbstractBasePtrList args_spec_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if (arg0_value->isa<AbstractError>()) {
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString();
(*cache_)[args_spec_list] = ret;
return ret;
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = eval_result;
return eval_result;
}
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
// Sometimes, node[0] in out_conf becomes phi0;
......@@ -970,8 +979,9 @@ class PartialEvaluator : public Evaluator {
}
}
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); });
(void)std::transform(
args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); });
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
auto cnode = out_conf->node()->cast<CNodePtr>();
......@@ -989,16 +999,17 @@ class PartialEvaluator : public Evaluator {
func->Visit(build_partial);
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
(*cache_)[args_spec_list] = ret;
return ret;
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
(*cache_)[args_spec_list] = infer_result;
return infer_result;
}
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
const AnfNodeConfigPtr &out_conf = nullptr) const {
EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
const AnfNodeConfigPtr &out_conf = nullptr) const {
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto cnode = out_conf->node()->cast<CNodePtr>();
......
......@@ -45,7 +45,7 @@ class StandardPrimEvaluator : public TrivialPrimEvaluator {
: TrivialPrimEvaluator("StandardPrimEvaluator"), prim_(primitive), eval_impl_(eval_impl) {}
~StandardPrimEvaluator() override = default;
MS_DECLARE_PARENT(StandardPrimEvaluator, TrivialPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
PrimitivePtr prim() { return prim_; }
std::string ToString() const override { return identifier_ + prim_->name(); }
......@@ -63,7 +63,7 @@ class PythonPrimEvaluator : public TrivialPrimEvaluator {
: TrivialPrimEvaluator("PythonPrimEvaluator"), prim_py_(primitive) {}
~PythonPrimEvaluator() override = default;
MS_DECLARE_PARENT(PythonPrimEvaluator, TrivialPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
PrimitivePtr prim() { return dyn_cast<Primitive>(prim_py_); }
std::string ToString() const override { return identifier_ + prim_py_->name(); }
......@@ -76,10 +76,10 @@ class DoSignatureEvaluator : public Evaluator {
public:
explicit DoSignatureEvaluator(const PrimitivePtr primitive) : Evaluator("DoSignatureEvaluator"), prim_(primitive) {}
~DoSignatureEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
......@@ -91,10 +91,10 @@ class UnpackGraphEvaluator : public Evaluator {
public:
explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {}
~UnpackGraphEvaluator() override = default;
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
}
......@@ -131,7 +131,7 @@ class UniformPrimEvaluator : public TrivialPrimEvaluator {
~UniformPrimEvaluator() override = default;
MS_DECLARE_PARENT(UniformPrimEvaluator, TrivialPrimEvaluator);
AbstractBasePtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) override;
ValuePtr RunImpl(const ValuePtrList &args) const;
// If eval_value_ is False, return broadened arguments.
......
......@@ -36,7 +36,7 @@ inline AbstractBasePtr GetEvaluatedValueWrap(const AnfNodeConfigPtr &conf) {
if (conf->node()->intermediate_abstract()) {
return conf->node()->intermediate_abstract();
}
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
}
AnfNodePtr BuildValueNode(const ValuePtr &v, const AbstractBasePtr &abs_base) {
......@@ -212,7 +212,7 @@ void FuncGraphSpecializer::FirstPass() {
// Specialize CNode in func graphs
void FuncGraphSpecializer::SecondPass() {
for (auto &node : DeepLinkedGraphSearch(specialized_func_graph_->get_return())) {
for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) {
if (node->isa<CNode>()) {
ProcessCNode(node->cast<CNodePtr>());
}
......@@ -225,7 +225,6 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AnfNodeConfigPtr conf = MakeConfig(node);
AnfNodePtr new_node = GetReplicatedNode(node);
MS_EXCEPTION_IF_NULL(new_node);
if (new_node->func_graph() != specialized_func_graph_) {
MS_LOG(EXCEPTION) << "Error in specializer [A] node: " << node->DebugString()
<< ", new_node: " << new_node->DebugString()
......@@ -244,6 +243,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
if (node->isa<CNode>()) {
auto attrs = conf->GetEvaluatedValue()->attribute();
auto c_old = node->cast<CNodePtr>();
auto c_new = new_node->cast<CNodePtr>();
auto new_inputs = c_new->inputs();
......@@ -254,7 +254,7 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
AbstractBasePtr ival = GetEvaluatedValueWrap(iconf);
// First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if
// can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node.
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival);
AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs);
if (replace_node == nullptr) {
replace_node = BuildReplacedNode(iconf);
MS_EXCEPTION_IF_NULL(replace_node);
......@@ -424,9 +424,10 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
}
auto attrs = std::make_shared<AttrValueMap>();
for (size_t i = 0; i < partial_closure->args().size(); i++) {
auto old_node = cnode->input(i + 2);
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]);
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs);
if (possibile_value_node != nullptr) {
partial_node_list.push_back(possibile_value_node);
} else {
......@@ -455,7 +456,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
const EvaluatorPtr &eval) {
MS_EXCEPTION_IF_NULL(eval);
std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices;
AbstractBasePtr ret = nullptr;
EvalResultPtr ret = nullptr;
AbstractBasePtrList broaded_argvals;
for (auto &argvals_map : *evalcaches_[eval]) {
auto argvals = argvals_map.first;
......@@ -478,7 +479,7 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB
(*real)[broaded_argvals] = ret;
evalcaches_[eval] = real;
return std::make_pair(broaded_argvals, ret);
return std::make_pair(broaded_argvals, ret->abstract());
} else {
MS_LOG(DEBUG) << "Choices.size: " << choices.size();
return std::make_pair(AbstractBasePtrList(), nullptr);
......@@ -491,7 +492,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
return;
}
specializer_->AddSeen(new_node);
auto new_inputs = new_node->inputs();
if (new_inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of CNode is empty";
......@@ -530,7 +530,13 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
}
if (CanSpecializeNode(func)) {
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
// for primitive node , we build the primitive node with infered attributes in the first pass
// so we do not build replaced node again here in second pass
if (IsValueNode<Primitive>(func)) {
new_inputs[0] = func;
} else {
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
}
}
for (size_t i = 0; i < argvals.size();) {
......@@ -540,7 +546,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
}
i = next;
}
new_node->set_inputs(new_inputs);
}
......@@ -582,7 +587,7 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
EvaluatorCacheMap evaluator_cache_map = *eval->cache();
if (evaluator_cache_map.find(argvals) != evaluator_cache_map.end()) {
*result = std::make_pair(argvals, evaluator_cache_map[argvals]);
*result = std::make_pair(argvals, evaluator_cache_map[argvals]->abstract());
return kSpecializeSuccess;
}
DumpEvaluatorCache(evaluator_cache_map, argvals);
......@@ -591,11 +596,11 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
MS_EXCEPTION_IF_NULL(choices);
if (choices->count(argvals)) {
*result = std::make_pair(argvals, (*choices)[argvals]);
*result = std::make_pair(argvals, (*choices)[argvals]->abstract());
return kSpecializeSuccess;
} else if (choices->size() == 1) {
MS_LOG(DEBUG) << "Evaluator cache has a single item, just use it.";
*result = std::make_pair(choices->begin()->first, choices->begin()->second);
*result = std::make_pair(choices->begin()->first, choices->begin()->second->abstract());
return kSpecializeSuccess;
} else if (choices->empty()) {
MS_LOG(DEBUG) << "Find DEAD code, it may be optimized in later phase.";
......@@ -614,8 +619,43 @@ SpecializeStatusCode FuncGraphSpecializer::FindUniqueArgvals(const AbstractFunct
return kSpecializeFindUniqueArgvalPoly;
}
}
static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, const AttrValueMapPtr &attrs) {
auto &prim_attrs = prim->attrs();
bool is_attr_same = true;
for (auto &item : *attrs) {
auto itr = prim_attrs.find(item.first);
if (itr != prim_attrs.end()) {
if (!(*(itr->second) == *(item.second))) {
is_attr_same = false;
break;
}
} else {
is_attr_same = false;
break;
}
}
if (!is_attr_same) {
if (prim->isa<PrimitivePy>()) {
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
auto clone_fn = prim_py->GetPyObj().attr("_clone");
py::object new_obj = clone_fn();
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
for (auto &item : *attrs) {
cloned_prim->AddAttr(item.first, item.second);
}
return cloned_prim;
}
auto cloned_prim = std::make_shared<Primitive>(*prim);
for (auto &item : *attrs) {
cloned_prim->AddAttr(item.first, item.second);
}
return cloned_prim;
}
return prim;
}
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival) {
AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs) {
MS_EXCEPTION_IF_NULL(origin_node);
MS_EXCEPTION_IF_NULL(ival);
......@@ -628,7 +668,12 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin
ValuePtr value = nullptr;
if (abs->isa<PrimitiveAbstractClosure>()) {
auto real_fn = dyn_cast<PrimitiveAbstractClosure>(abs);
value = real_fn->prim();
// for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one
if (attrs != nullptr) {
value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs);
} else {
value = real_fn->prim();
}
} else if (abs->isa<MetaFuncGraphAbstractClosure>()) {
auto real_fn = dyn_cast<MetaFuncGraphAbstractClosure>(abs);
value = real_fn->meta_func_graph();
......
......@@ -110,7 +110,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
// Build a value node if ival is constant and not any-value
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival);
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival,
const AttrValueMapPtr &attrs);
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
// replicated node.
AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf);
......
......@@ -55,29 +55,29 @@ AbstractBasePtr IntermediateJoin(const AbstractBasePtr &arg1, const AbstractBase
return nullptr;
}
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
void AnalysisCache::set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) {
MS_LOG(DEBUG) << "AnalysisCache set for NodeConfig: " << conf->node()->DebugString()
<< ", Context: " << conf->context()->ToString() << ", Value: " << arg->ToString()
<< ", Pointer: " << arg.get();
cache_[conf] = arg;
<< ", Context: " << conf->context()->ToString() << ", Value: " << result->abstract()->ToString()
<< ", Pointer: " << result->abstract().get();
cache_[conf] = result;
// Set intermediate abstract value.
if (IsIntermediateAbstract(arg)) {
if (IsIntermediateAbstract(result->abstract())) {
if (conf->node()->intermediate_abstract() == nullptr) {
conf->node()->set_intermediate_abstract(arg);
MS_LOG(DEBUG) << "Set intermediate abstract: " << arg->ToString();
conf->node()->set_intermediate_abstract(result->abstract());
MS_LOG(DEBUG) << "Set intermediate abstract: " << result->abstract()->ToString();
} else {
auto old_spec = conf->node()->intermediate_abstract();
auto joined_spec = IntermediateJoin(arg, old_spec);
auto joined_spec = IntermediateJoin(result->abstract(), old_spec);
conf->node()->set_intermediate_abstract(joined_spec);
MS_LOG(DEBUG) << "Set joined intermediate abstract:\nold_spec:\t\t" << old_spec->ToString() << "\nnew_spec:\t\t"
<< arg->ToString() << "\njoined_spec:\t"
<< result->abstract()->ToString() << "\njoined_spec:\t"
<< (joined_spec != nullptr ? joined_spec->ToString() : "nullptr");
}
}
}
AbstractBasePtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
EvalResultPtr AnalysisCache::GetValue(const AnfNodeConfigPtr &conf) {
auto value = cache_.find(conf);
if (value == cache_.end()) {
return nullptr;
......@@ -142,12 +142,12 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana
return eval->graph_context();
}
AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
auto value = cache_.GetValue(conf);
if (value != nullptr) {
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value.get() << ", "
<< value->ToString();
MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get()
<< ", " << value->abstract()->ToString();
return value;
}
......@@ -160,10 +160,10 @@ AbstractBasePtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf)
return value;
}
AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
AnfNodePtr node = conf->node();
AbstractBasePtr ret_abstract = nullptr;
EvalResultPtr eval_result = nullptr;
#ifdef DEBUG
compute_conf_stack_.push_back(node);
std::ostringstream buffer;
......@@ -177,14 +177,14 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(node);
if (node->abstract() != nullptr) {
MS_LOG(DEBUG) << "Return old abstract: " << node->DebugString();
ret_abstract = node->abstract();
eval_result = std::make_shared<EvalResult>(node->abstract(), std::make_shared<AttrValueMap>());
} else if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
ret_abstract = EvalValueNode(value_node, conf);
eval_result = std::make_shared<EvalResult>(EvalValueNode(value_node, conf), nullptr);
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
trace::TraceEvalCNodeEnter(conf);
ret_abstract = EvalCNode(cnode, conf);
eval_result = EvalCNode(cnode, conf);
trace::TraceEvalCNodeLeave();
} else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
......@@ -193,13 +193,13 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
#ifdef DEBUG
compute_conf_stack_.pop_back();
if (ret_abstract == nullptr) {
if (eval_result == nullptr) {
MS_LOG(EXCEPTION) << "Compute Config failed, node: " << node->DebugString()
<< " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
}
#endif
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << ret_abstract->ToString();
return ret_abstract;
MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString();
return eval_result;
}
AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf) {
......@@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return ToAbstract(value_node->value(), conf->context(), conf);
}
AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
......@@ -223,7 +223,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
AnfNodeConfigPtr func_conf = MakeConfig(func_node, context);
MS_EXCEPTION_IF_NULL(func_conf);
// Keep it in a local variable, otherwise smart pointer will free it.
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue();
AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract();
if (maybe_func == nullptr) {
MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString()
<< " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info());
......@@ -253,7 +253,7 @@ AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeCo
return ExecuteEvaluators(infs, conf, args_conf_list);
}
AbstractBasePtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) {
ConfigPtrList args_conf_list;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(args_conf_list),
[](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
......@@ -454,9 +454,8 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
return tracked_eval;
}
AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) {
if (evaluators.size() == 1) {
EvaluatorPtr eval = evaluators[0];
MS_EXCEPTION_IF_NULL(eval);
......@@ -465,9 +464,9 @@ AbstractBasePtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr
return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list);
}
AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
AbstractBasePtrList out_specs;
if (!multi_poss_.count(evaluators[0])) {
multi_poss_[evaluators[0]] = evaluators[1];
......@@ -477,7 +476,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->GetEvaluatedValue();
return conf->GetEvaluatedValue()->abstract();
});
for (auto eval : evaluators) {
auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>();
......@@ -502,11 +501,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
eval_trace_.push_back(current_inf);
MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get();
MS_EXCEPTION_IF_NULL(eval);
auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(out_spec);
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString();
out_specs.push_back(out_spec);
MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString();
auto eval_result = eval->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << eval_result->abstract()->ToString();
out_specs.push_back(eval_result->abstract());
eval_trace_.pop_back();
if (eval_trace_.empty()) {
multi_poss_.clear();
......@@ -552,10 +550,11 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
// Try to travel the latest undetermined.
if (latest_entry != eval_trace_.rbegin()->first) {
MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString();
auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(out_spec);
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString();
return out_spec;
auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf);
MS_EXCEPTION_IF_NULL(eval_result->abstract());
MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString()
<< " return out_spec: " << eval_result->abstract()->ToString();
return eval_result;
}
}
}
......@@ -566,15 +565,15 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
if (out_specs.size() == 1) {
MS_EXCEPTION_IF_NULL(out_specs[0]);
// If only one result derived, then broaden it to avoid wrong constant propagation.
return out_specs[0]->Broaden();
return std::make_shared<EvalResult>(out_specs[0]->Broaden(), std::make_shared<AttrValueMap>());
}
auto joined_spec = AbstractJoin(out_specs);
MS_EXCEPTION_IF_NULL(joined_spec);
MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString();
return joined_spec;
return std::make_shared<EvalResult>(joined_spec, std::make_shared<AttrValueMap>());
}
AbstractBasePtr AnfNodeConfig::GetEvaluatedValue() {
EvalResultPtr AnfNodeConfig::GetEvaluatedValue() {
AnfNodeConfigPtr self = shared_from_base<AnfNodeConfig>();
return engine_.lock()->GetEvaluatedValue(self);
}
......@@ -607,7 +606,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
return a;
}
AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
auto evaluator = GetPrimEvaluator(primitive, nullptr);
MS_EXCEPTION_IF_NULL(evaluator);
if (!evaluator->isa<TrivialPrimEvaluator>()) {
......@@ -615,8 +614,8 @@ AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtr
<< evaluator->ToString();
}
auto trivial_evaluator = dyn_cast<TrivialPrimEvaluator>(evaluator);
auto res_spec = trivial_evaluator->EvalPrim(nullptr, arg_specs);
return res_spec;
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
return eval_result;
}
} // namespace abstract
} // namespace mindspore
......@@ -40,13 +40,33 @@
namespace mindspore {
namespace abstract {
// define attribute value map
using AttrValueMap = std::unordered_map<std::string, ValuePtr>;
using AttrValueMapPtr = std::shared_ptr<AttrValueMap>;
// the class to save evaluated result: abstract value and modified attribute
class EvalResult : public Base {
public:
EvalResult(AbstractBasePtr abs, AttrValueMapPtr attr) : abstract_(abs), attribute_(attr) {}
~EvalResult() override = default;
MS_DECLARE_PARENT(EvalResult, Base);
AbstractBasePtr abstract() { return abstract_; }
AttrValueMapPtr attribute() { return attribute_; }
private:
AbstractBasePtr abstract_;
AttrValueMapPtr attribute_;
};
using EvalResultPtr = std::shared_ptr<EvalResult>;
// Superclass for AnfNodeConfig and VirtualConfig.
class Config : public Base {
public:
Config() = default;
~Config() override = default;
MS_DECLARE_PARENT(Config, Base);
virtual AbstractBasePtr GetEvaluatedValue() = 0;
virtual EvalResultPtr GetEvaluatedValue() = 0;
};
// Config will be stored in AnalysisCache
......@@ -74,7 +94,7 @@ class AnfNodeConfig : public Config {
~AnfNodeConfig() override = default;
MS_DECLARE_PARENT(AnfNodeConfig, Config);
AbstractBasePtr GetEvaluatedValue() override;
EvalResultPtr GetEvaluatedValue() override;
AnalysisContextPtr context() const { return context_; }
......@@ -123,7 +143,9 @@ class VirtualConfig : public Config {
~VirtualConfig() override = default;
MS_DECLARE_PARENT(VirtualConfig, Config);
AbstractBasePtr GetEvaluatedValue() override { return abstract_; }
EvalResultPtr GetEvaluatedValue() override {
return std::make_shared<EvalResult>(abstract_, std::make_shared<AttrValueMap>());
}
private:
AbstractBasePtr abstract_;
......@@ -135,11 +157,11 @@ class AnalysisCache {
AnalysisCache() = default;
~AnalysisCache() = default;
void Clear() { cache_.clear(); }
void set_value(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg);
AbstractBasePtr GetValue(const AnfNodeConfigPtr &conf);
void set_value(const AnfNodeConfigPtr &conf, const EvalResultPtr &arg);
EvalResultPtr GetValue(const AnfNodeConfigPtr &conf);
private:
std::unordered_map<AnfNodeConfigPtr, AbstractBasePtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
std::unordered_map<AnfNodeConfigPtr, EvalResultPtr, AnfNodeConfigHasher, AnfNodeConfigEqual> cache_;
};
using PrimEvaluatorMap = std::unordered_map<PrimitivePtr, EvaluatorPtr, PrimitiveHasher, PrimitiveEqual>;
......@@ -147,7 +169,7 @@ using AnfNodeConfigMap =
std::unordered_map<AnfNodeConfigPtr, AnfNodeConfigPtr, AnfNodeConfigHasher, AnfNodeConfigEqual>;
struct AnalysisResult {
AbstractBasePtr inferred;
EvalResultPtr inferred;
AnalysisContextPtr context;
};
......@@ -160,14 +182,14 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// func_graph: The func_graph to analyze.
// args_spec_list: The abstracted arguments for the func_graph. Must be a tuple of AbstractBase.
AnalysisResult Run(const FuncGraphPtr &func_graph, const AbstractBasePtrList &args_spec_list);
AbstractBasePtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
EvalResultPtr GetEvaluatedValue(const AnfNodeConfigPtr &conf);
// Return the Evaluator for the given function.
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
EvalResultPtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
// Infer the result of fn(args).
AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
EvalResultPtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear();
void ClearEvaluatorCache();
AnalysisCache &cache() { return cache_; }
......@@ -188,7 +210,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
// Set the analysis result for orig to the result for new.
// This sets an entry in anfnode_config_map from orig to new.
AbstractBasePtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) {
// Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor.
(void)anfnode_config_map_.emplace(orig_conf, new_conf);
MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString()
......@@ -211,12 +233,12 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context,
const ConfigPtrList &args_conf_list);
AbstractBasePtr Eval(const AnfNodeConfigPtr &conf);
EvalResultPtr Eval(const AnfNodeConfigPtr &conf);
EvaluatorPtr _GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
AbstractBasePtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list);
#ifdef DEBUG
std::vector<AnfNodePtr> compute_conf_stack_;
......@@ -244,7 +266,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return FromValueInside(MakeValue(value), broaden);
}
AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
} // namespace abstract
} // namespace mindspore
......
......@@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
}
}
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list);
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
op_exec_info->abstract = infer_res;
}
......
......@@ -26,6 +26,8 @@
#include <list>
#include <string>
#include <fstream>
#include <queue>
#include <set>
#include "ir/visitor.h"
#include "utils/log_adapter.h"
......@@ -223,6 +225,31 @@ std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c
return res;
}
// search the cnodes inside this graph only
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret) {
std::queue<CNodePtr> todo;
todo.push(ret);
std::vector<CNodePtr> sorted_nodes;
auto seen = NewSeenGeneration();
while (!todo.empty()) {
CNodePtr top = todo.front();
todo.pop();
sorted_nodes.push_back(top);
auto inputs = top->inputs();
for (auto &item : inputs) {
if (item->seen_ == seen) {
continue;
}
if (item->isa<CNode>()) {
todo.push(item->cast<CNodePtr>());
}
item->seen_ = seen;
}
}
return sorted_nodes;
}
std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
std::vector<AnfNodePtr> vecs;
if (node == nullptr) {
......
......@@ -57,6 +57,7 @@ std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const Incl
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
const IncludeFunc &include = AlwaysInclude);
std::vector<CNodePtr> BroadFirstSearchGraphCNodes(CNodePtr ret);
class FuncGraphIndex {
public:
explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
......
......@@ -71,7 +71,6 @@ class ExpandDims(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init ExpandDims"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['output'])
def __infer__(self, x, axis):
......@@ -182,7 +181,6 @@ class Cast(PrimitiveWithInfer):
# if primitive need setattr in __infer__ need add this flag
"""init Cast"""
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
self.__setattr_flag__ = True
def __infer__(self, x, t):
src_type = x['dtype']
......@@ -308,7 +306,6 @@ class Reshape(PrimitiveWithInfer):
def __init__(self):
"""init Reshape"""
self.init_prim_io_names(inputs=['tensor', 'shape'], outputs=['output'])
self.__setattr_flag__ = True
def __infer__(self, x, shape):
shape_v = shape['value']
......@@ -453,7 +450,6 @@ class Transpose(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init Transpose"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['x', 'perm'], outputs=['output'])
def __infer__(self, x, perm):
......@@ -508,7 +504,6 @@ class GatherV2(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init index_select"""
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
def __infer__(self, params, indices, axis):
......@@ -1402,7 +1397,6 @@ class Concat(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=0):
"""init Tile"""
self.__setattr_flag__ = True
validator.check_value_type("axis", axis, [int], self.name)
def __infer__(self, input_x):
......@@ -1476,7 +1470,6 @@ class Pack(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=0):
"""init Pack"""
self.__setattr_flag__ = True
validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis
......@@ -1526,7 +1519,6 @@ class Unpack(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=0):
"""init Unpack"""
self.__setattr_flag__ = True
validator.check_value_type("axis", axis, [int], self.name)
self.axis = axis
......@@ -1656,7 +1648,6 @@ class Select(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init"""
self.__setattr_flag__ = True
def infer_shape(self, cond_shape, x_shape, y_shape):
if cond_shape != x_shape or x_shape != y_shape:
......
......@@ -516,7 +516,6 @@ class MatMul(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
......@@ -596,7 +595,6 @@ class BatchMatMul(MatMul):
@prim_attr_register
def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
......@@ -682,7 +680,6 @@ class AddN(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def infer_shape(self, inputs):
......
......@@ -730,8 +730,8 @@ class Conv2D(PrimitiveWithInfer):
"""init Conv2D"""
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name)
self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1]))
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name)
......@@ -787,7 +787,6 @@ class Conv2D(PrimitiveWithInfer):
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel
out_shape = [x_shape[0], out_channel, h_out, w_out]
return out_shape
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn ops """
import functools
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.primitive import constexpr
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def test_cast_op_attr():
class CastNet(nn.Cell):
def __init__(self):
super(CastNet, self).__init__()
self.cast = P.Cast()
def construct(self, x, t):
return self.cast(x, t)
class CastTypeTest(nn.Cell):
def __init__(self, net):
super(CastTypeTest, self).__init__()
self.net = net
self.cast = P.Cast()
def construct(self, x, y, z):
cast_op = self.cast
t1 = cast_op(x, mstype.float32)
t2 = cast_op(y, mstype.int32)
cast_net = self.net
t3 = cast_net(x, mstype.float16)
t4 = cast_net(y, mstype.int32)
t5 = cast_net(z, mstype.float16)
return (t1, t2, t3, t4, t5)
net = CastTypeTest(CastNet())
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1918]).astype(np.int32))
out = net(t1, t2, t3)
assert out[0].asnumpy().dtype == np.float32
assert out[1].asnumpy().dtype == np.int32
assert out[2].asnumpy().dtype == np.float16
assert out[3].asnumpy().dtype == np.int32
assert out[4].asnumpy().dtype == np.float16
......@@ -153,7 +153,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......@@ -179,7 +179,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......@@ -205,7 +205,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......@@ -231,7 +231,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tuple_tensor, slice};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(tupleSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......@@ -253,7 +253,7 @@ TEST_F(TestComposite, test_TensorSliceBySlice) {
AbstractSlicePtr slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {tensor, slice};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred);
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSlicePtrGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
......@@ -288,7 +288,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTuple) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
......@@ -320,7 +320,7 @@ TEST_F(TestComposite, test_TensorSliceBySliceTupleToReduceDimension) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
......@@ -336,7 +336,7 @@ TEST_F(TestComposite, test_TensorSliceByScalar) {
AbstractScalarPtr start_index = std::make_shared<AbstractScalar>(2);
AbstractBasePtrList args_spec_list = {tensor, start_index};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
......@@ -358,7 +358,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTuple) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
......@@ -382,7 +382,7 @@ TEST_F(TestComposite, test_TensorSliceByScalarTupleToScalar) {
AbstractTuplePtr slice_tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tensor, slice_tuple};
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred);
AbstractTensorPtr ret = dyn_cast<AbstractTensor>(engine_->Run(tensorSliceGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract array failed.";
}
......@@ -408,7 +408,7 @@ TEST_F(TestComposite, test_UnpackCall_3args) {
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
AbstractBasePtrList args_spec_list = {fn_arg, tensor_tuple, tensor_dict};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......@@ -435,7 +435,7 @@ TEST_F(TestComposite, test_UnpackCall_5args) {
abstract::AbstractDictionaryPtr tensor_dict = std::make_shared<abstract::AbstractDictionary>(tensor_map);
AbstractBasePtrList args_spec_list = {fn_arg, tensor_dict, tensor_tuple, tensor_dict, tensor_tuple};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(unPackCallGraphPtr, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......@@ -457,7 +457,7 @@ TEST_F(TestComposite, test_ZipOperation) {
auto tuple = std::make_shared<AbstractTuple>(eles);
AbstractBasePtrList args_spec_list = {tuple};
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred);
AbstractTuplePtr ret = dyn_cast<AbstractTuple>(engine_->Run(zip_op_graph, args_spec_list).inferred->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract tuple failed.";
}
......
......@@ -41,11 +41,11 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
AbstractBasePtr abstract_v2 = FromValue(2, false);
AbstractBasePtrList args_spec_list = {abstract_v1, abstract_v2};
AbstractBasePtr abstract_val = FromValue(10, false);
cache[args_spec_list] = abstract_val;
cache[args_spec_list] = std::make_shared<EvalResult>(abstract_val, std::make_shared<AttrValueMap>());
auto iter = cache.find(args_spec_list);
ASSERT_TRUE(iter != cache.end());
ASSERT_TRUE(iter->second == abstract_val);
ASSERT_TRUE(iter->second->abstract() == abstract_val);
AbstractBasePtr abstract_v1_variant1 = FromValue(1, false);
AbstractBasePtr abstract_v2_variant1 = FromValue(2, false);
......@@ -53,7 +53,7 @@ TEST_F(TestEvaluatorCacheMap, test_evaluator_cache_map) {
iter = cache.find(args_spec_list_variant1);
ASSERT_TRUE(iter != cache.end());
ASSERT_TRUE(iter->second == abstract_val);
ASSERT_TRUE(iter->second->abstract() == abstract_val);
AbstractBasePtr abstract_v1_variant2 = FromValue(1, false);
AbstractBasePtr abstract_v2_variant2 = FromValue(3, false);
......@@ -111,7 +111,7 @@ TEST_F(TestStandardEvaluator, test_multiple_conv2d) {
std::vector<int> shape = {2, 2, 6, 6};
expected->set_shape(std::make_shared<Shape>(shape));
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(func_graph, args_spec_list).inferred->abstract();
MS_LOG(INFO) << "result: " << res->ToString();
MS_LOG(INFO) << "expected: " << expected->ToString();
......@@ -144,7 +144,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_resolved) {
AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
}
......@@ -160,7 +160,7 @@ TEST_F(TestPartialEvaluator, test_infer_dataclass_unresolved) {
AbstractBasePtr abstract_x = FromValue(x, false);
args_spec_list.push_back(abstract_x);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat32);
}
......@@ -179,7 +179,7 @@ TEST_F(TestPartialEvaluator, test_infer_add_resolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
......@@ -198,7 +198,7 @@ TEST_F(TestPartialEvaluator, test_infer_sub_unresolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
......@@ -217,7 +217,7 @@ TEST_F(TestPartialEvaluator, test_infer_net_construct_add_resolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
......@@ -237,7 +237,7 @@ TEST_F(TestPartialEvaluator, test_infer_construct_sub_unresolved) {
args_spec_list.push_back(abstract_x);
args_spec_list.push_back(abstract_y);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_x->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeFloat64);
}
......
......@@ -163,7 +163,7 @@ TEST_F(TestInfer, test_inferred_scalar_add) {
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
}
......@@ -261,7 +261,7 @@ TEST_F(TestInferGraph, test_inferred) {
MS_LOG(INFO) << "" << graph_f_->get_return()->ToString();
AbstractBasePtr abstract_v1 = FromValue(1, false);
args_spec_list.push_back(abstract_v1);
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(graph_f_, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
// now this test case failed randomly, have to debug.
......@@ -272,7 +272,7 @@ TEST_F(TestInferGraph, test_inferred) {
args_spec_list.clear();
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred;
abs_base_got = engine_->Run(graph_alpha_, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
}
......@@ -358,7 +358,7 @@ TEST_F(TestInferMetaGraph, test_inferred) {
AbstractBasePtr abstract_v2 = FromValue(v1, false);
args_spec_list.push_back(abstract_v1);
args_spec_list.push_back(abstract_v2);
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph_, args_spec_list).inferred->abstract();
ASSERT_TRUE(abs_base_got.get() == abstract_v1.get());
}
......@@ -390,7 +390,7 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
auto prim_scalar_add = std::make_shared<Primitive>("scalar_add");
FuncGraphPtr func_graph = MakeFuncGraph(prim_scalar_add);
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred;
AbstractBasePtr abs_base_got = engine_->Run(func_graph, args_spec).inferred->abstract();
ASSERT_TRUE(*(abs_base_got->GetTypeTrack()) == *(abstract_v1->GetTypeTrack()));
ASSERT_TRUE(abs_base_got->GetTypeTrack()->type_id() == kNumberTypeInt32);
}
......@@ -418,7 +418,7 @@ TEST_F(TestEvalOnePrim, test_scalar_add) {
AbstractBasePtr base1 = FromValue(x1, false);
AbstractBasePtr base2 = FromValue(x2, false);
AbstractBasePtrList base_list = {base1, base2};
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list);
auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list)->abstract();
MS_LOG(INFO) << "result spec: " << res->ToString();
AbstractBasePtr exp = FromValue(x3, false);
MS_LOG(INFO) << "result exp: " << exp->ToString();
......@@ -446,7 +446,7 @@ void TestGraphEval::TearDown() {
TEST_F(TestGraphInfer, test_graph_infer_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_defaults");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(50), false);
ASSERT_EQ(*res, *expect);
}
......@@ -454,7 +454,7 @@ TEST_F(TestGraphInfer, test_graph_infer_defaults) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_0");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(1), false);
ASSERT_EQ(*res, *expect);
}
......@@ -462,7 +462,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_0) {
TEST_F(TestGraphInfer, test_graph_infer_vararg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(9), false);
ASSERT_EQ(*res, *expect);
}
......@@ -470,7 +470,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(48), false);
ASSERT_EQ(*res, *expect);
}
......@@ -478,7 +478,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs) {
TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_kwarg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(7), false);
ASSERT_EQ(*res, *expect);
}
......@@ -486,7 +486,7 @@ TEST_F(TestGraphInfer, test_graph_infer_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(46), false);
ASSERT_EQ(*res, *expect);
}
......@@ -494,7 +494,7 @@ TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg) {
TEST_F(TestGraphInfer, test_graph_infer_vararg_kwonlyargs_kwarg_defaults) {
FuncGraphPtr graph = getPyFun.CallAndParseRet("test_graph_infer_vararg_kwonlyargs_kwarg_defaults");
AbstractBasePtrList args_spec_list = {};
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred;
AbstractBasePtr res = engine_->Run(graph, args_spec_list).inferred->abstract();
AbstractBasePtr expect = FromValue(MakeValue(57), false);
ASSERT_EQ(*res, *expect);
}
......
......@@ -31,7 +31,8 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
def conv3x3(in_channels, out_channels, stride=1, padding=1):
"""3x3 convolution """
......@@ -377,6 +378,21 @@ class StateNet(nn.Cell):
return x
def test_conv2d_same_primitive():
class Conv2DSameNet(nn.Cell):
def __init__(self):
super(Conv2DSameNet, self).__init__()
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
def construct(self, x, y):
r1 = self.conv1(x)
r2 = self.conv2(y)
return (r1, r2)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = Conv2DSameNet()
out = net(t1, t2)
class ComparisonNet(nn.Cell):
def __init__(self):
""" ComparisonNet definition """
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test nn ops """
import functools
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore import Tensor, Parameter
from mindspore.common.initializer import initializer
from mindspore.ops import Primitive
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
from mindspore.ops.primitive import constexpr
from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
from ....mindspore_test_framework.pipeline.forward.verify_exception \
import pipeline_for_verify_exception_for_case_by_case_config
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
class FakeOp(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
""""""
def infer_shape(self, x, y):
self.second_shape = y
self.add_prim_attr("second_shape", y)
return x
def infer_dtype(self, x, y):
return x
# test the normal case that should generate independent primitive because of different
# generated attributes after inference
def test_conv2d_same_primitive():
class Conv2DSameNet(nn.Cell):
def __init__(self):
super(Conv2DSameNet, self).__init__()
self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
def construct(self, x, y):
r1 = self.conv1(x)
r2 = self.conv2(y)
return (r1, r2)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = Conv2DSameNet()
out = net(t1, t2)
# test cell as high order argument
# The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system
def Xtest_conv2d_op_with_arg():
class Conv2dNet(nn.Cell):
def __init__(self):
super(Conv2dNet, self).__init__()
def construct(self, op, x):
return op(x)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
def construct(self, x, y):
conv_op = self.conv2
a = self.opnet(conv_op, x)
b = self.opnet(conv_op, y)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = OpsNet(Conv2dNet())
out = net(t1, t2)
def test_conv2d_op_with_arg():
class FackOpNet(nn.Cell):
def __init__(self):
super(FackOpNet, self).__init__()
self.op = FakeOp()
def construct(self, x, y):
return self.op(x, y)
class OpNet(nn.Cell):
def __init__(self):
super(OpNet, self).__init__()
def construct(self, op, x, y):
return op(x, y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.op = FackOpNet()
def construct(self, x, y):
op = self.op
a = self.opnet(op, x, y)
b = self.opnet(op, y, x)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = OpsNet(OpNet())
out = net(t1, t2)
def test_conv2d_op_with_arg_same_input():
class FackOpNet(nn.Cell):
def __init__(self):
super(FackOpNet, self).__init__()
self.op = FakeOp()
def construct(self, x, y):
return self.op(x, y)
class OpNet(nn.Cell):
def __init__(self):
super(OpNet, self).__init__()
def construct(self, op, x, y):
return op(x, y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.op = FackOpNet()
def construct(self, x, y):
op = self.op
a = self.opnet(op, x, x)
b = self.opnet(op, y, x)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
net = OpsNet(OpNet())
out = net(t1, t2)
# test op with partial
def test_op_as_partial():
class OpAsPartial(nn.Cell):
def __init__(self):
super(OpAsPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op = F.partial(self.op, x)
a = partial_op(y)
b = partial_op(z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpAsPartial()
out = net(t1, t2, t3)
# test op with partial
def test_op_as_partial_inside():
class OpAsPartial(nn.Cell):
def __init__(self):
super(OpAsPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op = F.partial(self.op, x)
a = partial_op(y)
b = partial_op(z)
return a, b
class OuterNet(nn.Cell):
def __init__(self):
super(OuterNet, self).__init__()
self.net = OpAsPartial()
def construct(self, x, y, z):
a,b = self.net(x, y, z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OuterNet()
out = net(t1, t2, t3)
# test op with partial case 2
def test_op_as_partial_independent():
class OpAsPartial(nn.Cell):
def __init__(self):
super(OpAsPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op1 = F.partial(self.op, x)
a = partial_op1(y)
partial_op2 = F.partial(self.op, x)
b = partial_op2(z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpAsPartial()
out = net(t1, t2, t3)
def test_nest_partial():
class NestPartial(nn.Cell):
def __init__(self):
super(NestPartial, self).__init__()
self.op = FakeOp()
def construct(self, x, y, z):
partial_op1 = F.partial(self.op)
partial_op2 = F.partial(partial_op1, x)
a = partial_op2(y)
partial_op3 = F.partial(self.op)
partial_op4 = F.partial(partial_op3, x)
b = partial_op4(z)
return a, b
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = NestPartial()
out = net(t1, t2, t3)
# high order argument
# op and op args as network arguments
def test_op_with_arg_as_input():
class WithOpArgNet(nn.Cell):
def __init__(self):
super(WithOpArgNet, self).__init__()
def construct(self, op, x, y):
return op(x, y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.opnet = net
self.op = FakeOp()
def construct(self, x, y, z):
op = self.op
a = self.opnet(op, x, z)
b = self.opnet(op, x, y)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpsNet(WithOpArgNet())
out = net(t1, t2, t3)
# The partial application used as argument is not supported yet
# because of the limit of inference specialize system
def Xtest_partial_as_arg():
class PartialArgNet(nn.Cell):
def __init__(self):
super(PartialArgNet, self).__init__()
def construct(self, partial_op, y):
return partial_op(y)
class OpsNet(nn.Cell):
def __init__(self, net):
super(OpsNet, self).__init__()
self.partial_net = net
self.op = FakeOp()
def construct(self, x, y, z):
partial_op = F.partial(self.op, x)
a = self.partial_net(partial_op, z)
b = self.partial_net(partial_op, y)
return (a, b)
t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
net = OpsNet(PartialArgNet())
out = net(t1, t2, t3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册