提交 86a4abad 编写于 作者: H Hoai Linh Tran

Change node collections into flag for calling Renormalize

上级 9991df86
...@@ -84,13 +84,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode ...@@ -84,13 +84,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
} }
#endif #endif
if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) { if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
if (renorm_action_ == FORCE_RENORM) { if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) {
optimizer->add_node_to_renormalize(result); optimizer->set_is_untyped_generated();
} else {
// renorm_action_ is CHECK_RENORM
if (result->abstract() == nullptr) {
optimizer->add_node_to_renormalize(result);
}
} }
} }
......
...@@ -89,12 +89,18 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; ...@@ -89,12 +89,18 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> { class Optimizer : public std::enable_shared_from_this<Optimizer> {
public: public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {} : name_(name),
resource_(resource_ptr),
run_only_once_(false),
is_watch_renormalize_(false),
is_enable_(true),
is_untyped_generated_(false) {}
virtual ~Optimizer() = default; virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) { void Init(const OptPassGroupMap &passes, bool run_only_once) {
run_only_once_ = run_only_once; run_only_once_ = run_only_once;
is_watch_renormalize_ = false; is_watch_renormalize_ = false;
is_untyped_generated_ = false;
is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG); is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG);
for (auto &iter : passes) { for (auto &iter : passes) {
...@@ -154,14 +160,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -154,14 +160,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
// So generate the args_spec from parameters. // So generate the args_spec from parameters.
abstract::AbstractBasePtrList maybe_new_args_spec; abstract::AbstractBasePtrList maybe_new_args_spec;
if (is_watch_renormalize_) { if (is_watch_renormalize_) {
if (untyped_nodes_.size() > 0) { if (is_untyped_generated_) {
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
std::back_inserter(maybe_new_args_spec), std::back_inserter(maybe_new_args_spec),
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
clear_untyped_nodes(); clear_is_untyped_generated();
} else { } else {
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False.";
} }
} else { } else {
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
...@@ -206,13 +212,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -206,13 +212,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
const std::string name() const { return name_; } const std::string name() const { return name_; }
void add_node_to_renormalize(AnfNodePtr anode) { void set_is_untyped_generated() { is_untyped_generated_ = true; }
if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) { void clear_is_untyped_generated() { is_untyped_generated_ = false; }
untyped_nodes_.push_back(anode);
}
}
void clear_untyped_nodes() { untyped_nodes_.clear(); }
void enable_watch_renormalize() { is_watch_renormalize_ = true; } void enable_watch_renormalize() { is_watch_renormalize_ = true; }
void disable_watch_renormalize() { is_watch_renormalize_ = false; } void disable_watch_renormalize() { is_watch_renormalize_ = false; }
...@@ -232,9 +233,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -232,9 +233,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
std::vector<OptPass> passes_; std::vector<OptPass> passes_;
std::vector<std::string> pass_names_; std::vector<std::string> pass_names_;
bool run_only_once_; bool run_only_once_;
std::vector<AnfNodePtr> untyped_nodes_;
bool is_watch_renormalize_; bool is_watch_renormalize_;
bool is_enable_; bool is_enable_;
bool is_untyped_generated_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册