diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 462d08ad3c1d9f3ce19d3095e9f9193d7c2f166c..5e893cf1aa78ca2607218ef31726856a88120845 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -84,13 +84,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode } #endif if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { - if (renorm_action_ == FORCE_RENORM) { - optimizer->add_node_to_renormalize(result); - } else { - // renorm_action_ is CHECK_RENORM - if (result->abstract() == nullptr) { - optimizer->add_node_to_renormalize(result); - } + if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) { + optimizer->set_is_untyped_generated(); } } diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h index dc423ed31477dee14b1bdf0dcabf88bb799aa360..a98a59caf2b22816762b3b8b4dff84f46bfd8883 100644 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ b/mindspore/ccsrc/optimizer/optimizer.h @@ -89,12 +89,18 @@ using OptPassGroupMap = std::vector>; class Optimizer : public std::enable_shared_from_this { public: Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) - : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {} + : name_(name), + resource_(resource_ptr), + run_only_once_(false), + is_watch_renormalize_(false), + is_enable_(true), + is_untyped_generated_(false) {} virtual ~Optimizer() = default; void Init(const OptPassGroupMap &passes, bool run_only_once) { run_only_once_ = run_only_once; is_watch_renormalize_ = false; + is_untyped_generated_ = false; is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); for (auto &iter : passes) { @@ -154,14 +160,14 @@ class Optimizer : public std::enable_shared_from_this { // So generate the args_spec from parameters. abstract::AbstractBasePtrList maybe_new_args_spec; if (is_watch_renormalize_) { - if (untyped_nodes_.size() > 0) { + if (is_untyped_generated_) { std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), std::back_inserter(maybe_new_args_spec), [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); - clear_untyped_nodes(); + clear_is_untyped_generated(); } else { - MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; + MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False."; } } else { std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), @@ -206,13 +212,8 @@ class Optimizer : public std::enable_shared_from_this { const std::string name() const { return name_; } - void add_node_to_renormalize(AnfNodePtr anode) { - if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) { - untyped_nodes_.push_back(anode); - } - } - - void clear_untyped_nodes() { untyped_nodes_.clear(); } + void set_is_untyped_generated() { is_untyped_generated_ = true; } + void clear_is_untyped_generated() { is_untyped_generated_ = false; } void enable_watch_renormalize() { is_watch_renormalize_ = true; } void disable_watch_renormalize() { is_watch_renormalize_ = false; } @@ -232,9 +233,9 @@ class Optimizer : public std::enable_shared_from_this { std::vector passes_; std::vector pass_names_; bool run_only_once_; - std::vector untyped_nodes_; bool is_watch_renormalize_; bool is_enable_; + bool is_untyped_generated_; }; } // namespace opt } // namespace mindspore