提交 d04a62b7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!121 Add a checking mechanism for the need of Renormalize pass in Parse pipeline

Merge pull request !121 from thlinh/dev_Apr02_add_watch_for_renormalize
...@@ -52,7 +52,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -52,7 +52,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType, {prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
// ops eliminate // ops eliminate
item_tuple_eliminate_ = item_tuple_eliminate_ =
...@@ -81,7 +82,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -81,7 +82,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_make_ref_eliminate_ = get_make_ref_eliminate_ =
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue}); MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>);
replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
// Gradient transforms // Gradient transforms
......
...@@ -31,14 +31,14 @@ ...@@ -31,14 +31,14 @@
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim,
const PrimitivePtr& prim) { const RenormAction& renorm_action) {
auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const std::vector<PrimitivePtr>& prims) { const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) {
auto fn = [prims](const AnfNodePtr& node) -> bool { auto fn = [prims](const AnfNodePtr& node) -> bool {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
return false; return false;
...@@ -52,12 +52,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: ...@@ -52,12 +52,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::
return false; return false;
}; };
return std::make_shared<Substitution>(transform, name, fn); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const PredicateFuncType& predicate) { const PredicateFuncType& predicate, const RenormAction& renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate); return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
} }
AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const {
...@@ -74,6 +74,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode ...@@ -74,6 +74,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode
} }
} }
#endif #endif
if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
if (renorm_action_ == FORCE_RENORM) {
optimizer->add_node_to_renormalize(result);
} else {
// renorm_action_ is CHECK_RENORM
if (result->abstract() == nullptr) {
optimizer->add_node_to_renormalize(result);
}
}
}
return result; return result;
} }
......
...@@ -36,24 +36,34 @@ using OptimizerWeakPtr = std::weak_ptr<Optimizer>; ...@@ -36,24 +36,34 @@ using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>; using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>; using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted
enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class Substitution { class Substitution {
public: public:
TransformFuncType transform_{nullptr}; TransformFuncType transform_{nullptr};
std::string name_; std::string name_;
PredicateFuncType predicate_{nullptr}; PredicateFuncType predicate_{nullptr};
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate) // an enum to mark this Substitution relation to renormalize pass
: transform_(transform), name_(name), predicate_(predicate) {} RenormAction renorm_action_;
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default; ~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
}; };
using SubstitutionPtr = std::shared_ptr<Substitution>; using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim); SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims); const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const PredicateFuncType &predicate); const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
class SubstitutionList { class SubstitutionList {
public: public:
......
...@@ -87,11 +87,12 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; ...@@ -87,11 +87,12 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> { class Optimizer : public std::enable_shared_from_this<Optimizer> {
public: public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
: name_(name), resource_(resource_ptr), run_only_once_(false) {} : name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {}
virtual ~Optimizer() = default; virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) { void Init(const OptPassGroupMap &passes, bool run_only_once) {
run_only_once_ = run_only_once; run_only_once_ = run_only_once;
is_watch_renormalize_ = false;
for (auto &iter : passes) { for (auto &iter : passes) {
const std::string &name = iter.first; const std::string &name = iter.first;
...@@ -118,9 +119,13 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -118,9 +119,13 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
} }
static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
const OptPassGroupMap &passes, bool run_only_once = false) { const OptPassGroupMap &passes, bool run_only_once = false,
bool watch_renormalize = false) {
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr); OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
optimizer->Init(passes, run_only_once); optimizer->Init(passes, run_only_once);
if (watch_renormalize) {
optimizer->enable_watch_renormalize();
}
return optimizer; return optimizer;
} }
...@@ -138,7 +143,16 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -138,7 +143,16 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
if (opt.is_renormalize()) { if (opt.is_renormalize()) {
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_); auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
if (resource_ptr != nullptr) { if (resource_ptr != nullptr) {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); if (is_watch_renormalize_) {
if (untyped_nodes_.size() > 0) {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
clear_untyped_nodes();
} else {
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
}
} else {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
}
} }
} else if (opt(func_graph, shared_from_this())) { } else if (opt(func_graph, shared_from_this())) {
changes = true; changes = true;
...@@ -180,12 +194,26 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { ...@@ -180,12 +194,26 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
const std::string name() const { return name_; } const std::string name() const { return name_; }
void add_node_to_renormalize(AnfNodePtr anode) {
if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) {
untyped_nodes_.push_back(anode);
}
}
void clear_untyped_nodes() { untyped_nodes_.clear(); }
void enable_watch_renormalize() { is_watch_renormalize_ = true; }
void disable_watch_renormalize() { is_watch_renormalize_ = false; }
bool is_watch_renormalize() { return is_watch_renormalize_; }
private: private:
const std::string name_; const std::string name_;
pipeline::ResourceBasePtr resource_; pipeline::ResourceBasePtr resource_;
std::vector<OptPass> passes_; std::vector<OptPass> passes_;
std::vector<std::string> pass_names_; std::vector<std::string> pass_names_;
bool run_only_once_; bool run_only_once_;
std::vector<AnfNodePtr> untyped_nodes_;
bool is_watch_renormalize_;
}; };
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
......
...@@ -185,8 +185,8 @@ void InitOpt(const ResourcePtr& res) { ...@@ -185,8 +185,8 @@ void InitOpt(const ResourcePtr& res) {
if (g_pass_opts.size() == 0) { if (g_pass_opts.size() == 0) {
opt::irpass::OptimizeIRPassLib irpass; opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass)); g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass)); g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册