diff --git a/mindspore/ccsrc/ir/pattern_matcher.h b/mindspore/ccsrc/ir/pattern_matcher.h index 6605b9ce4c81e2163cc35a0a4b03528ddeb64ef8..97a546fad504c025224fd06cdbef1b8c76dbc637 100644 --- a/mindspore/ccsrc/ir/pattern_matcher.h +++ b/mindspore/ccsrc/ir/pattern_matcher.h @@ -17,14 +17,16 @@ #ifndef MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ #define MINDSPORE_CCSRC_IR_PATTERN_MATCHER_H_ +#include +#include #include #include #include "ir/anf.h" #include "operator/ops.h" +#include "optimizer/optimizer.h" namespace mindspore { - /// /// Base class for all recognizable patterns. /// We implement an Expression Template approach using static polymorphism based on @@ -60,7 +62,7 @@ class PIsEqual { bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; } }; -template +template class PatternNode : public PBase > { public: T GetNode(const AnfNodePtr &node) const { @@ -90,12 +92,13 @@ class PatternNode : public PBase > { template class PBinOperation : public PBase > { public: - PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y) : prim_(prim), x_(x), y_(y) {} + PBinOperation(const PrimitivePtr &prim, const T &x, const T2 &y, bool is_commutative = false) + : prim_(prim), x_(x), y_(y), is_commutative_(is_commutative) {} AnfNodePtr GetNode(const AnfNodePtr &node) const { AnfNodePtr lhs = x_.GetNode(node->func_graph()); AnfNodePtr rhs = y_.GetNode(node->func_graph()); - AnfNodePtrList list = {prim_->cast(), lhs, rhs}; + AnfNodePtrList list = {NewValueNode(prim_), lhs, rhs}; return NewCNode(list, node->func_graph()); } @@ -106,6 +109,14 @@ class PBinOperation : public PBase > { if (inputs.size() == 3) { // Binary Prim assumes only two inputs if (!x_.TryCapture_(inputs[1]) || !y_.TryCapture_(inputs[2])) { + // If the operation is commutative, then check with inversed operands + if (is_commutative_) { + Reset(); + if (!x_.TryCapture_(inputs[2]) || !y_.TryCapture_(inputs[1])) { + return false; + } + return true; + } return false; } return true; @@ -113,7 +124,6 @@ class PBinOperation : public PBase > { } return false; } - void Reset() const { x_.Reset(); y_.Reset(); @@ -123,6 +133,7 @@ class PBinOperation : public PBase > { const PrimitivePtr prim_; typename T::Internal x_; typename T2::Internal y_; + bool is_commutative_{false}; }; /// @@ -214,7 +225,6 @@ class PCNode : public PBase > { return false; } - void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); @@ -255,6 +265,12 @@ class PPrimitive : public PBase > { return false; } + // If set to true, TryCapture will try to capture the nodes in iversed nodes as well (only for two input case) + const PPrimitive &Commutative(const bool &is_commutative = true) const { + is_commutative_ = is_commutative; + return *this; + } + void Reset() const { tuple_utils::PTupleResetCapture reset; tuple_utils::apply_func_tuple(&reset, args_); @@ -263,46 +279,424 @@ class PPrimitive : public PBase > { private: const PrimitivePtr prim_; std::tuple args_; + mutable bool is_commutative_{false}; +}; + +/// +/// PConstant class can capture a value node of a specified value (check_value_) +/// or a non-specified one (any_value = true). +/// It can be configured to capture a scalar constant as well (is_scalar_ = true) +/// +template +class PConstant : public PBase > { + public: + explicit PConstant(const AnfNodePtr &as_node, const bool any_value = true, const int check_value = 0, + const bool is_scalar = false) + : as_node_(as_node), + captured_node_(as_node), + any_value_(any_value), + check_value_(check_value), + is_scalar_(is_scalar) {} + + // Sets as_node_ as the node received as argument to produce a same-shape node with GetNode + const PConstant &WithShapeAs(const AnfNodePtr &node) const { + as_node_ = node; + changed_shape_ = true; + return *this; + } + + /// Sets captured_node_ as the node captured by the Pattern received as argument + /// to produce a new node with its contents when calling GetNode. + const PConstant &WithValueOf(const PatternNode &pnode) const { + if (!any_value_) { + MS_EXCEPTION(ValueError) << "Must use a PConstant with `any_value = true` to use the value of another node."; + } + captured_node_ = pnode.GetNode(captured_node_); + changed_shape_ = true; + return *this; + } + + /// Create a new Value Node filled up with check_value. + /// This function must be used immediately before GetNode to avoid replacing the expected result. + const PConstant &NewValue() const { + auto value_node_ = MakeValue(check_value_); + captured_node_ = NewValueNode(value_node_); + is_new_value_node_ = true; + return *this; + } + + AnfNodePtr GetNode(const AnfNodePtr &node) const { + // If a NewValueNode was requested (using NewValue function) then return that created node. + if (is_new_value_node_) { + return captured_node_; + } + /// Return a NewTensorFilledWithData if the node was initialized to have a specific value + /// even if it wasn't captured. Usually for zero constants (x - x => zero). + /// If the shape was changed, use the new shape. + if (changed_shape_ || !captured_) { + if (!any_value_) { + return NewTensorFilledWithData(as_node_, check_value_); + } + return NewTensorFilledWithData(as_node_, captured_node_); + } + return captured_node_; + } + + bool TryCapture_(const AnfNodePtr &node) const { + if (IsValueNode(node)) { + // If any_value_ is set don't check for the node's value. Just capture it. + if (any_value_) { + captured_node_ = node; + captured_ = true; + return true; + } + + auto value = node->cast()->value(); + if ((is_scalar_ && IsTensorScalarConstant(value)) || (!is_scalar_ && IsTensorConstant(value))) { + captured_node_ = node; + captured_ = true; + return true; + } + + auto value_node_ = MakeValue(check_value_); + if (*GetValueNode(node) == *value_node_) { + captured_node_ = node; + captured_ = true; + return true; + } + } + return false; + } + + void Reset() const { + captured_ = false; + changed_shape_ = false; + is_new_value_node_ = false; + } + + // Support function used for checking if all values of a Tensor are equal to `check_value_` + // Supported data types: double, float/float32, int/int32 + bool IsTensorConstant(const ValuePtr &value) const { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + TypeId tensor_type = tensor_ptr->Dtype()->type_id(); + if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { + float *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > FLT_EPSILON) { + return false; + } + } + return true; + } else if (tensor_type == TypeId::kNumberTypeFloat64) { + double *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (fabs(data2[i] - check_value_) > DBL_EPSILON) { + return false; + } + } + return true; + } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { + int *data2 = reinterpret_cast(tensor_ptr->data_c()); + for (int i = 0; i < tensor_ptr->DataSize(); i++) { + if (data2[i] != check_value_) { + return false; + } + } + return true; + } + // Input Data Type is not supported + return false; + } + + bool IsTensorScalarConstant(const ValuePtr &value) const { + if (!value->isa()) { + return false; + } + auto tensor_ptr = dyn_cast(value); + if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { + return false; + } + return IsTensorConstant(value); + } + + void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false) const { + if (!node->isa()) { + return nullptr; + } + + auto value = node->cast()->value(); + + if (!value->isa()) { + return nullptr; + } + + tensor::TensorPtr tensor_ptr = dyn_cast(value); + return tensor_ptr->data_c(); + } + + // Make a new tensor (when possible) with the same shape as of `node` + // If x is nullptr then fill new tensor will "0" + // If x is a tensor with empty shape then fill new tensor with the single value of x + // If x is a tensor with same shape as `node` then return x as result + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr) const { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + if (x == nullptr) { + std::memset(data, 0, mem_size); + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + // x is not nullptr + if (x->isa()) { + if ((x->abstract() == nullptr) || !x->abstract()->isa()) { + return nullptr; + } + auto x_abstract = x->abstract()->cast(); + std::vector x_shape = x_abstract->shape()->shape(); + + if (x_shape != tensor_shape) { + return nullptr; + } + return x; + } + + if (!x->isa()) { + return nullptr; + } + auto x_value = x->cast()->value(); + if (!x_value->isa()) { + return nullptr; + } + + auto x_tensor_ptr = dyn_cast(x_value); + + if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { + return nullptr; + } + char *source_data = reinterpret_cast(GetPointerToTensorData(x)); + if (x_tensor_ptr->DataSize() == 1) { + for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { + memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); + } + } else { + memcpy(data, source_data, mem_size); + } + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + + AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const int &value) const { + if ((node->abstract() == nullptr) || !node->abstract()->isa()) { + return nullptr; + } + + auto tensor_abstract = node->abstract()->cast(); + TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); + std::vector tensor_shape = tensor_abstract->shape()->shape(); + + auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); + size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + + std::memset(data, value, mem_size); + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + + // Support function to multiply two constant tensors: partially support broadcasting shapes + template + void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, + int out_data_size) const { + TM *data_1 = reinterpret_cast(in_data_1); + TM *data_2 = reinterpret_cast(in_data_2); + TM *data_out = new TM[out_data_size]; + + if (in_data_1_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] = data_1[i]; + } + } + if (in_data_2_size == 1) { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[0]; + } + } else { + for (int i = 0; i < out_data_size; i++) { + data_out[i] *= data_2[i]; + } + } + *out_data = reinterpret_cast(data_out); + return; + } + + AnfNodePtr MulByPatternConst(const PConstant &vpnode_2, const AnfNodePtr &node_3) const { + AnfNodePtr vnode_1 = this->GetNode(captured_node_); + AnfNodePtr vnode_2 = vpnode_2.GetNode(captured_node_); + return MulConstantTensors(vnode_1, vnode_2, node_3); + } + + AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3) const { + if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || + (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { + return nullptr; + } + + auto value_1 = GetValueNode(vnode_1); + auto value_2 = GetValueNode(vnode_2); + + if (!value_1->isa() || !value_2->isa()) { + return nullptr; + } + + auto tensor_ptr_1 = dyn_cast(value_1); + auto tensor_ptr_2 = dyn_cast(value_2); + + auto tensor_1_abstract = vnode_1->abstract()->cast(); + auto tensor_2_abstract = vnode_1->abstract()->cast(); + auto tensor_3_abstract = node_3->abstract()->cast(); + + TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); + TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); + TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); + + if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || + (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { + return nullptr; + } + + std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + + int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); + + if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { + return nullptr; + } + if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { + return nullptr; + } + + void *data_out; + + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || + (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { + Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), + tensor_ptr_2->DataSize(), &data_out, data_out_size); + } else { + // Un-support data types + return nullptr; + } + } + } + + auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); + size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); + char *data = reinterpret_cast(new_tensor_ptr->data_c()); + memcpy(data, data_out, mem_size); + + auto new_vnode = NewValueNode(new_tensor_ptr); + new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); + return new_vnode; + } + + using Internal = const PConstant &; + + protected: + mutable AnfNodePtr as_node_; + mutable AnfNodePtr captured_node_; + bool any_value_{true}; + int check_value_{0}; + bool is_scalar_{false}; + mutable bool is_new_value_node_{false}; + mutable bool captured_{false}; + mutable bool changed_shape_{false}; }; // Macro for binary operation functions -#define BIN_OPERATION_PATTERN(Operator, MSPrimitive) \ - template \ - inline PBinOperation Operator(const PBase &x, const PBase &y) { \ - return PBinOperation(MSPrimitive, x.get_object(), y.get_object()); \ +#define BIN_OPERATION_PATTERN(Operator, MSPrimitive, Commutative) \ + template \ + inline PBinOperation Operator(const PBase &x, const PBase &y) { \ + return PBinOperation(MSPrimitive, x.get_object(), y.get_object(), Commutative); \ } // Arithmetic operations -BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd); -BIN_OPERATION_PATTERN(operator*, prim::kPrimMul); +BIN_OPERATION_PATTERN(operator+, prim::kPrimTensorAdd, true); +BIN_OPERATION_PATTERN(operator*, prim::kPrimMul, true); // Macros for match and replace #define MATCH_REPLACE(OrigNode, CaptureNode, ReplaceWith) \ if ((CaptureNode).TryCapture(OrigNode)) { \ - return (ReplaceWith).GetNode(OrigNode); \ + auto rep = (ReplaceWith).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return rep; \ + } \ } #define MATCH_REPLACE_IF(OrigNode, CaptureNode, ReplaceWith, Condition) \ if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ + auto rep = (ReplaceWith).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return rep; \ + } \ } #define MATCH_REPLACE_IF_ELSE(OrigNode, CaptureNode, ReplaceWith, Condition, ElseNode) \ if ((CaptureNode).TryCapture(OrigNode)) { \ if ((Condition)) { \ - return (ReplaceWith).GetNode(OrigNode); \ + auto rep = (ReplaceWith).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return (ReplaceWith).GetNode(OrigNode); \ + } \ + } else { \ + auto rep = (ElseNode).GetNode(OrigNode); \ + if (rep != nullptr) { \ + return (ElseNode).GetNode(OrigNode); \ + } \ } \ - return (ElseNode).GetNode(OrigNode); \ } #define MATCH_REPLACE_LAMBDA(OrigNode, CaptureNode, Lambda) \ if ((CaptureNode).TryCapture(OrigNode)) { \ - return (Lambda)(); \ + auto rep = (Lambda)(); \ + if (rep != nullptr) { \ + return rep; \ + } \ } #define MATCH_REPLACE_LAMBDA_IF(OrigNode, CaptureNode, Lambda, Condition) \ if ((CaptureNode).TryCapture(OrigNode) && (Condition)) { \ - return (Lambda)(); \ + auto rep = (Lambda)(); \ + if (rep != nullptr) { \ + return rep; \ + } \ } } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc index b111a6b67aae1df0bdbe36615c7f2701b0466f7e..03da2f0ea7a5105c31f6015176a8480b8949fce2 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.cc @@ -14,542 +14,67 @@ * limitations under the License. */ -#include -#include -#include -#include - #include "optimizer/irpass/arithmetic_simplify.h" -#include "ir/optimizer_caller.h" -#include "ir/visitor.h" -#include "operator/ops.h" -#include "optimizer/irpass.h" -#include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -AnfNodePtr MultiplyByZeroOrOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarMul)(node); - - if (is_zero_) { - return NewValueNode(zero_); - } - if (is_one_) { - return x_; - } - return nullptr; -} - -void MultiplyByZeroOrOne::Visit(const AnfNodePtr &node) { - if (is_one_ || node->isa()) { - x_ = node; - return; - } - - AnfVisitor::Visit(node); - if (!is_one_) { - x_ = node; - } -} - -void MultiplyByZeroOrOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (*value == *zero_) { - is_zero_ = true; - } else if (*value == *one_) { - is_one_ = true; - } -} - -void MultiplyByZeroOrOne::Reset() { - x_ = nullptr; - is_one_ = false; - is_zero_ = false; -} - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -bool CheckTensorConstant::IsTensorConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - TypeId tensor_type = tensor_ptr->Dtype()->type_id(); - if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { - float *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > FLT_EPSILON) { - return false; - } - } - return true; - } else if (tensor_type == TypeId::kNumberTypeFloat64) { - double *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (fabs(data2[i] - check_value_) > DBL_EPSILON) { - return false; - } - } - return true; - } else if ((tensor_type == TypeId::kNumberTypeInt32) || (tensor_type == TypeId::kNumberTypeInt)) { - int *data2 = reinterpret_cast(tensor_ptr->data_c()); - for (int i = 0; i < tensor_ptr->DataSize(); i++) { - if (data2[i] != check_value_) { - return false; - } - } - return true; - } - // input Data Types is not supported - return false; -} - -bool CheckTensorConstant::IsTensorScalarConstant(const ValuePtr &value) { - if (!value->isa()) { - return false; - } - auto tensor_ptr = dyn_cast(value); - if ((tensor_ptr->DataSize() > 1) || (tensor_ptr->DataDim() > 0)) { - return false; - } - return IsTensorConstant(value); -} - -void *TensorMultiplyBase::GetPointerToTensorData(const AnfNodePtr &node, bool writable) { - if (!node->isa()) { - return nullptr; - } - - auto value = node->cast()->value(); - - if (!value->isa()) { - return nullptr; - } - - tensor::TensorPtr tensor_ptr = dyn_cast(value); - return tensor_ptr->data_c(); -} - -// Make a new tensor (when possible) with the same shape as of `node` -// If x is nullptr then fill new tensor will "0" -// If x is a tensor with empty shape then fill new tensor with the single value of x -// If x is a tensor with same shape as `node` then return x as result -AnfNodePtr TensorMultiplyBase::NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x) { - if ((node->abstract() == nullptr) || !node->abstract()->isa()) { - return nullptr; - } - - auto tensor_abstract = node->abstract()->cast(); - TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); - - auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); - size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - - if (x == nullptr) { - std::memset(data, 0, mem_size); - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; - } - // x is not nullptr - if (x->isa()) { - if ((x->abstract() == nullptr) || !x->abstract()->isa()) { - return nullptr; - } - auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); - - if (x_shape != tensor_shape) { - return nullptr; - } - return x; - } - - if (!x->isa()) { - return nullptr; - } - auto x_value = x->cast()->value(); - if (!x_value->isa()) { - return nullptr; - } - - auto x_tensor_ptr = dyn_cast(x_value); - - if ((x_tensor_ptr->DataSize() > 1) && (x_tensor_ptr->DataSize() != new_tensor_ptr->DataSize())) { - return nullptr; - } - char *source_data = reinterpret_cast(GetPointerToTensorData(x)); - if (x_tensor_ptr->DataSize() == 1) { - for (int i = 0; i < new_tensor_ptr->ElementsNum(); i++) { - memcpy(data + i * GetTypeByte(tensor_type_ptr), source_data, GetTypeByte(tensor_type_ptr)); - } - } else { - memcpy(data, source_data, mem_size); - } - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -AnfNodePtr TensorMultiplyByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_zero_) { - if (x_->func_graph() != node->func_graph()) { - return nullptr; - } - return NewTensorFilledWithData(node); - } - return nullptr; -} - -void TensorMultiplyByZero::Visit(const AnfNodePtr &node) { - if (is_zero_) { - x_ = node; - return; - } - - if (IsParam(node)) { - x_ = node; - return; - } - - if (IsCNode(node)) { - CNodePtr cnode = node->cast(); - if (IsPrimitive(cnode->input(0), prim::kPrimZerosLike)) { - is_zero_ = true; - return; - } - x_ = node; - return; - } - auto value = node->cast()->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -AnfNodePtr TensorMultiplyByOne::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimMul)(node); - - if (is_one_) { - return NewTensorFilledWithData(node, x_); - } - return nullptr; -} - -void TensorMultiplyByOne::Visit(const AnfNodePtr &node) { - if (is_one_) { - x_ = node; - return; - } - - if (IsParam(node) || IsCNode(node)) { - x_ = node; - return; - } - - auto value = node->cast()->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = node; -} - -void TensorMultiplyByOne::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(1).IsTensorConstant(value)) { - is_one_ = true; - return; - } - x_ = vnode; -} -void TensorMultiplyByOne::Reset() { - x_ = nullptr; - is_one_ = false; -} - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -AnfNodePtr AddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimScalarAdd)(node); - - if (is_zero_) { - return x_; - } - return nullptr; -} - -void AddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && - ((*GetValueNode(node) == *zero_) || CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node)))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void AddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -AnfNodePtr TensorAddByZero::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - AnfVisitor::Match(prim::kPrimTensorAdd)(node); +AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode x, y, z, xs; + PConstant one_(node, false, 1); + PConstant one_scalar_(node, false, 1, true); + PConstant zero_(node, false, 0); + PConstant zero_scalar_(node, false, 0, true); + PConstant const_(node); + PConstant const_2(node); + PConstant any_const(node); + + MATCH_REPLACE(node, x + zero_, x); // Add by zero + MATCH_REPLACE(node, x + zero_scalar_, x); // Add by zero + MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, zero_scalar_, x), x); // Scalar Add by zero + MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarAdd, x, zero_scalar_), x); // Scalar Add by zero + MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), x.CheckFunc(IsVNode, node)); // Multiply by one + MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, one_scalar_, x), x); // Scalar Mul by one + MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, one_scalar_), x); // Scalar Mul by one + MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, zero_scalar_, x), zero_.NewValue()); // Scalar Mul by zero + MATCH_REPLACE(node, PPrimitive(prim::kPrimScalarMul, x, zero_scalar_), zero_.NewValue()); // Scalar Mul by zero + + // Prim Eliminate (identity) + MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); + + // ConstantDuplicateMul + auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { + auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node)); + auto mul_node = node->cast()->inputs()[0]; + if (new_mul_tensor == nullptr) { + auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph()); + return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph()); + } + return NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph()); + }; + MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda); + + if (node->func_graph() == nullptr) { + return nullptr; + } + + // OptUpdateZeroTensor + MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z, xs), + PPrimitive(prim::kPrimMakeTuple, z, y)); + + // PowerOneEliminate + MATCH_REPLACE(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x); - if (is_zero_) { - return x_; - } return nullptr; } -void TensorAddByZero::Visit(const AnfNodePtr &node) { - if (node->isa() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { - is_zero_ = true; - return; - } - - x_ = node; -} - -void TensorAddByZero::Visit(const ValueNodePtr &vnode) { - auto value = vnode->value(); - if (CheckTensorConstant(0).IsTensorConstant(value)) { - is_zero_ = true; - return; - } -} - -void TensorAddByZero::Reset() { - x_ = nullptr; - is_zero_ = false; -} - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -AnfNodePtr OptUpdateZeroTensor::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimMomentum) || node->func_graph() == nullptr) { - return nullptr; - } - - // {PrimMomentum, {...}, Y, Z, Xs} - auto &inputs = node->cast()->inputs(); - if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) { - return nullptr; - } - auto y = inputs[2]; - auto z = inputs[3]; - - // {kPrimZerosLike, X} - if (inputs[1]->cast()->size() != 2) { - return nullptr; - } - - // {prim::kPrimMakeTuple, Z, Y} - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), z, y}); -} - -// {prim::kPrimMul, Tensor1, {prim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -// Support function to multiply two constant tensors: partially support broadcasting shapes -template -void ConstantDuplicateMul::Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, - void **out_data, int out_data_size) { - T *data_1 = reinterpret_cast(in_data_1); - T *data_2 = reinterpret_cast(in_data_2); - T *data_out = new T[out_data_size]; - - if (in_data_1_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] = data_1[i]; - } - } - if (in_data_2_size == 1) { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[0]; - } - } else { - for (int i = 0; i < out_data_size; i++) { - data_out[i] *= data_2[i]; - } - } - *out_data = reinterpret_cast(data_out); - return; -} - -AnfNodePtr ConstantDuplicateMul::MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, - const AnfNodePtr &node_3) { - if (!vnode_1->isa() || !vnode_2->isa() || (vnode_1->abstract() == nullptr) || - (vnode_2->abstract() == nullptr) || (node_3->abstract() == nullptr)) { - return nullptr; - } - - auto value_1 = GetValueNode(vnode_1); - auto value_2 = GetValueNode(vnode_2); - - if (!value_1->isa() || !value_2->isa()) { - return nullptr; - } - - auto tensor_ptr_1 = dyn_cast(value_1); - auto tensor_ptr_2 = dyn_cast(value_2); - - auto tensor_1_abstract = vnode_1->abstract()->cast(); - auto tensor_2_abstract = vnode_1->abstract()->cast(); - auto tensor_3_abstract = node_3->abstract()->cast(); - - TypePtr tensor_1_type_ptr = tensor_1_abstract->element()->BuildType(); - TypePtr tensor_2_type_ptr = tensor_2_abstract->element()->BuildType(); - TypePtr tensor_3_type_ptr = tensor_3_abstract->element()->BuildType(); - - if ((tensor_1_type_ptr->type_id() != tensor_3_type_ptr->type_id()) || - (tensor_2_type_ptr->type_id() != tensor_3_type_ptr->type_id())) { - return nullptr; - } +AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + PatternNode x, y; + PConstant zero_(node, false, 0); - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + MATCH_REPLACE(node, x * zero_, zero_); // Multiply by zero + MATCH_REPLACE(node, x * PPrimitive(prim::kPrimZerosLike, y), zero_); // Multiply by zero - int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); - - if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { - return nullptr; - } - if ((tensor_ptr_2->DataSize() > 1) && (tensor_ptr_2->DataSize() != data_out_size)) { - return nullptr; - } - - void *data_out; - - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), tensor_ptr_2->DataSize(), - &data_out, data_out_size); - } else { - if (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeFloat64) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - if ((tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt32) || - (tensor_3_type_ptr->type_id() == TypeId::kNumberTypeInt)) { - Multiply(tensor_ptr_1->data_c(), tensor_ptr_1->DataSize(), tensor_ptr_2->data_c(), - tensor_ptr_2->DataSize(), &data_out, data_out_size); - } else { - // Un-support data types - return nullptr; - } - } - } - - auto new_tensor_ptr = std::make_shared(tensor_3_type_ptr->type_id(), tensor_out_shape); - size_t mem_size = GetTypeByte(tensor_3_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); - char *data = reinterpret_cast(new_tensor_ptr->data_c()); - memcpy(data, data_out, mem_size); - - auto new_vnode = NewValueNode(new_tensor_ptr); - new_vnode->set_abstract(new_tensor_ptr->ToAbstract()); - return new_vnode; -} - -AnfNodePtr ConstantDuplicateMul::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - Reset(); - // {prim::kPrimMul, Tensor1, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(node); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - - if (!IsCNode(c_p_node_)) { - return nullptr; - } - - auto tensor1 = vnode_; - auto mul = c_p_node_->cast(); - - Reset(); - // {prim::kPrimMul, Tensor2, {...}} - AnfVisitor::Match(prim::kPrimMul, {IsNode, IsNode})(mul); - if (vnode_ == nullptr || c_p_node_ == nullptr) { - return nullptr; - } - auto tensor2 = vnode_; - auto c_p_node = c_p_node_; - - auto PrimMul = GetValueNode(mul->input(0)); - auto fg = node->func_graph(); - - auto new_mul_tensor = MulConstantTensors(tensor1, tensor2, c_p_node); - if (new_mul_tensor == nullptr) { - auto ttmul = NewCNode({NewValueNode(PrimMul), tensor1, tensor2}, fg); - return NewCNode({NewValueNode(PrimMul), c_p_node, ttmul}, fg); - } - return NewCNode({NewValueNode(PrimMul), c_p_node, new_mul_tensor}, fg); -} - -void ConstantDuplicateMul::Visit(const AnfNodePtr &node) { - if (IsValueNode(node)) { - vnode_ = node; - } - - if (IsCNode(node) || IsParam(node)) { - c_p_node_ = node; - } -} - -void ConstantDuplicateMul::Reset() { - vnode_ = nullptr; - c_p_node_ = nullptr; -} - -AnfNodePtr PowerOneEliminate::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimPow) || node->func_graph() == nullptr) { - return nullptr; - } - - auto &inputs = node->cast()->inputs(); - if (!IsValueNode(inputs[2])) { - return nullptr; - } - auto scalar = GetValueNode(inputs[2]); - if (scalar->isa() && GetValue(scalar) == 1.0) { - return inputs[1]; - } else if (scalar->isa() && GetValue(scalar) == 1) { - return inputs[1]; - } return nullptr; } @@ -654,27 +179,6 @@ void AdjustAllReduceMulAdd::Reset() { all_reduce_fg_ = nullptr; } -AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} - -AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; -} } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index f4bdb0d655987765812bcef966fdb0deb7b3f4c1..3ba85c4ed3311ba42f6ad9f21ee20cdc16ffff65 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -22,158 +22,14 @@ #include #include "ir/optimizer_caller.h" +#include "ir/pattern_matcher.h" #include "ir/visitor.h" -#include "operator/ops.h" #include "optimizer/irpass.h" #include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" namespace mindspore { namespace opt { namespace irpass { -// {prim::kPrimScalarMul, 0, X}, {prim::kPrimScalarMul, X, 0} -// {prim::kPrimScalarMul, 1, X}, {prim::kPrimScalarMul, X, 1} -class MultiplyByZeroOrOne : public AnfVisitor { - public: - MultiplyByZeroOrOne() : zero_(MakeValue(0)), one_(MakeValue(1)) {} - ~MultiplyByZeroOrOne() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}, is_one_{false}; - ValuePtr zero_, one_; - AnfNodePtr x_{nullptr}; -}; - -// Support class used for checking if all values of a Tensor are equal `check_value_` -// Supported data types: double, float/float32, int/int32 -class CheckTensorConstant { - public: - explicit CheckTensorConstant(int _check_value = 0) : check_value_(_check_value) {} - ~CheckTensorConstant() = default; - - bool IsTensorConstant(const ValuePtr &value); - bool IsTensorScalarConstant(const ValuePtr &value); - - private: - int check_value_; -}; - -class TensorMultiplyBase : public AnfVisitor { - protected: - void *GetPointerToTensorData(const AnfNodePtr &node, bool writable = false); - - // Make a new tensor (when possible) with the same shape as of `node` - // If x is nullptr then fill new tensor will "0" - // If x is a tensor with empty shape then fill new tensor with the single value of x - // If x is a tensor with same shape as `node` then return x as result - AnfNodePtr NewTensorFilledWithData(const AnfNodePtr &node, const AnfNodePtr &x = nullptr); - - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimMul, 0, X}, {prim::kPrimMul, X, 0} -class TensorMultiplyByZero : public TensorMultiplyBase { - public: - TensorMultiplyByZero() : zero_(MakeValue(0)) {} - ~TensorMultiplyByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; -}; - -// {prim::kPrimMul, 1, X}, {prim::kPrimMul, X, 1} -class TensorMultiplyByOne : public TensorMultiplyBase { - public: - TensorMultiplyByOne() {} - ~TensorMultiplyByOne() override = default; - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_one_{false}; -}; - -// {prim::kPrimScalarAdd, X, 0} -// {prim::kPrimScalarAdd, 0, X} -class AddByZero : public AnfVisitor { - public: - AddByZero() : zero_(MakeValue(0)) {} - ~AddByZero() override = default; - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - bool is_zero_{false}; - ValuePtr zero_; - AnfNodePtr x_{nullptr}; -}; - -// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X}, -// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}} -class TensorAddByZero : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Visit(const ValueNodePtr &vnode) override; - void Reset(); - - private: - bool is_zero_{false}; - AnfNodePtr x_{nullptr}; -}; - -// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y} -class OptUpdateZeroTensor : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - -// {prim::kPrimMul, Tensor1, {orim::kPrimMul, Tensor2, {...}}} -> -// {prim::kPrimMul, {...}, {prim::kPrimMul, Tensor1, Tensor2}} -class ConstantDuplicateMul : public AnfVisitor { - public: - // Support function to multiply two constant tensors: partially support broadcasting shapes - template - void Multiply(void *in_data_1, int in_data_1_size, void *in_data_2, int in_data_2_size, void **out_data, - int out_data_size); - - AnfNodePtr MulConstantTensors(const AnfNodePtr &vnode_1, const AnfNodePtr &vnode_2, const AnfNodePtr &node_3); - - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; - - void Visit(const AnfNodePtr &node) override; - void Reset(); - - private: - AnfNodePtr vnode_; - AnfNodePtr c_p_node_; -}; - -class PowerOneEliminate : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; -}; - // grad = AllReduce(grad) / worker_number // grad = grad + weight * decy // -> @@ -200,39 +56,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { class ArithmeticSimplify : public OptimizerCaller { public: - ArithmeticSimplify() - : multiply_by_zero_or_one_(std::make_shared()), - tensor_multiply_by_one_(std::make_shared()), - add_by_zero_(std::make_shared()), - tensor_add_by_zero_(std::make_shared()), - identity_(std::make_shared(prim::kPrimIdentity)), - opt_update_zero_tensor_(std::make_shared()), - constant_duplicate_mul_(std::make_shared()), - power_one_(std::make_shared()) { - eliminaters_.emplace_back(multiply_by_zero_or_one_); - eliminaters_.emplace_back(tensor_multiply_by_one_); - eliminaters_.emplace_back(add_by_zero_); - eliminaters_.emplace_back(tensor_add_by_zero_); - eliminaters_.emplace_back(identity_); - eliminaters_.emplace_back(opt_update_zero_tensor_); - eliminaters_.emplace_back(constant_duplicate_mul_); - eliminaters_.emplace_back(power_one_); - } - ~ArithmeticSimplify() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr multiply_by_zero_or_one_; - OptimizerCallerPtr tensor_multiply_by_one_; - OptimizerCallerPtr add_by_zero_; - OptimizerCallerPtr tensor_add_by_zero_; - OptimizerCallerPtr identity_; - OptimizerCallerPtr opt_update_zero_tensor_; - OptimizerCallerPtr constant_duplicate_mul_; - OptimizerCallerPtr power_one_; - - std::vector eliminaters_{}; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; }; // Arithmetic Simplifications should be done after step_parallel. @@ -242,17 +66,9 @@ class ArithmeticSimplify : public OptimizerCaller { // ArithmeticSimplify and deferred until step_parallel. class ArithmeticSimplify2 : public OptimizerCaller { public: - ArithmeticSimplify2() : tensor_multiply_by_zero_(std::make_shared()) { - eliminaters_.emplace_back(tensor_multiply_by_zero_); - } - ~ArithmeticSimplify2() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override; - - private: - OptimizerCallerPtr tensor_multiply_by_zero_; - std::vector eliminaters_{}; + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; }; + } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h index b6a4e1c85238a09d13f14c2a63432ce64ddbdb1c..6de982f999b0d93cafd29cce1a38b04008910311 100644 --- a/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/optimizer/irpass/special_op_eliminate.h @@ -25,10 +25,8 @@ #include "ir/optimizer_caller.h" #include "ir/pattern_matcher.h" #include "ir/visitor.h" -#include "operator/ops.h" #include "optimizer/irpass.h" #include "optimizer/irpass/prim_eliminate.h" -#include "optimizer/optimizer.h" namespace mindspore { namespace opt { diff --git a/tests/ut/cpp/optimizer/opt_test.cc b/tests/ut/cpp/optimizer/opt_test.cc index 2428d0dddb36117e9a1e27fd17c3bbf31161f5b4..6c4aa8f56f9ac8fbcb91e407a792fe70228ca896 100644 --- a/tests/ut/cpp/optimizer/opt_test.cc +++ b/tests/ut/cpp/optimizer/opt_test.cc @@ -77,7 +77,7 @@ class TestOptOpt : public UT::Common { }; void SetUp() { - elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); + elim_Z = MakeSubstitution(std::make_shared(), "elim_Z", prim::kPrimScalarAdd); elim_R = MakeSubstitution(std::make_shared(R), "elim_R", R); idempotent_P = MakeSubstitution(std::make_shared(), "idempotent_P", P); Qct_to_P = MakeSubstitution(std::make_shared(), "Qct_to_P", Q);