提交 be60bd3d 编写于 作者: K Kang

Code refactoring for the static_analysis : modified Infer to Eval.

上级 168dfb25
...@@ -182,7 +182,7 @@ void DumpInferStack(std::ostringstream &oss) { ...@@ -182,7 +182,7 @@ void DumpInferStack(std::ostringstream &oss) {
} }
} }
void TraceGraphInfer() { void TraceGraphEval() {
auto &infer_stack = GetCurrenGraphInferStack(); auto &infer_stack = GetCurrenGraphInferStack();
std::ostringstream oss; std::ostringstream oss;
if (infer_stack.empty()) { if (infer_stack.empty()) {
...@@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, ...@@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename,
ofs.close(); ofs.close();
} }
void GetInferStackInfo(std::ostringstream &oss) { void GetEvalStackInfo(std::ostringstream &oss) {
MS_LOG(INFO) << "Get graph analysis information begin"; MS_LOG(INFO) << "Get graph analysis information begin";
auto stack = GetCNodeDebugStack(); auto stack = GetCNodeDebugStack();
if (stack.empty()) { if (stack.empty()) {
...@@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream &oss) { ...@@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream &oss) {
static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack; static std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> graph_infer_stack;
// trace the cnode infer debug info // trace the cnode infer debug info
static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{}; static std::vector<abstract::AnfNodeConfigPtr> cnode_debug_stack{};
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) {
if (eval == nullptr) { if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
} }
...@@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::An ...@@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::An
} }
} }
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { void TraceGraphEvalLeave(const abstract::EvaluatorPtr &eval) {
if (eval == nullptr) { if (eval == nullptr) {
MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; MS_LOG(EXCEPTION) << "GraphInferEnter got null eval";
} }
...@@ -354,9 +354,9 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { ...@@ -354,9 +354,9 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) {
} }
} }
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); }
void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } void TraceEvalCNodeLeave() { cnode_debug_stack.pop_back(); }
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; } std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack() { return cnode_debug_stack; }
......
...@@ -35,12 +35,12 @@ std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLi ...@@ -35,12 +35,12 @@ std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLi
std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
SourceLineTip tip = kSourceLineTipNextLine); SourceLineTip tip = kSourceLineTipNextLine);
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
void TraceGraphInfer(); void TraceGraphEval();
void GetInferStackInfo(std::ostringstream &oss); void GetEvalStackInfo(std::ostringstream &oss);
void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); void TraceGraphEvalEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node);
void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); void TraceGraphEvalLeave(const abstract::EvaluatorPtr &eval);
void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); void TraceEvalCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg);
void TraceInferCNodeLeave(); void TraceEvalCNodeLeave();
std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack(); std::vector<abstract::AnfNodeConfigPtr> &GetCNodeDebugStack();
std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack(); std::stack<std::pair<abstract::EvaluatorPtr, abstract::AnfNodeConfigPtr>> &GetCurrenGraphInferStack();
std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); std::string GetAbstractStr(const abstract::AbstractBasePtr &abs);
......
...@@ -430,8 +430,8 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: ...@@ -430,8 +430,8 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
} catch (const py::error_already_set &ex) { } catch (const py::error_already_set &ex) {
// print function call stack info before release // print function call stack info before release
std::ostringstream oss; std::ostringstream oss;
trace::TraceGraphInfer(); trace::TraceGraphEval();
trace::GetInferStackInfo(oss); trace::GetEvalStackInfo(oss);
// call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see // call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
// these info from screen, no need to open log file to find these info // these info from screen, no need to open log file to find these info
py::print(oss.str()); py::print(oss.str());
......
...@@ -38,7 +38,7 @@ namespace abstract { ...@@ -38,7 +38,7 @@ namespace abstract {
class AbstractBase; class AbstractBase;
using AbstractBasePtrList = std::vector<AbstractBasePtr>; using AbstractBasePtrList = std::vector<AbstractBasePtr>;
// The base class for abstract value. The abstract value is used in inferring // The base class for abstract value. The abstract value is used in evaluating
// to express the type, shape, and value of the real value. // to express the type, shape, and value of the real value.
class AbstractBase : public Base { class AbstractBase : public Base {
public: public:
......
...@@ -153,7 +153,7 @@ bool AnalysisContext::operator==(const AnalysisContext &other) const { ...@@ -153,7 +153,7 @@ bool AnalysisContext::operator==(const AnalysisContext &other) const {
// free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what // free values. In order to decrease the number of cloned graphs, we add this `SpecializeKey` method to control what
// graph can be reused. // graph can be reused.
// The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined // The graph called with different SymbolicKey will be reused. The abstract of SymbolicKey parameter will be joined
// and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in infer, thus the reused // and stored in the intermediate_abstract. The joined SymbolicKey would cause Poly Code in eval, thus the reused
// graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize. // graph with SymbolicKey parameter should be inlined in `opt` pipeline before the next renormalize.
// The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies // The graph called with different shape should not be reused, because the combination of `shape` and `Fill` relies
// on correct shape to specialize a tensor constant. // on correct shape to specialize a tensor constant.
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
namespace { namespace {
void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list, void EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_spec_list,
const AnfNodeConfigPtr &out_conf) { const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
if (out_conf != nullptr) { if (out_conf != nullptr) {
MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name(); MS_LOG(DEBUG) << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
...@@ -37,7 +37,7 @@ void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList ...@@ -37,7 +37,7 @@ void InferEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList
} }
} }
void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) { void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
if (out_conf != nullptr) { if (out_conf != nullptr) {
auto node = out_conf->node(); auto node = out_conf->node();
...@@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) { ...@@ -89,7 +89,7 @@ static std::vector<AnfNodePtr> FastShadowSort(const AnfNodePtr &ret_node) {
return sorted_nodes; return sorted_nodes;
} }
AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { AbstractBasePtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
std::size_t nargs = fg->parameters().size(); std::size_t nargs = fg->parameters().size();
...@@ -124,7 +124,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab ...@@ -124,7 +124,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::Infer(AnalysisEnginePtr engine, const Ab
} }
MS_EXCEPTION_IF_NULL(ret_base); MS_EXCEPTION_IF_NULL(ret_base);
MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " infer end, inferred abstract: " << ret_base->ToString(); MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " Eval end, evaluated abstract: " << ret_base->ToString();
return ret_base; return ret_base;
} }
...@@ -155,7 +155,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa ...@@ -155,7 +155,7 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa
<< ", context: " << parent_context_->ToString(); << ", context: " << parent_context_->ToString();
auto last_context = parent_context_->Filter(func_graph_); auto last_context = parent_context_->Filter(func_graph_);
if (last_context && last_context->func_graph() == func_graph_) { if (last_context && last_context->func_graph() == func_graph_) {
MS_LOG(DEBUG) << "Find last infer context: " << last_context->ToString(); MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString();
MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list);
MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list());
// Join the last eval arguments and current arguments to check if there are loop variant. // Join the last eval arguments and current arguments to check if there are loop variant.
...@@ -248,26 +248,26 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar ...@@ -248,26 +248,26 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar
}); });
args_spec_list = NormalizeArgs(args_spec_list); args_spec_list = NormalizeArgs(args_spec_list);
args_spec_list = BroadenUndeterminedArgs(args_spec_list); args_spec_list = BroadenUndeterminedArgs(args_spec_list);
trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf); trace::TraceGraphEvalEnter(shared_from_base<Evaluator>(), out_conf);
InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); EvalEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_EXCEPTION_IF_NULL(cache_); MS_EXCEPTION_IF_NULL(cache_);
auto iter = cache_->find(args_spec_list); auto iter = cache_->find(args_spec_list);
if (iter == cache_->end()) { if (iter == cache_->end()) {
MS_LOG(DEBUG) << evaluator_name << " cache miss, call Infer()."; MS_LOG(DEBUG) << evaluator_name << " cache miss, call Eval().";
AbstractBasePtr ret = Infer(engine, args_spec_list); AbstractBasePtr ret = Eval(engine, args_spec_list);
if (ret == nullptr) { if (ret == nullptr) {
InferFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); EvalFailLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf);
MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr."; MS_LOG(EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
} }
MS_EXCEPTION_IF_NULL(ret); MS_EXCEPTION_IF_NULL(ret);
MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << "."; MS_LOG(DEBUG) << evaluator_name << " set cache. return: " << ret->ToString() << ".";
(*cache_)[args_spec_list] = ret; (*cache_)[args_spec_list] = ret;
trace::TraceGraphInferLeave(shared_from_base<Evaluator>()); trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return ret; return ret;
} else { } else {
MS_EXCEPTION_IF_NULL(iter->second); MS_EXCEPTION_IF_NULL(iter->second);
MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << "."; MS_LOG(DEBUG) << evaluator_name << " cache hit. return: " << iter->second->ToString() << ".";
trace::TraceGraphInferLeave(shared_from_base<Evaluator>()); trace::TraceGraphEvalLeave(shared_from_base<Evaluator>());
return iter->second; return iter->second;
} }
} }
...@@ -378,7 +378,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a ...@@ -378,7 +378,7 @@ AbstractBasePtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &a
return jtuple; return jtuple;
} }
AbstractBasePtr VirtualEvaluator::Infer(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) { AbstractBasePtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() != args_spec_list_.size()) { if (args_spec_list.size() != args_spec_list_.size()) {
MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size() MS_LOG(EXCEPTION) << "Arguments mismatch, parameters no: " << args_spec_list_.size()
<< ", arguments no: " << args_spec_list.size(); << ", arguments no: " << args_spec_list.size();
......
...@@ -38,12 +38,12 @@ class Evaluator : public Base { ...@@ -38,12 +38,12 @@ class Evaluator : public Base {
~Evaluator() override = default; ~Evaluator() override = default;
MS_DECLARE_PARENT(Evaluator, Base); MS_DECLARE_PARENT(Evaluator, Base);
// difference between Run() and Infer(): // difference between Run() and Eval():
// Run() will be called with ConfigPtrList, but Infer() will be called with AbstractBasePtr. // Run() will be called with ConfigPtrList, but Eval() will be called with AbstractBasePtr.
// Run() will modify cache_ member, so it cannot marked as const; // Run() will modify cache_ member, so it cannot marked as const;
virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf); virtual AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf);
virtual AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; virtual AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; }
...@@ -71,8 +71,8 @@ class PrimEvaluator : public Evaluator { ...@@ -71,8 +71,8 @@ class PrimEvaluator : public Evaluator {
explicit PrimEvaluator(const std::string &id) : Evaluator(id) {} explicit PrimEvaluator(const std::string &id) : Evaluator(id) {}
~PrimEvaluator() override = default; ~PrimEvaluator() override = default;
MS_DECLARE_PARENT(PrimEvaluator, Evaluator); MS_DECLARE_PARENT(PrimEvaluator, Evaluator);
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) final { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) final {
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
}; };
...@@ -113,7 +113,7 @@ class DummyEvaluator : public Evaluator { ...@@ -113,7 +113,7 @@ class DummyEvaluator : public Evaluator {
DummyEvaluator() : Evaluator("dummy") {} DummyEvaluator() : Evaluator("dummy") {}
~DummyEvaluator() override = default; ~DummyEvaluator() override = default;
MS_DECLARE_PARENT(DummyEvaluator, Evaluator); MS_DECLARE_PARENT(DummyEvaluator, Evaluator);
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; } AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { return nullptr; }
}; };
// Wrap another evaluator to track a subset of uses. // Wrap another evaluator to track a subset of uses.
...@@ -139,8 +139,8 @@ class TrackedEvaluator : public Evaluator { ...@@ -139,8 +139,8 @@ class TrackedEvaluator : public Evaluator {
bound_node_ = AnfNodeWeakPtr(node); bound_node_ = AnfNodeWeakPtr(node);
} }
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
AnfNodeConfigPtr out_conf) override; AnfNodeConfigPtr out_conf) override;
...@@ -158,7 +158,7 @@ class BaseFuncGraphEvaluator : public Evaluator { ...@@ -158,7 +158,7 @@ class BaseFuncGraphEvaluator : public Evaluator {
~BaseFuncGraphEvaluator() override = default; ~BaseFuncGraphEvaluator() override = default;
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator); MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0; virtual FuncGraphPtr GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) = 0;
...@@ -238,7 +238,7 @@ class PartialAppEvaluator : public Evaluator { ...@@ -238,7 +238,7 @@ class PartialAppEvaluator : public Evaluator {
} }
bound_node_ = AnfNodeWeakPtr(node); bound_node_ = AnfNodeWeakPtr(node);
} }
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
} }
...@@ -258,7 +258,7 @@ class VirtualEvaluator : public Evaluator { ...@@ -258,7 +258,7 @@ class VirtualEvaluator : public Evaluator {
~VirtualEvaluator() override = default; ~VirtualEvaluator() override = default;
MS_DECLARE_PARENT(VirtualEvaluator, Evaluator); MS_DECLARE_PARENT(VirtualEvaluator, Evaluator);
AbstractBasePtr Infer(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override; AbstractBasePtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) override;
std::string ToString() const override { return identifier_; } std::string ToString() const override { return identifier_; }
private: private:
...@@ -285,7 +285,7 @@ class JEvaluator : public Evaluator { ...@@ -285,7 +285,7 @@ class JEvaluator : public Evaluator {
} }
bound_node_ = AnfNodeWeakPtr(node); bound_node_ = AnfNodeWeakPtr(node);
} }
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Should not be called, Run() method should be called";
} }
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
......
...@@ -470,16 +470,16 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const ...@@ -470,16 +470,16 @@ AbstractBasePtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const
} }
} }
ValuePtr inferred_value = RunImpl(value_list); ValuePtr evaluated_value = RunImpl(value_list);
if (!(*inferred_value == *kAnyValue)) { if (!(*evaluated_value == *kAnyValue)) {
ret_value_type = inferred_value->type(); ret_value_type = evaluated_value->type();
} }
// for comparison primitives , return type shall have be specified to be bool. // for comparison primitives , return type shall have be specified to be bool.
if (specify_out_type_ != nullptr) { if (specify_out_type_ != nullptr) {
ret_value_type = specify_out_type_; ret_value_type = specify_out_type_;
} }
AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(inferred_value, ret_value_type); AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
return abs_base; return abs_base;
} }
...@@ -997,8 +997,8 @@ class PartialEvaluator : public Evaluator { ...@@ -997,8 +997,8 @@ class PartialEvaluator : public Evaluator {
return ret; return ret;
} }
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value, AbstractBasePtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
......
...@@ -79,8 +79,8 @@ class DoSignatureEvaluator : public Evaluator { ...@@ -79,8 +79,8 @@ class DoSignatureEvaluator : public Evaluator {
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override; AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
private: private:
...@@ -94,8 +94,8 @@ class UnpackGraphEvaluator : public Evaluator { ...@@ -94,8 +94,8 @@ class UnpackGraphEvaluator : public Evaluator {
AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs,
AnfNodeConfigPtr out_config = nullptr) override; AnfNodeConfigPtr out_config = nullptr) override;
AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { AbstractBasePtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override {
MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
} }
private: private:
......
...@@ -183,11 +183,11 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { ...@@ -183,11 +183,11 @@ AbstractBasePtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) {
ret_abstract = EvalValueNode(value_node, conf); ret_abstract = EvalValueNode(value_node, conf);
} else if (node->isa<CNode>()) { } else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
trace::TraceInferCNodeEnter(conf); trace::TraceEvalCNodeEnter(conf);
ret_abstract = InferCNode(cnode, conf); ret_abstract = EvalCNode(cnode, conf);
trace::TraceInferCNodeLeave(); trace::TraceEvalCNodeLeave();
} else { } else {
MS_LOG(EXCEPTION) << "Illegal AnfNode for inferring, " << node->DebugString() MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString()
<< ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info());
} }
...@@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co ...@@ -208,7 +208,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
return ToAbstract(value_node->value(), conf->context(), conf); return ToAbstract(value_node->value(), conf->context(), conf);
} }
AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { AbstractBasePtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
...@@ -496,7 +496,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval ...@@ -496,7 +496,7 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval
auto current_inf = std::make_pair(eval, args_spec_list); auto current_inf = std::make_pair(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring. // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf);
if (it == eval_trace_.rend()) { if (it == eval_trace_.rend()) {
eval_trace_.push_back(current_inf); eval_trace_.push_back(current_inf);
...@@ -607,7 +607,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) { ...@@ -607,7 +607,7 @@ AbstractBasePtr FromValueInside(const ValuePtr &value, bool broaden) {
return a; return a;
} }
AbstractBasePtr InferOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) { AbstractBasePtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrList &arg_specs) {
auto evaluator = GetPrimEvaluator(primitive, nullptr); auto evaluator = GetPrimEvaluator(primitive, nullptr);
MS_EXCEPTION_IF_NULL(evaluator); MS_EXCEPTION_IF_NULL(evaluator);
if (!evaluator->isa<TrivialPrimEvaluator>()) { if (!evaluator->isa<TrivialPrimEvaluator>()) {
......
...@@ -165,7 +165,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { ...@@ -165,7 +165,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn); EvaluatorPtr GetEvaluatorFor(const AbstractFunctionPtr &fn);
AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf); AbstractBasePtr EvalValueNode(const ValueNodePtr &value_node, const AnfNodeConfigPtr &conf);
AbstractBasePtr InferCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf); AbstractBasePtr EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf);
// Infer the result of fn(args). // Infer the result of fn(args).
AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list); AbstractBasePtr Execute(const AbstractFunctionPtr &fn, const AbstractBasePtrList &args_spec_list);
void Clear(); void Clear();
...@@ -244,7 +244,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { ...@@ -244,7 +244,7 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
return FromValueInside(MakeValue(value), broaden); return FromValueInside(MakeValue(value), broaden);
} }
AbstractBasePtr InferOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); AbstractBasePtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore
......
...@@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI ...@@ -116,7 +116,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecI
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false)); args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
} }
} }
AbstractBasePtr infer_res = InferOnePrim(prim, args_spec_list); AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list);
op_exec_info->abstract = infer_res; op_exec_info->abstract = infer_res;
} }
......
...@@ -216,8 +216,8 @@ void LogWriter::operator^(const LogStream &stream) const { ...@@ -216,8 +216,8 @@ void LogWriter::operator^(const LogStream &stream) const {
} }
oss << msg.str(); oss << msg.str();
trace::TraceGraphInfer(); trace::TraceGraphEval();
trace::GetInferStackInfo(oss); trace::GetEvalStackInfo(oss);
if (exception_type_ == IndexError) { if (exception_type_ == IndexError) {
throw pybind11::index_error(oss.str()); throw pybind11::index_error(oss.str());
......
...@@ -396,9 +396,9 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) { ...@@ -396,9 +396,9 @@ TEST_F(TestInferUniform, test_inferred_scalar_add) {
} }
class TestInferOnePrim : public UT::Common { class TestEvalOnePrim : public UT::Common {
public: public:
TestInferOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {} TestEvalOnePrim() : getPyFun("gtest_input.pipeline.infer.infer_test", true), engine_(nullptr) {}
void SetUp(); void SetUp();
void TearDown(); void TearDown();
...@@ -406,37 +406,37 @@ class TestInferOnePrim : public UT::Common { ...@@ -406,37 +406,37 @@ class TestInferOnePrim : public UT::Common {
AnalysisEnginePtr engine_; AnalysisEnginePtr engine_;
}; };
void TestInferOnePrim::SetUp() { engine_ = SetupAnalysisEngineStub(); } void TestEvalOnePrim::SetUp() { engine_ = SetupAnalysisEngineStub(); }
void TestInferOnePrim::TearDown() { void TestEvalOnePrim::TearDown() {
// destroy resource // destroy resource
} }
TEST_F(TestInferOnePrim, test_scalar_add) { TEST_F(TestEvalOnePrim, test_scalar_add) {
double x1 = 1.1; double x1 = 1.1;
double x2 = 1.1; double x2 = 1.1;
double x3 = 2.2; double x3 = 2.2;
AbstractBasePtr base1 = FromValue(x1, false); AbstractBasePtr base1 = FromValue(x1, false);
AbstractBasePtr base2 = FromValue(x2, false); AbstractBasePtr base2 = FromValue(x2, false);
AbstractBasePtrList base_list = {base1, base2}; AbstractBasePtrList base_list = {base1, base2};
auto res = InferOnePrim(std::make_shared<Primitive>("scalar_add"), base_list); auto res = EvalOnePrim(std::make_shared<Primitive>("scalar_add"), base_list);
MS_LOG(INFO) << "result spec: " << res->ToString(); MS_LOG(INFO) << "result spec: " << res->ToString();
AbstractBasePtr exp = FromValue(x3, false); AbstractBasePtr exp = FromValue(x3, false);
MS_LOG(INFO) << "result exp: " << exp->ToString(); MS_LOG(INFO) << "result exp: " << exp->ToString();
ASSERT_EQ(*res, *exp); ASSERT_EQ(*res, *exp);
} }
class TestGraphInfer : public UT::Common { class TestGraphEval : public UT::Common {
public: public:
TestGraphInfer() : getPyFun("gtest_input.pipeline.infer.infer_test", true){}; TestGraphEval() : getPyFun("gtest_input.pipeline.infer.infer_test", true){};
void SetUp(); void SetUp();
void TearDown(); void TearDown();
AnalysisEnginePtr engine_; AnalysisEnginePtr engine_;
UT::PyFuncGraphFetcher getPyFun; UT::PyFuncGraphFetcher getPyFun;
}; };
void TestGraphInfer::SetUp() { engine_ = SetupAnalysisEngine(); } void TestGraphEval::SetUp() { engine_ = SetupAnalysisEngine(); }
void TestGraphInfer::TearDown() { void TestGraphEval::TearDown() {
// destroy resource // destroy resource
engine_->ClearEvaluatorCache(); engine_->ClearEvaluatorCache();
parse::data_converter::ClearObjectCache(); parse::data_converter::ClearObjectCache();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册