提交 86a4abad 编写于 作者: H Hoai Linh Tran

Change node collections into flag for calling Renormalize

上级 9991df86
......@@ -84,13 +84,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNode
}
#endif
if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
if (renorm_action_ == FORCE_RENORM) {
optimizer->add_node_to_renormalize(result);
} else {
// renorm_action_ is CHECK_RENORM
if (result->abstract() == nullptr) {
optimizer->add_node_to_renormalize(result);
}
if ((renorm_action_ == FORCE_RENORM) || (result->abstract() == nullptr)) {
optimizer->set_is_untyped_generated();
}
}
......
......@@ -89,12 +89,18 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> {
public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false), is_enable_(true) {}
: name_(name),
resource_(resource_ptr),
run_only_once_(false),
is_watch_renormalize_(false),
is_enable_(true),
is_untyped_generated_(false) {}
virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) {
run_only_once_ = run_only_once;
is_watch_renormalize_ = false;
is_untyped_generated_ = false;
is_on_debug_ = IS_OUTPUT_ON(mindspore::DEBUG);
for (auto &iter : passes) {
......@@ -154,14 +160,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
// So generate the args_spec from parameters.
abstract::AbstractBasePtrList maybe_new_args_spec;
if (is_watch_renormalize_) {
if (untyped_nodes_.size() > 0) {
if (is_untyped_generated_) {
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
std::back_inserter(maybe_new_args_spec),
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
clear_untyped_nodes();
clear_is_untyped_generated();
} else {
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because is_untyped_generated_ is False.";
}
} else {
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
......@@ -206,13 +212,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
const std::string name() const { return name_; }
void add_node_to_renormalize(AnfNodePtr anode) {
if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) {
untyped_nodes_.push_back(anode);
}
}
void clear_untyped_nodes() { untyped_nodes_.clear(); }
void set_is_untyped_generated() { is_untyped_generated_ = true; }
void clear_is_untyped_generated() { is_untyped_generated_ = false; }
void enable_watch_renormalize() { is_watch_renormalize_ = true; }
void disable_watch_renormalize() { is_watch_renormalize_ = false; }
......@@ -232,9 +233,9 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
std::vector<OptPass> passes_;
std::vector<std::string> pass_names_;
bool run_only_once_;
std::vector<AnfNodePtr> untyped_nodes_;
bool is_watch_renormalize_;
bool is_enable_;
bool is_untyped_generated_;
};
} // namespace opt
} // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册