提交 d04a62b7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!121 Add a checking mechanism for the need of Renormalize pass in Parse pipeline

Merge pull request !121 from thlinh/dev_Apr02_add_watch_for_renormalize
......@@ -52,7 +52,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
// ops eliminate
item_tuple_eliminate_ =
......@@ -81,7 +82,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef);
get_make_ref_eliminate_ =
MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
replace_refkey_by_param_ = MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>);
replace_refkey_by_param_ =
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam);
// Gradient transforms
......
......@@ -31,14 +31,14 @@
namespace mindspore {
/* namespace to support opt */
namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const PrimitivePtr& prim) {
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim,
const RenormAction& renorm_action) {
auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn);
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const std::vector<PrimitivePtr>& prims) {
const std::vector<PrimitivePtr>& prims, const RenormAction& renorm_action) {
auto fn = [prims](const AnfNodePtr& node) -> bool {
if (!node->isa<CNode>()) {
return false;
......@@ -52,12 +52,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::
return false;
};
return std::make_shared<Substitution>(transform, name, fn);
return std::make_shared<Substitution>(transform, name, fn, renorm_action);
}
SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name,
const PredicateFuncType& predicate) {
return std::make_shared<Substitution>(transform, name, predicate);
const PredicateFuncType& predicate, const RenormAction& renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
}
AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const {
......@@ -74,6 +74,16 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode
}
}
#endif
if (optimizer != nullptr && optimizer->is_watch_renormalize() && result != nullptr) {
if (renorm_action_ == FORCE_RENORM) {
optimizer->add_node_to_renormalize(result);
} else {
// renorm_action_ is CHECK_RENORM
if (result->abstract() == nullptr) {
optimizer->add_node_to_renormalize(result);
}
}
}
return result;
}
......
......@@ -36,24 +36,34 @@ using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
// Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
// CHECK_RENORM: check if the new node is un-typed to decide if the next Renormalize will be executted
enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class Substitution {
public:
TransformFuncType transform_{nullptr};
std::string name_;
PredicateFuncType predicate_{nullptr};
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate)
: transform_(transform), name_(name), predicate_(predicate) {}
// an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_;
explicit Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const;
};
using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims);
const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name,
const PredicateFuncType &predicate);
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
class SubstitutionList {
public:
......
......@@ -87,11 +87,12 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>;
class Optimizer : public std::enable_shared_from_this<Optimizer> {
public:
Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr)
: name_(name), resource_(resource_ptr), run_only_once_(false) {}
: name_(name), resource_(resource_ptr), run_only_once_(false), is_watch_renormalize_(false) {}
virtual ~Optimizer() = default;
void Init(const OptPassGroupMap &passes, bool run_only_once) {
run_only_once_ = run_only_once;
is_watch_renormalize_ = false;
for (auto &iter : passes) {
const std::string &name = iter.first;
......@@ -118,9 +119,13 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr,
const OptPassGroupMap &passes, bool run_only_once = false) {
const OptPassGroupMap &passes, bool run_only_once = false,
bool watch_renormalize = false) {
OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr);
optimizer->Init(passes, run_only_once);
if (watch_renormalize) {
optimizer->enable_watch_renormalize();
}
return optimizer;
}
......@@ -138,7 +143,16 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
if (opt.is_renormalize()) {
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
if (resource_ptr != nullptr) {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
if (is_watch_renormalize_) {
if (untyped_nodes_.size() > 0) {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
clear_untyped_nodes();
} else {
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
}
} else {
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
}
}
} else if (opt(func_graph, shared_from_this())) {
changes = true;
......@@ -180,12 +194,26 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
const std::string name() const { return name_; }
void add_node_to_renormalize(AnfNodePtr anode) {
if (std::find(untyped_nodes_.begin(), untyped_nodes_.end(), anode) == untyped_nodes_.end()) {
untyped_nodes_.push_back(anode);
}
}
void clear_untyped_nodes() { untyped_nodes_.clear(); }
void enable_watch_renormalize() { is_watch_renormalize_ = true; }
void disable_watch_renormalize() { is_watch_renormalize_ = false; }
bool is_watch_renormalize() { return is_watch_renormalize_; }
private:
const std::string name_;
pipeline::ResourceBasePtr resource_;
std::vector<OptPass> passes_;
std::vector<std::string> pass_names_;
bool run_only_once_;
std::vector<AnfNodePtr> untyped_nodes_;
bool is_watch_renormalize_;
};
} // namespace opt
} // namespace mindspore
......
......@@ -185,8 +185,8 @@ void InitOpt(const ResourcePtr& res) {
if (g_pass_opts.size() == 0) {
opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass));
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), false, true);
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册