提交 aabec55c 编写于 作者: G Giancarlo Colmenares

Removing TransformFuncType

上级 b4a66d47
...@@ -17,13 +17,23 @@ ...@@ -17,13 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #ifndef MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #define MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h" #include "ir/anf.h"
#include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
} // namespace opt
class OptimizerCaller { class OptimizerCaller {
public: public:
virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; } virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &) { return nullptr; }
}; };
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_ #endif // MINDSPORE_CCSRC_IR_OPTIMIZER_CALLER_H_
...@@ -14,140 +14,154 @@ ...@@ -14,140 +14,154 @@
* limitations under the License. * limitations under the License.
*/ */
#include "optimizer/irpass.h"
#include <string> #include <string>
#include "optimizer/irpass/symbol_resolver.h" #include "optimizer/irpass.h"
#include "optimizer/irpass/arithmetic_simplify.h" #include "optimizer/irpass/arithmetic_simplify.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/branch_culling.h" #include "optimizer/irpass/branch_culling.h"
#include "optimizer/irpass/cast_eliminate.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/env_item_eliminate.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/gradient_eliminate.h" #include "optimizer/irpass/gradient_eliminate.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/inline.h" #include "optimizer/irpass/inline.h"
#include "optimizer/irpass/convert.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h" #include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h" #include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/param_replace.h" #include "optimizer/irpass/item_tuple_eliminate.h"
#include "optimizer/irpass/mark_interface_fusion.h" #include "optimizer/irpass/mark_interface_fusion.h"
#include "optimizer/irpass/merge_addn.h"
#include "optimizer/irpass/minmax_grad.h"
#include "optimizer/irpass/param_replace.h"
#include "optimizer/irpass/partial_eliminate.h"
#include "optimizer/irpass/reduce_eliminate.h"
#include "optimizer/irpass/ref_eliminate.h"
#include "optimizer/irpass/reshape_eliminate.h"
#include "optimizer/irpass/special_op_eliminate.h"
#include "optimizer/irpass/specialize_transform.h"
#include "optimizer/irpass/symbol_resolver.h"
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h" #include "optimizer/opt.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
OptimizeIRPassLib::OptimizeIRPassLib() { OptimizeIRPassLib::OptimizeIRPassLib() {
arithmetic_simplify_ = MakeSubstitution(ArithmeticSimplify(), "arithmetic_simplify", arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
{prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd, {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimTensorAdd,
prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow}); prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
arithmetic_simplify2_ = MakeSubstitution(ArithmeticSimplify2(), "arithmetic_simplify2", {prim::kPrimMul}); arithmetic_simplify2_ =
MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
special_op_eliminate_ = special_op_eliminate_ =
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate", MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward, {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike); zero_like_fill_zero_ =
adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
adjust_all_reduce_mul_add_ =
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
// ops eliminate // ops eliminate
item_tuple_eliminate_ = item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
MakeSubstitution(ItemTupleEliminater(), "item_tuple_eliminate", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem}); {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem});
tile_eliminate_ = MakeSubstitution(TileMultiplyByOne(), "tile_eliminate", prim::kPrimTile); tile_eliminate_ = MakeSubstitution(std::make_shared<TileMultiplyByOne>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(CastEliminater(), "cast_eliminate", prim::kPrimCast); cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(ReshapeEliminater(), "reshape_eliminate", prim::kPrimReshape); reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
transpose_eliminate_ = MakeSubstitution(TransposeSameIOEliminater(), "transpose_eliminate", prim::kPrimTranspose); transpose_eliminate_ =
MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
reduce_eliminate_ = MakeSubstitution( reduce_eliminate_ = MakeSubstitution(
ReduceOneEliminater(), "reduce_eliminate", std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin}); {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(PartialEliminater(), "partial_eliminate", IsCNodeDup); partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape); same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop); check_bprop_eliminate_ =
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>); MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend); reset_defer_inline_ =
MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
// Env Item Eliminate // Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem); env_get_item_eliminate_ =
new_env_get_item_ = MakeSubstitution(NewEnvGetItem(), "new_env_get_item", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_ = incorporate_env_getitem_ =
MakeSubstitution(IncorporateEnvGetitem(), "incorporate_env_get_item", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
MakeSubstitution(IncorporateEnvGetitemSwitch(), "incorporate_env_getitem_switch", prim::kPrimEnvGetItem); "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
// Ref eliminate // Ref eliminate
make_ref_eliminate_ = MakeSubstitution(MakeRefEliminater(), "make_ref_eliminate", prim::kPrimMakeRef); make_ref_eliminate_ =
get_ref_param_eliminate_ = MakeSubstitution(GetRefParamEliminater(), "get_ref_param_eliminate", MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
get_ref_param_eliminate_ = MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate",
{prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); {prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
get_make_ref_eliminate_ = MakeSubstitution(GetMakeRefEliminater(), "get_make_ref_eliminate", get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
{prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin}); {prim::kPrimGetRefKey, prim::kPrimGetRefValue, prim::kPrimGetRefOrigin});
replace_refkey_by_param_ = replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
MakeSubstitution(ReplaceRefkeyByParam(), "replace_refkey_by_param", IsValueNode<RefKey>, opt::FORCE_RENORM); IsValueNode<RefKey>, opt::FORCE_RENORM);
replace_old_param_ = MakeSubstitution(ReplaceOldParam(), "replace_old_param", IsParam); replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
// Gradient transforms // Gradient transforms
expand_jprim_ = MakeSubstitution(ExpandJPrim(), "expand_jprim", prim::kPrimJ); expand_jprim_ = MakeSubstitution(std::make_shared<ExpandJPrim>(), "expand_jprim", prim::kPrimJ);
minmaximum_grad_ = MakeSubstitution(MinMaximumGrad(), "minmaximum_grad", prim::kPrimTupleGetItem); minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
// branch culling // branch culling
switch_simplify_ = MakeSubstitution(SwitchSimplify(), "switch_simplify", prim::kPrimSwitch); switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
float_tuple_getitem_switch_ = float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
MakeSubstitution(FloatTupleGetItemSwitch(), "float_tuple_getitem_switch", prim::kPrimTupleGetItem); "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
float_env_getitem_switch_ = float_env_getitem_switch_ =
MakeSubstitution(FloatEnvGetItemSwitch(), "float_env_getitem_switch", prim::kPrimEnvGetItem); MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
convert_switch_replacement_ = MakeSubstitution(ConvertSwitchReplacement(), "convert_switch_replacement", IsCNodeDup); convert_switch_replacement_ =
MakeSubstitution(std::make_shared<ConvertSwitchReplacement>(), "convert_switch_replacement", IsCNodeDup);
// Addn // Addn
merge_addn_ = MakeSubstitution(MergeAddN(), "merge_addn", prim::kPrimAddN); merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
addn_zero_filter_ = MakeSubstitution(AddNZeroFilter(), "addn_zero_filter", prim::kPrimAddN); addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
// inline // inline
inline_ = MakeSubstitution(Inliner(), "inline", IsCNodeGraph); inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
replace_applicator_ = MakeSubstitution(ReplaceApplicator(), "replace_applicator", IsValueNode<FuncGraph>); replace_applicator_ =
specialize_transform_ = MakeSubstitution(SpecializeOnGraphArguments(), "specialize_transform", IsCNodeGraph); MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
specialize_transform_ =
MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
// Incorporation // Incorporation
incorporate_getitem_set_ = incorporate_getitem_set_ =
MakeSubstitution(IncorporateGetitemSet(), "incorporate_getitem_set", prim::kPrimTupleGetItem); MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
incorporate_getitem_from_param_ = incorporate_getitem_from_param_ = MakeSubstitution(std::make_shared<IncorporateGetitemFromParam>(),
MakeSubstitution(IncorporateGetitemFromParam(), "incorporate_getitem_from_param", IsCNodeGraphKernel); "incorporate_getitem_from_param", IsCNodeGraphKernel);
incorporate_call_ = MakeSubstitution(IncorporateCall(), "incorporate_call", IsCNodeDup); incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
incorporate_call_switch_ = MakeSubstitution(IncorporateCallSwitch(), "incorporate_call_switch", IsCNodeDup); incorporate_call_switch_ =
MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
// Virtual Dataset // Virtual Dataset
virtual_dataset_eliminate_ = virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
MakeSubstitution(VirtualDatasetEliminater(), "virtual_dataset_eliminate", prim::kPrimVirtualDataset); "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Convert // Convert
print_tuple_wrapper_ = MakeSubstitution(PrintTupleWrapper(), "print_tuple_wrapper", prim::kPrimPrint); print_tuple_wrapper_ =
MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
// Unused parameter eliminate // Unused parameter eliminate
unused_parameter_eliminate_ = unused_parameter_eliminate_ =
MakeSubstitution(UnusedParasEliminater(), "unused_parameter_eliminate", IsCNodeGraphKernel); MakeSubstitution(std::make_shared<UnusedParasEliminater>(), "unused_parameter_eliminate", IsCNodeGraphKernel);
unused_output_eliminate_ = MakeSubstitution(UnusedOutputEliminater(), "unused_output_eliminate", IsCNodeGraphKernel); unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// AddN eliminate // AddN eliminate
addn_eliminate_ = MakeSubstitution(AddNEliminater(), "addn_eliminate", IsCNodeGraphKernel); addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
// Mark interface fusion // Mark interface fusion
mark_interface_fusion_ = MakeSubstitution(MarkInterfaceFusion(), "mark_interface_fusion", prim::kPrimSelect); mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {
resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve);
resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetattr>(), "resolver_getattr", prim::kPrimGetAttr);
} }
InferenceOptPrepareLib::InferenceOptPrepareLib() { InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
} }
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ARITHMETIC_SIMPLIFY_H_
#include <vector>
#include <memory>
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h" #include "ir/optimizer_caller.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor { ...@@ -739,17 +740,17 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
FuncGraphPtr all_reduce_fg_{nullptr}; FuncGraphPtr all_reduce_fg_{nullptr};
}; };
class ArithmeticSimplify { class ArithmeticSimplify : public OptimizerCaller {
public: public:
ArithmeticSimplify() ArithmeticSimplify()
: multiply_by_zero_or_one_(), : multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(), tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(), add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(), tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(prim::kPrimIdentity), identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(), opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(), constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_() { power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_); eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_); eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_); eliminaters_.emplace_back(add_by_zero_);
...@@ -761,10 +762,10 @@ class ArithmeticSimplify { ...@@ -761,10 +762,10 @@ class ArithmeticSimplify {
} }
~ArithmeticSimplify() = default; ~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -773,15 +774,9 @@ class ArithmeticSimplify { ...@@ -773,15 +774,9 @@ class ArithmeticSimplify {
} }
private: private:
MultiplyByZeroOrOne multiply_by_zero_or_one_; OptimizerCallerPtr multiply_by_zero_or_one_, tensor_multiply_by_one_, add_by_zero_, tensor_add_by_zero_, identity_,
TensorMultiplyByOne tensor_multiply_by_one_; opt_update_zero_tensor_, constant_duplicate_mul_, power_one_;
AddByZero add_by_zero_; std::vector<OptimizerCallerPtr> eliminaters_{};
TensorAddByZero tensor_add_by_zero_;
PrimEliminater identity_;
OptUpdateZeroTensor opt_update_zero_tensor_;
ConstantDuplicateMul constant_duplicate_mul_;
PowerOneEliminate power_one_;
std::vector<TransformFuncType> eliminaters_{};
}; };
// Arithmetic Simplifications should be done after step_parallel. // Arithmetic Simplifications should be done after step_parallel.
...@@ -789,15 +784,17 @@ class ArithmeticSimplify { ...@@ -789,15 +784,17 @@ class ArithmeticSimplify {
// with shape(weight), but after step_parallel, shape of weight may be changed, so the // with shape(weight), but after step_parallel, shape of weight may be changed, so the
// shape of the constant tensor should also be changed. So this pass is seperated from // shape of the constant tensor should also be changed. So this pass is seperated from
// ArithmeticSimplify and deferred until step_parallel. // ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 { class ArithmeticSimplify2 : public OptimizerCaller {
public: public:
ArithmeticSimplify2() : tensor_multiply_by_zero_() { eliminaters_.emplace_back(tensor_multiply_by_zero_); } ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default; ~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -806,8 +803,8 @@ class ArithmeticSimplify2 { ...@@ -806,8 +803,8 @@ class ArithmeticSimplify2 {
} }
private: private:
TensorMultiplyByZero tensor_multiply_by_zero_; OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<OptimizerCallerPtr> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "optimizer/irpass.h" #include "optimizer/irpass.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "ir/visitor.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor { ...@@ -52,12 +52,12 @@ class TwoCastEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, t_{nullptr}; AnfNodePtr x_{nullptr}, t_{nullptr};
}; };
class CastEliminater { class CastEliminater : public OptimizerCaller {
public: public:
CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {} CastEliminater() : cast_same_type_eliminater_(), two_cast_eliminater_() {}
~CastEliminater() = default; ~CastEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = cast_same_type_eliminater_(optimizer, node); auto new_node = cast_same_type_eliminater_(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
......
...@@ -17,18 +17,19 @@ ...@@ -17,18 +17,19 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
#include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <memory> #include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
namespace mindspore { namespace mindspore {
...@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor { ...@@ -225,19 +226,22 @@ class EnvGetSetItem : public AnfVisitor {
bool is_match_{false}; bool is_match_{false};
}; };
class EnvGetItemEliminater { class EnvGetItemEliminater : public OptimizerCaller {
public: public:
EnvGetItemEliminater() : new_env_get_item_(), add_env_get_item_(), env_get_set_item_() { EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()) {
eliminaters_.emplace_back(new_env_get_item_); eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_); eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_); eliminaters_.emplace_back(env_get_set_item_);
} }
~EnvGetItemEliminater() = default; ~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -246,10 +250,8 @@ class EnvGetItemEliminater { ...@@ -246,10 +250,8 @@ class EnvGetItemEliminater {
} }
private: private:
NewEnvGetItem new_env_get_item_; OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_;
AddEnvGetItem add_env_get_item_; std::vector<OptimizerCallerPtr> eliminaters_{};
EnvGetSetItem env_get_set_item_;
std::vector<TransformFuncType> eliminaters_{};
}; };
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y} // {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
......
...@@ -17,18 +17,20 @@ ...@@ -17,18 +17,20 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INCORPORATE_GETITEM_H_
#include <vector>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <memory> #include <memory>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
...@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor { ...@@ -383,18 +385,20 @@ class IncorporateGetitemSwitch : public AnfVisitor {
internal::GetitemTransform getitem_transform_; internal::GetitemTransform getitem_transform_;
}; };
class IncorporateGetitemSet { class IncorporateGetitemSet : public OptimizerCaller {
public: public:
IncorporateGetitemSet() : incorporate_getitem_(), incorporate_getitem_switch_() { IncorporateGetitemSet()
: incorporate_getitem_(std::make_shared<IncorporateGetitem>()),
incorporate_getitem_switch_(std::make_shared<IncorporateGetitemSwitch>()) {
eliminaters_.emplace_back(incorporate_getitem_); eliminaters_.emplace_back(incorporate_getitem_);
eliminaters_.emplace_back(incorporate_getitem_switch_); eliminaters_.emplace_back(incorporate_getitem_switch_);
} }
~IncorporateGetitemSet() = default; ~IncorporateGetitemSet() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -403,9 +407,8 @@ class IncorporateGetitemSet { ...@@ -403,9 +407,8 @@ class IncorporateGetitemSet {
} }
private: private:
IncorporateGetitem incorporate_getitem_; OptimizerCallerPtr incorporate_getitem_, incorporate_getitem_switch_;
IncorporateGetitemSwitch incorporate_getitem_switch_; std::vector<OptimizerCallerPtr> eliminaters_{};
std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
#include <vector>
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/irpass.h" #include "ir/optimizer_caller.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
...@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor { ...@@ -261,14 +263,14 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
}; };
class ItemTupleEliminater { class ItemTupleEliminater : public OptimizerCaller {
public: public:
ItemTupleEliminater() ItemTupleEliminater()
: get_item_eliminater_(), : get_item_eliminater_(std::make_shared<GetitemEliminater>()),
get_item_const_eliminater_(), get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
set_item_eliminater_(), set_item_eliminater_(std::make_shared<SetitemEliminater>()),
get_set_item_eliminater_(), get_set_item_eliminater_(std::make_shared<GetSetitemEliminater>()),
get_item_depend_reorder_() { get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()) {
eliminaters_.emplace_back(get_item_eliminater_); eliminaters_.emplace_back(get_item_eliminater_);
eliminaters_.emplace_back(get_item_const_eliminater_); eliminaters_.emplace_back(get_item_const_eliminater_);
eliminaters_.emplace_back(set_item_eliminater_); eliminaters_.emplace_back(set_item_eliminater_);
...@@ -277,10 +279,10 @@ class ItemTupleEliminater { ...@@ -277,10 +279,10 @@ class ItemTupleEliminater {
} }
~ItemTupleEliminater() = default; ~ItemTupleEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -289,12 +291,9 @@ class ItemTupleEliminater { ...@@ -289,12 +291,9 @@ class ItemTupleEliminater {
} }
private: private:
GetitemEliminater get_item_eliminater_; OptimizerCallerPtr get_item_eliminater_, get_item_const_eliminater_, set_item_eliminater_, get_set_item_eliminater_,
GetitemConstEliminater get_item_const_eliminater_; get_item_depend_reorder_;
SetitemEliminater set_item_eliminater_; std::vector<OptimizerCallerPtr> eliminaters_{};
GetSetitemEliminater get_set_item_eliminater_;
GetitemDependReorder get_item_depend_reorder_;
std::vector<TransformFuncType> eliminaters_{};
}; };
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
......
...@@ -19,9 +19,9 @@ ...@@ -19,9 +19,9 @@
#include <memory> #include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/pattern_matcher.h" #include "ir/pattern_matcher.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
......
...@@ -19,11 +19,12 @@ ...@@ -19,11 +19,12 @@
#include <vector> #include <vector>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "pipeline/static_analysis/dshape.h" #include "pipeline/static_analysis/dshape.h"
namespace mindspore { namespace mindspore {
...@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor { ...@@ -124,12 +125,12 @@ class TwoReshapeEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr}, shape_{nullptr}; AnfNodePtr x_{nullptr}, shape_{nullptr};
}; };
class ReshapeEliminater { class ReshapeEliminater : public OptimizerCaller {
public: public:
ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {} ReshapeEliminater() : reshape_same_shape_eliminater_(), two_reshape_eliminater_() {}
~ReshapeEliminater() = default; ~ReshapeEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
auto new_node = reshape_same_shape_eliminater_(optimizer, node); auto new_node = reshape_same_shape_eliminater_(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
......
...@@ -18,31 +18,31 @@ ...@@ -18,31 +18,31 @@
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
#include <securec.h> #include <securec.h>
#include <vector>
#include <memory>
#include <algorithm> #include <algorithm>
#include <memory>
#include <vector>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h" #include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h" #include "ir/pattern_matcher.h"
#include "ir/visitor.h" #include "ir/visitor.h"
#include "operator/ops.h" #include "operator/ops.h"
#include "ir/pattern_matcher.h" #include "optimizer/irpass.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "optimizer/optimizer.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
class SpecialOpEliminater { class SpecialOpEliminater : public OptimizerCaller {
public: public:
SpecialOpEliminater() SpecialOpEliminater()
: insert_gradient_of_(prim::kPrimInsertGradientOf), : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
stop_gradient_(prim::kPrimStopGradient), stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
hook_backward_(prim::kPrimHookBackward), hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
print_shape_type_(prim::kPrimPrintShapeType), print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
get_ref_value_(prim::kPrimGetRefValue), get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
mirror_(prim::kPrimMirror), mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
virtual_div_(prim::kPrimVirtualDiv) { virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
eliminaters_.emplace_back(insert_gradient_of_); eliminaters_.emplace_back(insert_gradient_of_);
eliminaters_.emplace_back(stop_gradient_); eliminaters_.emplace_back(stop_gradient_);
eliminaters_.emplace_back(hook_backward_); eliminaters_.emplace_back(hook_backward_);
...@@ -53,10 +53,10 @@ class SpecialOpEliminater { ...@@ -53,10 +53,10 @@ class SpecialOpEliminater {
} }
~SpecialOpEliminater() = default; ~SpecialOpEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node; AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) { for (auto &eliminater : eliminaters_) {
new_node = eliminater(optimizer, node); new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) { if (new_node != nullptr) {
return new_node; return new_node;
} }
...@@ -65,9 +65,9 @@ class SpecialOpEliminater { ...@@ -65,9 +65,9 @@ class SpecialOpEliminater {
} }
private: private:
PrimEliminater insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
virtual_div_; virtual_div_;
std::vector<TransformFuncType> eliminaters_{}; std::vector<OptimizerCallerPtr> eliminaters_{};
}; };
// {PrimVirtualDataset, X} -> X // {PrimVirtualDataset, X} -> X
......
...@@ -16,28 +16,27 @@ ...@@ -16,28 +16,27 @@
#include "optimizer/opt.h" #include "optimizer/opt.h"
#include <algorithm>
#include <deque>
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
#include <deque>
#include <algorithm>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "utils/ordered_set.h"
#include "utils/log_adapter.h"
#include "optimizer/optimizer.h" #include "optimizer/optimizer.h"
#include "utils/log_adapter.h"
#include "utils/ordered_set.h"
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &renorm_action) { const RenormAction &renorm_action) {
auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); };
return std::make_shared<Substitution>(transform, name, fn, renorm_action); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) { const std::vector<PrimitivePtr> &prims, const RenormAction &renorm_action) {
auto fn = [prims](const AnfNodePtr &node) -> bool { auto fn = [prims](const AnfNodePtr &node) -> bool {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
...@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std:: ...@@ -64,16 +63,16 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::
return std::make_shared<Substitution>(transform, name, fn, renorm_action); return std::make_shared<Substitution>(transform, name, fn, renorm_action);
} }
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &renorm_action) { const PredicateFuncType &predicate, const RenormAction &renorm_action) {
return std::make_shared<Substitution>(transform, name, predicate, renorm_action); return std::make_shared<Substitution>(transform, name, predicate, renorm_action);
} }
AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
double t = GetTime(); double t = GetTime();
#endif #endif
AnfNodePtr result = transform_(optimizer, node); AnfNodePtr result = (*transform_)(optimizer, node);
#ifdef ENABLE_PROFILE #ifdef ENABLE_PROFILE
if (optimizer != nullptr) { if (optimizer != nullptr) {
auto time = GetTime(); auto time = GetTime();
......
...@@ -17,24 +17,18 @@ ...@@ -17,24 +17,18 @@
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_ #define MINDSPORE_CCSRC_OPTIMIZER_OPT_H_
#include <vector>
#include <string>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "operator/ops.h" #include "operator/ops.h"
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using TransformFuncType = std::function<AnfNodePtr(const OptimizerPtr &, const AnfNodePtr &)>;
// Define the interaction mode between an Optimize pass and Renormalize pass // Define the interaction mode between an Optimize pass and Renormalize pass
// FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed // FORCE_RENORM: if the pass modified the graph then the next Renormalize will be executed
...@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM }; ...@@ -43,26 +37,26 @@ enum RenormAction : int { FORCE_RENORM = 0, CHECK_RENORM };
class Substitution { class Substitution {
public: public:
TransformFuncType transform_{nullptr}; OptimizerCallerPtr transform_;
std::string name_; std::string name_;
PredicateFuncType predicate_{nullptr}; PredicateFuncType predicate_{nullptr};
// an enum to mark this Substitution relation to renormalize pass // an enum to mark this Substitution relation to renormalize pass
RenormAction renorm_action_; RenormAction renorm_action_;
Substitution(const TransformFuncType &transform, const std::string &name, const PredicateFuncType &predicate, Substitution(const OptimizerCallerPtr &transform, const std::string &name, const PredicateFuncType &predicate,
const RenormAction &renorm_action) const RenormAction &renorm_action)
: transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {} : transform_(transform), name_(name), predicate_(predicate), renorm_action_(renorm_action) {}
~Substitution() = default; ~Substitution() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node);
}; };
using SubstitutionPtr = std::shared_ptr<Substitution>; using SubstitutionPtr = std::shared_ptr<Substitution>;
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, const PrimitivePtr &prim,
const RenormAction &action_renorm = CHECK_RENORM); const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const std::vector<PrimitivePtr> &prims, const std::vector<PrimitivePtr> &prims,
const RenormAction &action_renorm = CHECK_RENORM); const RenormAction &action_renorm = CHECK_RENORM);
SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name,
const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM);
class SubstitutionList { class SubstitutionList {
......
...@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common { ...@@ -77,10 +77,10 @@ class TestOptOpt : public UT::Common {
}; };
void SetUp() { void SetUp() {
elim_Z = MakeSubstitution(irpass::AddByZero(), "elim_Z", prim::kPrimScalarAdd); elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(irpass::PrimEliminater(R), "elim_R", R); elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(IdempotentEliminater(), "idempotent_P", P); idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(QctToP(), "Qct_to_P", Q); Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
} }
bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) { bool CheckTransform(FuncGraphPtr gbefore, FuncGraphPtr gafter, const SubstitutionList &transform) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册