提交 db80643e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3177 Update Arithmetic Simplify to use Pattern Matcher (2)

Merge pull request !3177 from Giancarlo/pm_arithmetic_simplify
......@@ -21,159 +21,15 @@
#include <memory>
#include <vector>
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0}
// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1}
class MultiplyByZeroOrOne : public AnfVisitor {
public:
MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {}
~MultiplyByZeroOrOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false}, is_one_{false};
ValuePtr zero_, one_;
AnfNodePtr x_{nullptr};
};
// Support class used for checking if all values of a Tensor are equal `check_value_`
// Supported data types: double, float/float32, int/int32
class CheckTensorConstant {
public:
explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {}
~CheckTensorConstant() = default;
bool IsTensorConstant(const ValuePtr &value);
bool IsTensorScalarConstant(const ValuePtr &value);
private:
int check_value_;
};
class TensorMultiplyBase : public AnfVisitor {
protected:
void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false);
// Make a new tensor (when possible) with the same shape as of `node`
// If x is nullptr then fill new tensor will "0"
// If x is a tensor with empty shape then fill new tensor with the single value of x
// If x is a tensor with same shape as `node` then return x as result
AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr);
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0}
class TensorMultiplyByZero : public TensorMultiplyBase {
public:
TensorMultiplyByZero() : zero_(MakeValue(0)) {}
~TensorMultiplyByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false};
ValuePtr zero_;
};
// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1}
class TensorMultiplyByOne : public TensorMultiplyBase {
public:
TensorMultiplyByOne() {}
~TensorMultiplyByOne() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_one_{false};
};
// {prim::kPrimScalarAdd, X, 0}
// {prim::kPrimScalarAdd, 0, X}
class AddByZero : public AnfVisitor {
public:
AddByZero() : zero_(MakeValue(0)) {}
~AddByZero() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
bool is_zero_{false};
ValuePtr zero_;
AnfNodePtr x_{nullptr};
};
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
class TensorAddByZero : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Visit(const ValueNodePtr &vnode) override;
void Reset();
private:
bool is_zero_{false};
AnfNodePtr x_{nullptr};
};
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
class OptUpdateZeroTensor : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} ->
// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}}
class ConstantDuplicateMul : public AnfVisitor {
public:
// Support function to multiply two constant tensors: partially support broadcasting shapes
template <typename T>
void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data,
int out_data_size);
AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3);
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
void Visit(const AnfNodePtr &node) override;
void Reset();
private:
AnfNodePtr vnode_;
AnfNodePtr c_p_node_;
};
class PowerOneEliminate : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// grad = AllReduce(grad) / worker_number
// grad = grad + weight * decy
// ->
......@@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor {
class ArithmeticSimplify : public OptimizerCaller {
public:
ArithmeticSimplify()
: multiply_by_zero_or_one_(std::make_shared<MultiplyByZeroOrOne>()),
tensor_multiply_by_one_(std::make_shared<TensorMultiplyByOne>()),
add_by_zero_(std::make_shared<AddByZero>()),
tensor_add_by_zero_(std::make_shared<TensorAddByZero>()),
identity_(std::make_shared<PrimEliminater>(prim::kPrimIdentity)),
opt_update_zero_tensor_(std::make_shared<OptUpdateZeroTensor>()),
constant_duplicate_mul_(std::make_shared<ConstantDuplicateMul>()),
power_one_(std::make_shared<PowerOneEliminate>()) {
eliminaters_.emplace_back(multiply_by_zero_or_one_);
eliminaters_.emplace_back(tensor_multiply_by_one_);
eliminaters_.emplace_back(add_by_zero_);
eliminaters_.emplace_back(tensor_add_by_zero_);
eliminaters_.emplace_back(identity_);
eliminaters_.emplace_back(opt_update_zero_tensor_);
eliminaters_.emplace_back(constant_duplicate_mul_);
eliminaters_.emplace_back(power_one_);
}
~ArithmeticSimplify() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
private:
OptimizerCallerPtr multiply_by_zero_or_one_;
OptimizerCallerPtr tensor_multiply_by_one_;
OptimizerCallerPtr add_by_zero_;
OptimizerCallerPtr tensor_add_by_zero_;
OptimizerCallerPtr identity_;
OptimizerCallerPtr opt_update_zero_tensor_;
OptimizerCallerPtr constant_duplicate_mul_;
OptimizerCallerPtr power_one_;
std::vector<OptimizerCallerPtr> eliminaters_{};
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
// Arithmetic Simplifications should be done after step_parallel.
......@@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller {
// ArithmeticSimplify and deferred until step_parallel.
class ArithmeticSimplify2 : public OptimizerCaller {
public:
ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared<TensorMultiplyByZero>()) {
eliminaters_.emplace_back(tensor_multiply_by_zero_);
}
~ArithmeticSimplify2() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override;
private:
OptimizerCallerPtr tensor_multiply_by_zero_;
std::vector<OptimizerCallerPtr> eliminaters_{};
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
此差异已折叠。
......@@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common {
};
void SetUp() {
elim_Z = MakeSubstitution(std::make_shared<irpass::AddByZero>(), "elim_Z", prim::kPrimScalarAdd);
elim_Z = MakeSubstitution(std::make_shared<irpass::ArithmeticSimplify>(), "elim_Z", prim::kPrimScalarAdd);
elim_R = MakeSubstitution(std::make_shared<irpass::PrimEliminater>(R), "elim_R", R);
idempotent_P = MakeSubstitution(std::make_shared<IdempotentEliminater>(), "idempotent_P", P);
Qct_to_P = MakeSubstitution(std::make_shared<QctToP>(), "Qct_to_P", Q);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部