提交 53d2da5f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!264 static_analysis: remove useless cache in TrivialPrimEvaluator and add...

!264 static_analysis: remove useless cache in TrivialPrimEvaluator and add cache for PythonPrimEvaluator
Merge pull request !264 from xychow/remove-unnecessary-cache-and-add-cache
......@@ -17,7 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_
#include <algorithm>
#include <functional>
#include <iterator>
#include <memory>
#include <string>
#include <vector>
......@@ -129,29 +131,38 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
return optimizer;
}
FuncGraphPtr step(FuncGraphPtr func_graph, const abstract::AbstractBasePtrList &args_spec, bool use_profile = true) {
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
// Optimizer step counter;
int counter = 1;
bool changes = true;
while (changes) {
changes = false;
auto run_runc = [&counter, &func_graph, &args_spec, &changes, use_profile, this]() {
auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() {
for (size_t i = 0; i < passes_.size(); ++i) {
const OptPass &opt = passes_[i];
auto opt_func = [&func_graph, &args_spec, &changes, &opt, this]() {
auto opt_func = [&func_graph, &changes, &opt, this]() {
if (opt.is_renormalize()) {
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
if (resource_ptr != nullptr) {
// StepParallel may replace the AbstractValue of the parameters of func_graph,
// So generate the args_spec from parameters.
abstract::AbstractBasePtrList maybe_new_args_spec;
if (is_watch_renormalize_) {
if (untyped_nodes_.size() > 0) {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
std::back_inserter(maybe_new_args_spec),
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
clear_untyped_nodes();
} else {
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
}
} else {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
std::back_inserter(maybe_new_args_spec),
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
}
}
} else if (opt(func_graph, shared_from_this())) {
......
......@@ -1230,7 +1230,11 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
<< MakeValue(slice_shape)->ToString();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
abstract->set_shape(parallel_shape);
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
auto cloned_abstract = abstract->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract);
cloned_abstract->set_shape(parallel_shape);
parameter->set_abstract(cloned_abstract);
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr);
......@@ -1330,7 +1334,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
cloned_parameter_node->abstract()->set_shape(cloned_from_node->abstract()->GetShapeTrack());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract);
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
cloned_parameter_node->set_abstract(cloned_abstract);
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
<< ", clone index is: " << cloned_index;
......@@ -1742,7 +1749,10 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l
auto slice_shape = loss_grad_layout.slice_shape().array();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
abstract->set_shape(parallel_shape);
auto cloned_abstract = abstract->Clone();
MS_EXCEPTION_IF_NULL(cloned_abstract);
cloned_abstract->set_shape(parallel_shape);
sens_tensor_node->set_abstract(cloned_abstract);
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
return;
......
......@@ -276,9 +276,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa
(void)parse::python_adapter::set_python_scoped();
abstract::AbstractBasePtrList args_spec;
MS_EXCEPTION_IF_NULL(opt_resolve);
(void)opt_resolve->step(func_graph, args_spec, use_profile);
(void)opt_resolve->step(func_graph, use_profile);
return true;
}
......
......@@ -205,14 +205,15 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
return false;
}
abstract::AbstractBasePtrList args = res->args_spec();
FuncGraphPtr func_graph = res->func_graph();
MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", "
<< func_graph->get_return()->DebugString(true);
InitOpt(res);
if (g_pass_opts.find(name) != g_pass_opts.end()) {
res->set_func_graph(g_pass_opts[name]->step(func_graph, args));
res->set_func_graph(g_pass_opts[name]->step(func_graph));
}
// Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to
// res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here.
return true;
}
......@@ -255,10 +256,9 @@ bool ValidatePass(const ResourcePtr &res) {
bool InferenceOptPreparePass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
abstract::AbstractBasePtrList args_spec = res->args_spec();
auto prepare_map = GetInferenceOptPreparePhases();
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map);
(void)infer_opt_prepare->step(func_graph, args_spec, false);
(void)infer_opt_prepare->step(func_graph, false);
return true;
}
......
......@@ -260,7 +260,6 @@ AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const Config
return conf->GetEvaluatedValue();
});
AbstractBasePtr ret = EvalPrim(engine, args_spec_list);
(*cache_)[args_spec_list] = ret;
return ret;
}
......
......@@ -405,6 +405,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
const auto &iter = cache_->find(args);
if (iter != cache_->end()) {
return iter->second;
}
auto py_args = PreparePyInputs(prim_py_, args);
auto pyobj = prim_py_->GetPyObj();
......@@ -418,6 +422,7 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
auto res_spec = PyInferRes2Abstract(prim_py_, output);
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
(*cache_)[args] = res_spec;
return res_spec;
}
......
......@@ -271,6 +271,18 @@ void AnalysisEngine::ClearEvaluatorCache() {
MS_EXCEPTION_IF_NULL(evaluator->cache());
evaluator->cache()->clear();
}
for (auto &element : prim_constructors_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->cache());
evaluator->cache()->clear();
}
for (auto &element : prim_py_evaluators_) {
EvaluatorPtr evaluator = element.second;
MS_EXCEPTION_IF_NULL(evaluator);
MS_EXCEPTION_IF_NULL(evaluator->cache());
evaluator->cache()->clear();
}
}
void AnalysisEngine::Clear() {
......@@ -296,7 +308,17 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
if (prim->HasPyEvaluator()) {
auto prim_py = dyn_cast<PrimitivePy>(prim);
if (prim_py != nullptr) {
return std::make_shared<PythonPrimEvaluator>(prim_py);
if (engine == nullptr) {
return std::make_shared<PythonPrimEvaluator>(prim_py);
}
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
if (iter != engine->prim_py_evaluators_.end()) {
return iter->second;
}
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
engine->prim_py_evaluators_[prim_py] = evaluator;
return evaluator;
}
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
}
......
......@@ -194,6 +194,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
AnalysisCache cache_;
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
private:
const PrimEvaluatorMap &prim_constructors_;
......
......@@ -57,8 +57,7 @@ TEST_F(TestOptOptimizer, test_step_opt) {
true);
EXPECT_TRUE(optimizer.get() != nullptr);
abstract::AbstractBasePtrList args;
auto after = optimizer->step(before, args);
auto after = optimizer->step(before);
draw::Draw("optimizer_test_expendJ_before.dot", before);
draw::Draw("optimizer_test_expendJ_after.dot", after);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册