diff --git a/mindspore/ccsrc/ir/pattern_matcher.h b/mindspore/ccsrc/ir/pattern_matcher.h index 64703a22d07ee276c920e611f9d1cc700a51361e..6605b9ce4c81e2163cc35a0a4b03528ddeb64ef8 100644 --- a/mindspore/ccsrc/ir/pattern_matcher.h +++ b/mindspore/ccsrc/ir/pattern_matcher.h @@ -17,16 +17,14 @@ #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ #define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ -#include -#include #include #include #include "ir/anf.h" #include "operator/ops.h" -#include "optimizer/optimizer.h" namespace mindspore { + /// /// Base class for all recognizable patterns. /// We implement an Expression Template approach using static polymorphism based on @@ -62,7 +60,7 @@ class PIsEqual { bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } }; -template +template class PatternNode : public PBase > { public: T GetNode(const AnfNodePtr &node) const { @@ -92,13 +90,12 @@ class PatternNode : public PBase > { template class PBinOperation : public PBase > { public: - PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false) - : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {} + PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} AnfNodePtr GetNode(const AnfNodePtr &node) const { AnfNodePtr lhs = x_.GetNode(node->func_graph()); AnfNodePtr rhs = y_.GetNode(node->func_graph()); - AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; + AnfNodePtrList list = {prim_->cast(), lhs, rhs}; return NewCNode(list, node->func_graph()); } @@ -109,14 +106,6 @@ class PBinOperation : public PBase > { if (inputs.size() == 3) { // Binary Prim assumes only two inputs if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { - // If the operation is commutative, then check with inversed operands - if (is_commutative_) { - Reset(); - if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) { - return false; - } - return true; - } return false; } return true; @@ -124,6 +113,7 @@ class PBinOperation : public PBase > { } return false; } + void Reset() const { x_.Reset(); y_.Reset(); @@ -133,7 +123,6 @@ class PBinOperation : public PBase > { const PrimitivePtr prim_; typename T::Internal x_; typename T2::Internal y_; - bool is_commutative_{false}; }; /// @@ -225,6 +214,7 @@ class PCNode : public PBase > { return false; } + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); @@ -265,12 +255,6 @@ class PPrimitive : public PBase > { return false; } - // If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case) - const PPrimitive &Commutative(const bool &is_commutative = true) const { - is_commutative_ = is_commutative; - return *this; - } - void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); @@ -279,435 +263,46 @@ class PPrimitive : public PBase > { private: const PrimitivePtr prim_; std::tuple args_; - mutable bool is_commutative_{false}; -}; - -/// -/// PConstant class can capture a value node of a specified value (check_value_) -/// or a non-specified one (any_value = true). -/// It can be configured to capture a scalar constant as well (is_scalar_ = true) -/// -template -class PConstant : public PBase > { - public: - explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int check_value = 0, - const bool is_scalar = false) - : as_node_(as_node), - captured_node_(as_node), - any_value_(any_value), - check_value_(check_value), - is_scalar_(is_scalar) {} - - // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode - const PConstant &WithShapeAs(const AnfNodePtr &node) const { - as_node_ = node; - changed_shape_ = true; - return *this; - } - - /// Sets captured_node_ as the node captured by the Pattern received as argument - /// to produce a new node with its contents when calling GetNode. - const PConstant &WithValueOf(const PatternNode &pnode) const { - if (!any_value_) { - MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node."; - } - captured_node_ = pnode.GetNode(captured_node_); - changed_shape_ = true; - return *this; - } - - /// Create a new Value Node filled up with check_value. - /// This function must be used immediately before GetNode to avoid replacing the expected result. - const PConstant &NewValue() const { - auto value_node_ = MakeValue(check_value_); - captured_node_ = NewValueNode(value_node_); - is_new_value_node_ = true; - return *this; - } - - AnfNodePtr GetNode(const AnfNodePtr &node) const { - // If a NewValueNode was requested (using NewValue function) then return that created node. - if (is_new_value_node_) { - return captured_node_; - } - /// Return a NewTensorFilledWithData if the node was initialized to have a specific value - /// even if it wasn't captured. Usually for zero constants (x - x => zero). - /// If the shape was changed, use the new shape. - if (changed_shape_ || !captured_) { - if (!any_value_) { - return NewTensorFilledWithData(as_node_, check_value_); - } - return NewTensorFilledWithData(as_node_, captured_node_); - } - return captured_node_; - } - - bool TryCapture_(const AnfNodePtr &node) const { - if (IsValueNode(node)) { - // If any_value_ is set don't check for the node's value. Just capture it. - if (any_value_) { - captured_node_ = node; - captured_ = true; - return true; - } - - auto value = node->cast()->value(); - if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) { - captured_node_ = node; - captured_ = true; - return true; - } - - auto value_node_ = MakeValue(check_value_); - if (*GetValueNode(node) == *value_node_) { - captured_node_ = node; - captured_ = true; - return true; - } - } - return false; - } - - void Reset() const { - captured_ = false; - changed_shape_ = false; - is_new_value_node_ = false; - } - - // Support function used for checking if all values of a Tensor are equal to `check_value_` - // Supported data types: double, float/float32, int/int32 - bool IsTensorConstant(const ValuePtr &value) const { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > FLT_EPSILON) { - return false; - } - } - return true; - } else if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > DBL_EPSILON) { - return false; - } - } - return true; - } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (data2[i] != check_value_) { - return false; - } - } - return true; - } - // Input Data Type is not supported - return false; - } - - bool IsTensorScalarConstant(const ValuePtr &value) const { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { - return false; - } - return IsTensorConstant(value); - } - - void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const { - if (!node->isa()) { - return nullptr; - } - - auto value = node->cast()->value(); - - if (!value->isa()) { - return nullptr; - } - - tensor::TensorPtr tensor_ptr = dyn_cast(value); - return tensor_ptr->data_c(); - } - - // Make a new tensor (when possible) with the same shape as of `node` - // If x is nullptr then fill new tensor will "0" - // If x is a tensor with empty shape then fill new tensor with the single value of x - // If x is a tensor with same shape as `node` then return x as result - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - if (x == nullptr) { - std::memset(data, 0, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - // x is not nullptr - if (x->isa()) { - if ((x->abstract() == nullptr) || !x->abstract()->isa()) { - return nullptr; - } - auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); - - if (x_shape != tensor_shape) { - return nullptr; - } - return x; - } - - if (!x->isa()) { - return nullptr; - } - auto x_value = x->cast()->value(); - if (!x_value->isa()) { - return nullptr; - } - - auto x_tensor_ptr = dyn_cast(x_value); - - if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { - return nullptr; - } - char *source_data = reinterpret_cast(GetPointerToTensorData(x)); - if (x_tensor_ptr->DataSize() == 1) { - for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { - memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); - } - } else { - memcpy(data, source_data, mem_size); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - std::memset(data, value, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - - // Support function to multiply two constant tensors: partially support broadcasting shapes - template - void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size) const { - TM *data_1 = reinterpret_cast(in_data_1); - TM *data_2 = reinterpret_cast(in_data_2); - TM *data_out = new TM[out_data_size]; - - if (in_data_1_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[i]; - } - } - if (in_data_2_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; - } - } else { - if (in_data_2_size < out_data_size) { - MS_EXCEPTION(ValueError) << "in_data_2_size is smaller than out_data_size."; - } - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; - } - } - *out_data = reinterpret_cast(data_out); - return; - } - - AnfNodePtr MulByPatternConst(const PConstant &vpnode_2, const AnfNodePtr &node_3) const { - AnfNodePtr vnode_1 = this->GetNode(captured_node_); - AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); - return MulConstantTensors(vnode_1, vnode_2, node_3); - } - - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const { - if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || - (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { - return nullptr; - } - - auto value_1 = GetValueNode(vnode_1); - auto value_2 = GetValueNode(vnode_2); - - if (!value_1->isa() || !value_2->isa()) { - return nullptr; - } - - auto tensor_ptr_1 = dyn_cast(value_1); - auto tensor_ptr_2 = dyn_cast(value_2); - - auto tensor_1_abstract = vnode_1->abstract()->cast(); - auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); - TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } - - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); - - int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); - - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; - } - - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - int ret = 0; - void *data_out = nullptr; - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - ret = memcpy_s(data, mem_size, data_out, mem_size); - delete[] reinterpret_cast(data_out); - } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - ret = memcpy_s(data, mem_size, data_out, mem_size); - delete[] reinterpret_cast(data_out); - } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - ret = memcpy_s(data, mem_size, data_out, mem_size); - delete[] reinterpret_cast(data_out); - } else { - // Un-support data types - return nullptr; - } - } - } - if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno " << ret << ", source size " << mem_size << "dest size" - << new_tensor_ptr->DataSize(); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - - using Internal = const PConstant &; - - protected: - mutable AnfNodePtr as_node_; - mutable AnfNodePtr captured_node_; - bool any_value_{true}; - int check_value_{0}; - bool is_scalar_{false}; - mutable bool is_new_value_node_{false}; - mutable bool captured_{false}; - mutable bool changed_shape_{false}; }; // Macro for binary operation functions -#define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \ - template \ - inline PBinOperation Operator(const PBase &x, const PBase &y) { \ - return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \ +#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ + template \ + inline PBinOperation Operator(const PBase &x, const PBase &y) { \ + return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ } // Arithmetic operations -BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); -BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); +BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); +BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); // Macros for match and replace #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ if ((CaptureNode).TryCapture(OrigNode)) { \ - auto rep = (ReplaceWith).GetNode(OrigNode); \ - if (rep != nullptr) { \ - return rep; \ - } \ + return (ReplaceWith).GetNode(OrigNode); \ } #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - auto rep = (ReplaceWith).GetNode(OrigNode); \ - if (rep != nullptr) { \ - return rep; \ - } \ + return (ReplaceWith).GetNode(OrigNode); \ } #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ if ((CaptureNode).TryCapture(OrigNode)) { \ if ((Condition)) { \ - auto rep = (ReplaceWith).GetNode(OrigNode); \ - if (rep != nullptr) { \ - return (ReplaceWith).GetNode(OrigNode); \ - } \ - } else { \ - auto rep = (ElseNode).GetNode(OrigNode); \ - if (rep != nullptr) { \ - return (ElseNode).GetNode(OrigNode); \ - } \ + return (ReplaceWith).GetNode(OrigNode); \ } \ + return (ElseNode).GetNode(OrigNode); \ } #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ if ((CaptureNode).TryCapture(OrigNode)) { \ - auto rep = (Lambda)(); \ - if (rep != nullptr) { \ - return rep; \ - } \ + return (Lambda)(); \ } #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - auto rep = (Lambda)(); \ - if (rep != nullptr) { \ - return rep; \ - } \ + return (Lambda)(); \ } } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc index 03da2f0ea7a5105c31f6015176a8480b8949fce2..b111a6b67aae1df0bdbe36615c7f2701b0466f7e 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc @@ -14,67 +14,542 @@ * limitations under the License. */ +#include +#include +#include +#include + #include "optimizer/irpass/arithmetic_simplify.h" +#include "ir/optimizer_caller.h" +#include "ir/visitor.h" +#include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - PatternNode x, y, z, xs; - PConstant one_(node, false, 1); - PConstant one_scalar_(node, false, 1, true); - PConstant zero_(node, false, 0); - PConstant zero_scalar_(node, false, 0, true); - PConstant const_(node); - PConstant const_2(node); - PConstant any_const(node); - - MATCH_REPLACE(node, x + zero_, x); // Add by zero - MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, zero_scalar_, x), x); // Scalar Add by zero - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, x, zero_scalar_), x); // Scalar Add by zero - MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), x.CheckFunc(IsVNode, node)); // Multiply by one - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, one_scalar_, x), x); // Scalar Mul by one - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, one_scalar_), x); // Scalar Mul by one - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, zero_scalar_, x), zero_.NewValue()); // Scalar Mul by zero - MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, zero_scalar_), zero_.NewValue()); // Scalar Mul by zero - - // Prim Eliminate (identity) - MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); - - // ConstantDuplicateMul - auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { - auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node)); - auto mul_node = node->cast()->inputs()[0]; - if (new_mul_tensor == nullptr) { - auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph()); - return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph()); - } - return NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph()); - }; - MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda); - - if (node->func_graph() == nullptr) { - return nullptr; - } - - // OptUpdateZeroTensor - MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z, xs), - PPrimitive(prim::kPrimMakeTuple, z, y)); - - // PowerOneEliminate - MATCH_REPLACE(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x); +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarMul)(node); + + if (is_zero_) { + return NewValueNode(zero_); + } + if (is_one_) { + return x_; + } + return nullptr; +} + +void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { + if (is_one_ || node->isa()) { + x_ = node; + return; + } + + AnfVisitor::Visit(node); + if (!is_one_) { + x_ = node; + } +} + +void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (*value == *zero_) { + is_zero_ = true; + } else if (*value == *one_) { + is_one_ = true; + } +} + +void MultiplyByZeroOrOne::Reset() { + x_ = nullptr; + is_one_ = false; + is_zero_ = false; +} + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > FLT_EPSILON) { + return false; + } + } + return true; + } else if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > DBL_EPSILON) { + return false; + } + } + return true; + } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] != check_value_) { + return false; + } + } + return true; + } + // input Data Types is not supported + return false; +} + +bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { + return false; + } + return IsTensorConstant(value); +} + +void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { + if (!node->isa()) { + return nullptr; + } + + auto value = node->cast()->value(); + + if (!value->isa()) { + return nullptr; + } + + tensor::TensorPtr tensor_ptr = dyn_cast(value); + return tensor_ptr->data_c(); +} + +// Make a new tensor (when possible) with the same shape as of `node` +// If x is nullptr then fill new tensor will "0" +// If x is a tensor with empty shape then fill new tensor with the single value of x +// If x is a tensor with same shape as `node` then return x as result +AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (x == nullptr) { + std::memset(data, 0, mem_size); + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + // x is not nullptr + if (x->isa()) { + if ((x->abstract() == nullptr) || !x->abstract()->isa()) { + return nullptr; + } + auto x_abstract = x->abstract()->cast(); + std::vector x_shape = x_abstract->shape()->shape(); + + if (x_shape != tensor_shape) { + return nullptr; + } + return x; + } + + if (!x->isa()) { + return nullptr; + } + auto x_value = x->cast()->value(); + if (!x_value->isa()) { + return nullptr; + } + + auto x_tensor_ptr = dyn_cast(x_value); + + if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { + return nullptr; + } + char *source_data = reinterpret_cast(GetPointerToTensorData(x)); + if (x_tensor_ptr->DataSize() == 1) { + for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { + memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); + } + } else { + memcpy(data, source_data, mem_size); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_zero_) { + if (x_->func_graph() != node->func_graph()) { + return nullptr; + } + return NewTensorFilledWithData(node); + } + return nullptr; +} + +void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { + if (is_zero_) { + x_ = node; + return; + } + + if (IsParam(node)) { + x_ = node; + return; + } + + if (IsCNode(node)) { + CNodePtr cnode = node->cast(); + if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { + is_zero_ = true; + return; + } + x_ = node; + return; + } + auto value = node->cast()->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimMul)(node); + + if (is_one_) { + return NewTensorFilledWithData(node, x_); + } + return nullptr; +} + +void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { + if (is_one_) { + x_ = node; + return; + } + + if (IsParam(node) || IsCNode(node)) { + x_ = node; + return; + } + + auto value = node->cast()->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = node; +} + +void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(1).IsTensorConstant(value)) { + is_one_ = true; + return; + } + x_ = vnode; +} +void TensorMultiplyByOne::Reset() { + x_ = nullptr; + is_one_ = false; +} + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimScalarAdd)(node); + + if (is_zero_) { + return x_; + } + return nullptr; +} + +void AddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && + ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void AddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + AnfVisitor::Match(prim::kPrimTensorAdd)(node); + if (is_zero_) { + return x_; + } return nullptr; } -AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - PatternNode x, y; - PConstant zero_(node, false, 0); +void TensorAddByZero::Visit(const AnfNodePtr &node) { + if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { + is_zero_ = true; + return; + } + + x_ = node; +} + +void TensorAddByZero::Visit(const ValueNodePtr &vnode) { + auto value = vnode->value(); + if (CheckTensorConstant(0).IsTensorConstant(value)) { + is_zero_ = true; + return; + } +} + +void TensorAddByZero::Reset() { + x_ = nullptr; + is_zero_ = false; +} + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { + return nullptr; + } + + // {PrimMomentum, {...}, Y, Z, Xs} + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { + return nullptr; + } + auto y = inputs[2]; + auto z = inputs[3]; + + // {kPrimZerosLike, X} + if (inputs[1]->cast()->size() != 2) { + return nullptr; + } + + // {prim::kPrimMakeTuple, Z, Y} + return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); +} + +// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +// Support function to multiply two constant tensors: partially support broadcasting shapes +template +void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, + void **out_data, int out_data_size) { + T *data_1 = reinterpret_cast(in_data_1); + T *data_2 = reinterpret_cast(in_data_2); + T *data_out = new T[out_data_size]; + + if (in_data_1_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[i]; + } + } + if (in_data_2_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[i]; + } + } + *out_data = reinterpret_cast(data_out); + return; +} + +AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, + const AnfNodePtr &node_3) { + if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || + (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { + return nullptr; + } + + auto value_1 = GetValueNode(vnode_1); + auto value_2 = GetValueNode(vnode_2); + + if (!value_1->isa() || !value_2->isa()) { + return nullptr; + } + + auto tensor_ptr_1 = dyn_cast(value_1); + auto tensor_ptr_2 = dyn_cast(value_2); + + auto tensor_1_abstract = vnode_1->abstract()->cast(); + auto tensor_2_abstract = vnode_1->abstract()->cast(); + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); + TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } - MATCH_REPLACE(node, x * zero_, zero_); // Multiply by zero - MATCH_REPLACE(node, x * PPrimitive(prim::kPrimZerosLike, y), zero_); // Multiply by zero + std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + + void *data_out; + + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), + &data_out, data_out_size); + } else { + if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + // Un-support data types + return nullptr; + } + } + } + + auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); + size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + memcpy(data, data_out, mem_size); + + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; +} + +AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + Reset(); + // {prim::kPrimMul, Tensor1, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + + if (!IsCNode(c_p_node_)) { + return nullptr; + } + + auto tensor1 = vnode_; + auto mul = c_p_node_->cast(); + + Reset(); + // {prim::kPrimMul, Tensor2, {...}} + AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); + if (vnode_ == nullptr || c_p_node_ == nullptr) { + return nullptr; + } + auto tensor2 = vnode_; + auto c_p_node = c_p_node_; + + auto PrimMul = GetValueNode(mul->input(0)); + auto fg = node->func_graph(); + + auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); + if (new_mul_tensor == nullptr) { + auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); + return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); + } + return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); +} + +void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { + if (IsValueNode(node)) { + vnode_ = node; + } + + if (IsCNode(node) || IsParam(node)) { + c_p_node_ = node; + } +} + +void ConstantDuplicateMul::Reset() { + vnode_ = nullptr; + c_p_node_ = nullptr; +} + +AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (!IsValueNode(inputs[2])) { + return nullptr; + } + auto scalar = GetValueNode(inputs[2]); + if (scalar->isa() && GetValue(scalar) == 1.0) { + return inputs[1]; + } else if (scalar->isa() && GetValue(scalar) == 1) { + return inputs[1]; + } return nullptr; } @@ -179,6 +654,27 @@ void AdjustAllReduceMulAdd::Reset() { all_reduce_fg_ = nullptr; } +AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} + +AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { + AnfNodePtr new_node; + for (auto &eliminater : eliminaters_) { + new_node = (*eliminater)(optimizer, node); + if (new_node != nullptr) { + return new_node; + } + } + return nullptr; +} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 3ba85c4ed3311ba42f6ad9f21ee20cdc16ffff65..f4bdb0d655987765812bcef966fdb0deb7b3f4c1 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -22,14 +22,158 @@ #include #include "ir/optimizer_caller.h" -#include "ir/pattern_matcher.h" #include "ir/visitor.h" +#include "operator/ops.h" #include "optimizer/irpass.h" #include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { +// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} +// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} +class MultiplyByZeroOrOne : public AnfVisitor { + public: + MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} + ~MultiplyByZeroOrOne() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}, is_one_{false}; + ValuePtr zero_, one_; + AnfNodePtr x_{nullptr}; +}; + +// Support class used for checking if all values of a Tensor are equal `check_value_` +// Supported data types: double, float/float32, int/int32 +class CheckTensorConstant { + public: + explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} + ~CheckTensorConstant() = default; + + bool IsTensorConstant(const ValuePtr &value); + bool IsTensorScalarConstant(const ValuePtr &value); + + private: + int check_value_; +}; + +class TensorMultiplyBase : public AnfVisitor { + protected: + void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); + + // Make a new tensor (when possible) with the same shape as of `node` + // If x is nullptr then fill new tensor will "0" + // If x is a tensor with empty shape then fill new tensor with the single value of x + // If x is a tensor with same shape as `node` then return x as result + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); + + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} +class TensorMultiplyByZero : public TensorMultiplyBase { + public: + TensorMultiplyByZero() : zero_(MakeValue(0)) {} + ~TensorMultiplyByZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}; + ValuePtr zero_; +}; + +// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} +class TensorMultiplyByOne : public TensorMultiplyBase { + public: + TensorMultiplyByOne() {} + ~TensorMultiplyByOne() override = default; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_one_{false}; +}; + +// {prim::kPrimScalarAdd, X, 0} +// {prim::kPrimScalarAdd, 0, X} +class AddByZero : public AnfVisitor { + public: + AddByZero() : zero_(MakeValue(0)) {} + ~AddByZero() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + bool is_zero_{false}; + ValuePtr zero_; + AnfNodePtr x_{nullptr}; +}; + +// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, +// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} +class TensorAddByZero : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Visit(const ValueNodePtr &vnode) override; + void Reset(); + + private: + bool is_zero_{false}; + AnfNodePtr x_{nullptr}; +}; + +// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} +class OptUpdateZeroTensor : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> +// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} +class ConstantDuplicateMul : public AnfVisitor { + public: + // Support function to multiply two constant tensors: partially support broadcasting shapes + template + void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, + int out_data_size); + + AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + void Visit(const AnfNodePtr &node) override; + void Reset(); + + private: + AnfNodePtr vnode_; + AnfNodePtr c_p_node_; +}; + +class PowerOneEliminate : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + // grad = AllReduce(grad) / worker_number // grad = grad + weight * decy // -> @@ -56,7 +200,39 @@ class AdjustAllReduceMulAdd : public AnfVisitor { class ArithmeticSimplify : public OptimizerCaller { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + ArithmeticSimplify() + : multiply_by_zero_or_one_(std::make_shared()), + tensor_multiply_by_one_(std::make_shared()), + add_by_zero_(std::make_shared()), + tensor_add_by_zero_(std::make_shared()), + identity_(std::make_shared(prim::kPrimIdentity)), + opt_update_zero_tensor_(std::make_shared()), + constant_duplicate_mul_(std::make_shared()), + power_one_(std::make_shared()) { + eliminaters_.emplace_back(multiply_by_zero_or_one_); + eliminaters_.emplace_back(tensor_multiply_by_one_); + eliminaters_.emplace_back(add_by_zero_); + eliminaters_.emplace_back(tensor_add_by_zero_); + eliminaters_.emplace_back(identity_); + eliminaters_.emplace_back(opt_update_zero_tensor_); + eliminaters_.emplace_back(constant_duplicate_mul_); + eliminaters_.emplace_back(power_one_); + } + ~ArithmeticSimplify() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + + private: + OptimizerCallerPtr multiply_by_zero_or_one_; + OptimizerCallerPtr tensor_multiply_by_one_; + OptimizerCallerPtr add_by_zero_; + OptimizerCallerPtr tensor_add_by_zero_; + OptimizerCallerPtr identity_; + OptimizerCallerPtr opt_update_zero_tensor_; + OptimizerCallerPtr constant_duplicate_mul_; + OptimizerCallerPtr power_one_; + + std::vector eliminaters_{}; }; // Arithmetic Simplifications should be done after step_parallel. @@ -66,9 +242,17 @@ class ArithmeticSimplify : public OptimizerCaller { // ArithmeticSimplify and deferred until step_parallel. class ArithmeticSimplify2 : public OptimizerCaller { public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; + ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { + eliminaters_.emplace_back(tensor_multiply_by_zero_); + } + ~ArithmeticSimplify2() = default; + + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; + private: + OptimizerCallerPtr tensor_multiply_by_zero_; + std::vector eliminaters_{}; +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index 6de982f999b0d93cafd29cce1a38b04008910311..b6a4e1c85238a09d13f14c2a63432ce64ddbdb1c 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -25,8 +25,10 @@ #include "ir/optimizer_caller.h" #include "ir/pattern_matcher.h" #include "ir/visitor.h" +#include "operator/ops.h" #include "optimizer/irpass.h" #include "optimizer/irpass/prim_eliminate.h" +#include "optimizer/optimizer.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 6c4aa8f56f9ac8fbcb91e407a792fe70228ca896..2428d0dddb36117e9a1e27fd17c3bbf31161f5b4 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common { }; void SetUp() { - elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); + elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q);