diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index eb89c965df3954650702a48fe50af288a4b8ef45..3af574caf9432b681d64608a99d21a776f8f309f 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -17,6 +17,7 @@ """Resources for ast tree parse.""" import ast import math +from mindspore import IndexedSlices from mindspore.ops.composite import multitype_ops from mindspore.ops import functional as F, composite as C from . import standard_method as M @@ -135,4 +136,7 @@ convert_object_map = { math.sin: NO_IMPLEMENT, math.cos: NO_IMPLEMENT, math.tan: NO_IMPLEMENT, + + # user defined + IndexedSlices: F.make_indexed_slices, } diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index ab2ce1322a6d6ba365901cc4aa990a50f8ad78ef..99440537c7676825d0aed80a3089ae809a4c54c1 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -120,6 +120,10 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); } } + } else if (type->isa()) { + // Do Nothing + } else if (type->isa()) { + // Do Nothing } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_TUPLE); diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index 5e049c06232ffa1b3087fba9cf5312ec17cfda79..71a78bdcf679a79705fe51a930f0514dcb1b033b 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -94,6 +94,48 @@ bool Slice::operator==(const Type &other) const { std::string Slice::DumpText() const { return ToString(); } +TypePtr UndeterminedType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string UndeterminedType::ToReprString() const { + if (element_type_ == nullptr) { + return "Undetermined"; + } + return "Undetermined[" + element_type_->ToReprString() + "]"; +} + +std::string UndeterminedType::ToString() const { + if (element_type_ == nullptr) { + return "Undetermined"; + } + return "Undetermined[" + element_type_->ToString() + "]"; +} + +std::string UndeterminedType::DumpText() const { + if (element_type_ == nullptr) { + return "Undetermined"; + } + return "Undetermined[" + element_type_->DumpText() + "]"; +} + +bool UndeterminedType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + TypePtr TensorType::DeepCopy() const { MS_EXCEPTION_IF_NULL(element_type_); if (IsGeneric()) { @@ -137,6 +179,48 @@ bool TensorType::operator==(const Type &other) const { return *element_type_ == *other_elem_type; } +TypePtr IndexedSlicesType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string IndexedSlicesType::ToReprString() const { + if (element_type_ == nullptr) { + return "IndexedSlices"; + } + return "IndexedSlices[" + element_type_->ToReprString() + "]"; +} + +std::string IndexedSlicesType::ToString() const { + if (element_type_ == nullptr) { + return "IndexedSlices"; + } + return "IndexedSlices[" + element_type_->ToString() + "]"; +} + +std::string IndexedSlicesType::DumpText() const { + if (element_type_ == nullptr) { + return "IndexedSlices"; + } + return "IndexedSlices[" + element_type_->DumpText() + "]"; +} + +bool IndexedSlicesType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + Function::Function() : Object(kObjectTypeFunction) { args_ = std::vector(); retval_ = nullptr; diff --git a/mindspore/ccsrc/ir/dtype.h b/mindspore/ccsrc/ir/dtype.h index 9659a27e3640dd9cb2f2fa084989f8ab4b440708..f10c56e659459b1c039c255353ad3975b8bcca69 100644 --- a/mindspore/ccsrc/ir/dtype.h +++ b/mindspore/ccsrc/ir/dtype.h @@ -108,10 +108,34 @@ class Slice : public Object { }; using SlicePtr = std::shared_ptr; +class UndeterminedType : public Object { + public: + UndeterminedType() : Object(kObjectTypeUndeterminedType) {} + explicit UndeterminedType(const TypePtr &ele) + : Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {} + ~UndeterminedType() override = default; + MS_DECLARE_PARENT(UndeterminedType, Object) + + TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + protected: + TypePtr element_type_; +}; +using MetaTensorTypePtr = std::shared_ptr; + class TensorType : public Object { public: - TensorType() : Object(kObjectTypeTensorType) {} - explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} + TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {} + explicit TensorType(const TypePtr &ele) + : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} ~TensorType() override = default; MS_DECLARE_PARENT(TensorType, Object) @@ -130,6 +154,29 @@ class TensorType : public Object { }; using TensorTypePtr = std::shared_ptr; +class IndexedSlicesType : public Object { + public: + IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} + explicit IndexedSlicesType(const TypePtr &ele) + : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~IndexedSlicesType() override = default; + MS_DECLARE_PARENT(IndexedSlicesType, Object) + + TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using IndexedSlicesTypePtr = std::shared_ptr; + class Function : public Object { public: Function(); @@ -255,6 +302,8 @@ TypePtr StringToType(const std::string &type_name); // Judge whether x is predicate or is a subclass of predicate. bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); +bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type); + // Whether t1 is identity or a subclass of t2. bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/ccsrc/ir/dtype/type.cc index 5395b596176a79cf31e7af7372dbbefd317eeba0..754876a366a114b2abd92d17bc81eb4b8d7a8c41 100644 --- a/mindspore/ccsrc/ir/dtype/type.cc +++ b/mindspore/ccsrc/ir/dtype/type.cc @@ -115,6 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) { return "kObjectTypeKeyword"; case kObjectTypeTensorType: return "kObjectTypeTensorType"; + case kObjectTypeIndexedSlicesType: + return "kObjectTypeIndexedSlicesType"; + case kObjectTypeUndeterminedType: + return "kObjectTypeUndeterminedType"; case kObjectTypeDictionary: return "kObjectTypeDictionary"; case kObjectTypeClass: diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/ccsrc/ir/dtype/type.h index bfe39af43c0dedda01fbccd83ebb3282f3056e2a..cba0d17fce15628a51e8b67e557cd02a08ea25d4 100644 --- a/mindspore/ccsrc/ir/dtype/type.h +++ b/mindspore/ccsrc/ir/dtype/type.h @@ -67,6 +67,7 @@ class Type : public Value { virtual bool equal(const TypePtr other) const { return *this == *other; } virtual TypeId object_type() const { return kTypeUnknown; } + virtual TypeId parent_type() const { return kTypeUnknown; } virtual TypeId number_type() const { return kTypeUnknown; } virtual TypePtr DeepCopy() const = 0; virtual TypePtr Clone() const { return DeepCopy(); } @@ -97,13 +98,16 @@ using TypePtrList = std::vector; // class Object : public Type { public: - Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject) {} + Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {} explicit Object(const TypeId object_type, bool is_generic = true) - : Type(kMetaTypeObject, is_generic), object_type_(object_type) {} + : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {} + explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true) + : Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {} ~Object() override = default; MS_DECLARE_PARENT(Object, Type) TypeId object_type() const override { return object_type_; } + TypeId parent_type() const override { return parent_type_; } TypeId type_id() const override { return object_type_; } TypeId generic_type_id() const override { return kMetaTypeObject; } bool equal(const TypePtr other) const override; @@ -114,6 +118,7 @@ class Object : public Type { private: const TypeId object_type_; + const TypeId parent_type_; }; std::ostream &operator<<(std::ostream &os, const TypePtrList &types); diff --git a/mindspore/ccsrc/ir/dtype/type_id.h b/mindspore/ccsrc/ir/dtype/type_id.h index 17862ad798137923285241a1c78c1b39a4b077c0..a711779e919d5814bc7453ab933a2cfd83a6c432 100644 --- a/mindspore/ccsrc/ir/dtype/type_id.h +++ b/mindspore/ccsrc/ir/dtype/type_id.h @@ -50,6 +50,8 @@ enum TypeId : int { kObjectTypeSlice, kObjectTypeKeyword, kObjectTypeTensorType, + kObjectTypeIndexedSlicesType, + kObjectTypeUndeterminedType, kObjectTypeClass, kObjectTypeDictionary, kObjectTypeFunction, diff --git a/mindspore/ccsrc/ir/dtype_extends.cc b/mindspore/ccsrc/ir/dtype_extends.cc index e7af81292229aeadbc899838fa038d144edbba16..732872cb4f59a70e68e81e17e87bdf92e6078293 100644 --- a/mindspore/ccsrc/ir/dtype_extends.cc +++ b/mindspore/ccsrc/ir/dtype_extends.cc @@ -192,6 +192,40 @@ TypePtr TensorStrToType(const std::string &type_name) { return type; } +TypePtr IndexedSlicesStrToType(const std::string &type_name) { + if (type_name == "IndexedSlices") { + return std::make_shared(); + } + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + return std::make_shared(element_type); +} + +TypePtr UndeterminedStrToType(const std::string &type_name) { + if (type_name == "Undetermined") { + return std::make_shared(); + } + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + return std::make_shared(element_type); +} + TypePtr ListStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "List") { @@ -313,6 +347,10 @@ TypePtr StringToType(const std::string &type_name) { type = StringToNumberType(type_name, "Float"); } else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) { type = TensorStrToType(type_name); + } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { + type = UndeterminedStrToType(type_name); + } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { + type = IndexedSlicesStrToType(type_name); } else if (type_name.compare(0, strlen("List"), "List") == 0) { type = ListStrToType(type_name); } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { @@ -340,6 +378,20 @@ TypePtr StringToType(const std::string &type_name) { return type; } +bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) { + if (x == nullptr || base_type == nullptr) { + MS_LOG(ERROR) << "Type is nullptr."; + return false; + } + if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) { + return false; + } + if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) { + return true; + } + return false; +} + bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { if (x == nullptr || base_type == nullptr) { MS_LOG(ERROR) << "Type is nullptr."; @@ -481,6 +533,10 @@ REGISTER_PYBIND_DEFINE( TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); return data; })); + (void)py::class_>(m_sub, "IndexedSlicesType") + .def(py::init()); + (void)py::class_>(m_sub, "UndeterminedType") + .def(py::init()); (void)py::class_>(m_sub, "Function") .def(py::init()) .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); @@ -501,6 +557,8 @@ const TypePtr kTypeExternal = std::make_shared(); const TypePtr kTypeEnv = std::make_shared(); const TypePtr kTypeType = std::make_shared(); const TypePtr kTensorType = std::make_shared(); +const TypePtr kIndexedSlicesType = std::make_shared(); +const TypePtr kUndeterminedType = std::make_shared(); const TypePtr kString = std::make_shared(); const TypePtr kList = std::make_shared(); const TypePtr kTuple = std::make_shared(); diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc index 88b313450806c680ae438cb34ebbd3cd60bca390..de6526f642313528527ea94cf4af1cce8131404f 100644 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/operator/composite/multitype_funcgraph.cc @@ -93,15 +93,17 @@ static TypePtr UnwrapRef(const TypePtr &type) { } return type; } -FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { - bool find_fn = false; - py::function py_fn; + +// Return Exact match if exists, else return non ambiguous sub class match +// Return py::none() if matching is ambiguous +const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { + // Exact match for (auto &item : fn_cache_py_) { TypePtrList sign = item.first; if (sign.size() != types.size()) { continue; } - bool match = true; + auto match = true; for (size_t i = 0; i < sign.size(); ++i) { if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) { match = false; @@ -111,13 +113,45 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { if (!match) { continue; } - find_fn = true; - py_fn = item.second; - break; + return item.second; } + // Try best match + py::function py_fn_subclass; + size_t subclass_match_cnt = 0; + for (auto &item : fn_cache_py_) { + TypePtrList sign = item.first; + if (sign.size() != types.size()) { + continue; + } + auto match = true; + for (size_t i = 0; i < sign.size(); ++i) { + if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) && + !IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) { + match = false; + break; + } + } + if (!match) { + continue; + } + py_fn_subclass = item.second; + subclass_match_cnt++; + } + if (subclass_match_cnt > 1) { + MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass"; + } + if (subclass_match_cnt == 1) { + MS_LOG(DEBUG) << "Found one subclass match"; + return py_fn_subclass; + } + return py::none(); +} + +FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { + auto py_fn = SignMatch(types); std::ostringstream buffer; buffer << types; - if (find_fn) { + if (py_fn != py::none()) { FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn); if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); diff --git a/mindspore/ccsrc/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/operator/composite/multitype_funcgraph.h index feb38f17ba09265ca25913113be83efdd39a3318..ababf218831faba1c73d993fa044a1f1206a0ad5 100644 --- a/mindspore/ccsrc/operator/composite/multitype_funcgraph.h +++ b/mindspore/ccsrc/operator/composite/multitype_funcgraph.h @@ -54,6 +54,7 @@ class MultitypeFuncGraph : public MetaFuncGraph { } private: + const py::function SignMatch(const TypePtrList &types); std::unordered_map fn_cache_; std::unordered_map fn_cache_py_; }; diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 88001bf63f3ff4b38057ce3da1763fbb1c69801b..b682847ed72f08a6353d5ba5afd070c4506c6c51 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -277,5 +277,12 @@ const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); const PrimitivePtr kPrimDebug = std::make_shared("Debug"); + +// IndexedSlices +const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); +const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); +const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); +const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); +const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index efa66834689ee764fbe953dc1a211c42bee40b3d..f7780138961d1fe9f0949ab434bff1ffffd3de67 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -287,6 +287,13 @@ extern const PrimitivePtr kPrimMirror; extern const PrimitivePtr kPrimVirtualDiv; extern const PrimitivePtr kPrimVirtualDataset; +// IndexedSlices +extern const PrimitivePtr kPrimMakeIndexedSlices; +extern const PrimitivePtr kPrimIndexedSlicesGetValues; +extern const PrimitivePtr kPrimIndexedSlicesGetIndices; +extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; +extern const PrimitivePtr kPrimIsIndexedSlices; + class DoSignaturePrimitive : public Primitive { public: explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc index 9350e9aa3b1f85b0ae159c8cfba2b8dbe515cc93..ff9ec712bbea0d83a1e9fe0a027753bce90814c9 100644 --- a/mindspore/ccsrc/operator/prim_others.cc +++ b/mindspore/ccsrc/operator/prim_others.cc @@ -24,6 +24,7 @@ #include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/utils.h" #include "utils/symbolic.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace abstract { @@ -173,6 +174,13 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt return std::make_shared(sparse_list); } + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse_flag = context->enable_sparse_flag(); + if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa()) { + auto dflt_tensor = dflt->cast(); + return std::make_shared(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone()); + } if (!key->GetValueTrack()->isa()) { return dflt; } @@ -236,6 +244,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr & } auto ret = std::make_shared(args_spec_list[0], args_spec_list[1], args_spec_list[2]); ret->set_sparse_grad(args_spec_list[2]->sparse_grad()); + ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad()); return ret; } @@ -437,5 +446,72 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv } return std::make_shared(kAnyValue, kBool); } + +AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto indices = CheckArg(op_name, args_spec_list, 0); + auto values = CheckArg(op_name, args_spec_list, 1); + auto dense_shape = CheckArg(op_name, args_spec_list, 2); + + auto dense_shape_value = dense_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(dense_shape_value); + auto shp = dense_shape_value->value(); + std::vector dense_shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), + [](const ValuePtr &e) -> int { + auto elem = GetValue(e); + return elem; + }); + auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); + ret->set_indices(indices); + ret->set_values(values); + ret->set_dense_shape(dense_shape); + return ret; +} + +AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->values()); + return indexed_slices->values(); +} + +AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->indices()); + return indexed_slices->indices(); +} + +AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto indexed_slices = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); + return indexed_slices->dense_shape(); +} + +AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + bool ret = false; + if (args_spec_list[0]->isa()) { + ret = true; + } + MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); + return std::make_shared(ret); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index 6a5459728290ad479095893edc7ae9e6613c0d3b..bb52273568110ae528710a2923ab14b65ea41b8c 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -36,6 +36,7 @@ using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractUndetermined; static AbstractBasePtr Reabs(const AbstractBasePtr &t) { if (t == nullptr) { @@ -78,7 +79,7 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(cons); auto dt = data->abstract(); - if (dt == nullptr) { + if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) { return nullptr; } diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 3e8cfea37f5d31f2a32f4da938ed15873450fb54..166151751ff852e5fcc2cba26f7a87adba99d253 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -42,6 +42,7 @@ #include "optimizer/irpass/tile_eliminate.h" #include "optimizer/irpass/transpose_eliminate.h" #include "optimizer/opt.h" +#include "optimizer/irpass/indexed_slices_eliminate.h" namespace mindspore { namespace opt { @@ -153,6 +154,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Mark interface fusion mark_interface_fusion_ = MakeSubstitution(std::make_shared(), "mark_interface_fusion", prim::kPrimSelect); + + // IndexedSlices Eliminate + indexed_slices_eliminate_ = MakeSubstitution( + std::make_shared(), "indexed_slices_eliminate", + {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); } ResolveIRPassLib::ResolveIRPassLib() { diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index fa4d1e4cae49cba0b202518298b5213aa73f37bd..782eae61240ba087f3686b01a286d25fe5320db4 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -104,6 +104,9 @@ class OptimizeIRPassLib { // Fusion SubstitutionPtr mark_interface_fusion_; + + // IndexedSlices Eliminate + SubstitutionPtr indexed_slices_eliminate_; }; // the collection of irpass for resolve action diff --git a/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h b/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h new file mode 100644 index 0000000000000000000000000000000000000000..630d567549fd093c4a8e0ddbd3fc9a42e6ebc617 --- /dev/null +++ b/mindspore/ccsrc/optimizer/irpass/indexed_slices_eliminate.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ + +#include +#include + +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" +#include "ir/visitor.h" +#include "operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} +// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} +// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} +class IndexedSlicesEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(1); + } + AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(2); + } + AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(3); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { + tuple_ = cnode; + is_match_ = true; + } + } + + void Reset() { + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + CNodePtr tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 7d56551ff05872f0593463412dfb1e9c99f9680f..c76053d2418577383bf7483937e66956d1693eae 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -232,6 +232,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { auto sparse_grad = py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad")); ptr->set_sparse_grad(sparse_grad); + auto has_indexed_slices_grad = + py::cast(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad")); + ptr->set_has_indexed_slices_grad(has_indexed_slices_grad); parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr); args_spec.push_back(ptr); diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index dc309808d9ac3f69d0630c0fc25002380011550d..f28be181dddb3081a88903ce336e6e61d3de1021 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -154,7 +154,9 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.") .def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel, "Set the GraphKernel switch to on or off.") - .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch."); + .def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.") + .def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.") + .def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse."); (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 9876c0280ad17f32971484c176df3131d618de2a..f6cfd6362c465840e3d3398fb1f7dbdc11090458 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -156,6 +156,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, irpass.get_ref_param_eliminate_, + irpass.indexed_slices_eliminate_, }); OptPassGroupMap map({ {"b_1", b_1}, diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc index 50ccef2f44c73dd51c79c5737ec116a8b566b750..faf1f2015d4c3bd1eac1f89fa4d7a57cef29e21f 100644 --- a/mindspore/ccsrc/pipeline/resource.cc +++ b/mindspore/ccsrc/pipeline/resource.cc @@ -33,148 +33,157 @@ namespace mindspore { namespace pipeline { MethodMap &GetMethodMap() { - static MethodMap method_map = {{kObjectTypeString, - { - {"__bool__", std::string("str_bool")} // C.str_bool - }}, - {kMetaTypeNone, - { - {"__bool__", std::string("none_bool")} // C.none_bool - }}, - {kNumberTypeBool, - { - {"__and__", prim::kPrimBoolAnd}, // P.bool_and - {"__or__", prim::kPrimBoolOr}, // P.bool_or - {"__eq__", prim::kPrimBoolEq}, // P.bool_eq - {"__ne__", std::string("bool_ne")}, // C.bool_ne - {"__bool__", prim::kPrimIdentity} // P.identity - }}, - {kNumberTypeInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul - {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow - {"__floor__", prim::kPrimIdentity}, // P.identity - {"__trunc__", prim::kPrimIdentity}, // P.identity - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt - {"__le__", prim::kPrimScalarLe}, // P.scalar_le - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array - }}, - {kNumberTypeUInt, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__truediv__", std::string("int_truediv")}, // C.int_truediv - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimIdentity}, // P.identity, - {"__trunc__", prim::kPrimIdentity}, // P.identity, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("int_bool")}, // C.int_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kNumberTypeFloat, - { - {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, - {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, - {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, - {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv - {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, - {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, - {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, - {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, - {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, - {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, - {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, - {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, - {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, - {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, - {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, - {"__le__", prim::kPrimScalarLe}, // P.scalar_le, - {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, - {"__bool__", std::string("float_bool")}, // C.float_bool - {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, - }}, - {kObjectTypeTuple, - { - {"__len__", prim::kPrimTupleLen}, // P.tuple_len, - {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, - {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity, - {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, - {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext - {"__bool__", std::string("tuple_bool")} // C.tuple_bool - }}, - {kObjectTypeList, - { - {"__len__", prim::kPrimListLen}, // P.list_len, - {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, - {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, - {"__ms_iter__", prim::kPrimIdentity}, // P.identity - {"__ms_next__", std::string("list_next")}, // C.list_next - {"append", std::string("list_append")}, // C.list_next - {"__bool__", std::string("list_bool")}, // C.list_bool - {"__ms_hasnext__", std::string("list_hasnext")}, - }}, - {kObjectTypeDictionary, - { - {"__len__", prim::kPrimDictLen}, // P.dict_len - {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem - {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, - {"__bool__", std::string("dict_bool")} // C.dict_bool - }}, - {kObjectTypeTensorType, - { - {"__add__", std::string("add")}, // C.add - {"__sub__", std::string("sub")}, // C.sub - {"__mul__", std::string("mul")}, // C.mul - {"__truediv__", std::string("truediv")}, // C.truediv - {"__floordiv__", std::string("floordiv")}, // C.floordiv - {"__mod__", std::string("mod")}, // C.mod - {"__pow__", std::string("pow_")}, // C.pow - {"__floor__", std::string("array_floor")}, // C.array_floor - {"__trunc__", std::string("array_trunc")}, // C.array_trunc - {"__pos__", std::string("array_uadd")}, // C.array_uadd - {"__neg__", std::string("array_usub")}, // C.array_usub - {"__eq__", std::string("eq")}, // C.eq - {"__ne__", std::string("ne")}, // C.ne - {"__lt__", std::string("lt")}, // C.lt - {"__gt__", std::string("gt")}, // C.gt - {"__le__", std::string("le")}, // C.le - {"__ge__", std::string("ge")}, // C.ge - {"__matmul__", prim::kPrimDot}, // P.dot, - {"__len__", prim::kPrimArrayLen}, // P.array_len, - {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, - {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, - {"__ms_iter__", std::string("array_iter")}, // C.array_iter - {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, - {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, - {"transpose", std::string("transpose")}, // P.transpose - {"__bool__", std::string("tensor_bool")}, // C.tensor_bool - }}, - {kObjectTypeJTagged, {}}, - {kObjectTypeSymbolicKeyType, {}}, - {kObjectTypeEnvType, {}}}; + static MethodMap method_map = { + {kObjectTypeString, + { + {"__bool__", std::string("str_bool")} // C.str_bool + }}, + {kMetaTypeNone, + { + {"__bool__", std::string("none_bool")} // C.none_bool + }}, + {kNumberTypeBool, + { + {"__and__", prim::kPrimBoolAnd}, // P.bool_and + {"__or__", prim::kPrimBoolOr}, // P.bool_or + {"__eq__", prim::kPrimBoolEq}, // P.bool_eq + {"__ne__", std::string("bool_ne")}, // C.bool_ne + {"__bool__", prim::kPrimIdentity} // P.identity + }}, + {kNumberTypeInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul + {"__floordiv__", std::string("int_floordiv")}, // C.int_floordiv + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow + {"__floor__", prim::kPrimIdentity}, // P.identity + {"__trunc__", prim::kPrimIdentity}, // P.identity + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt + {"__le__", prim::kPrimScalarLe}, // P.scalar_le + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array + }}, + {kNumberTypeUInt, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__truediv__", std::string("int_truediv")}, // C.int_truediv + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimIdentity}, // P.identity, + {"__trunc__", prim::kPrimIdentity}, // P.identity, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("int_bool")}, // C.int_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kNumberTypeFloat, + { + {"__add__", prim::kPrimScalarAdd}, // P.scalar_add, + {"__sub__", prim::kPrimScalarSub}, // P.scalar_sub, + {"__mul__", prim::kPrimScalarMul}, // P.scalar_mul, + {"__floordiv__", std::string("float_floordiv")}, // C.float_floordiv + {"__truediv__", prim::kPrimScalarDiv}, // P.scalar_div, + {"__mod__", prim::kPrimScalarMod}, // P.scalar_mod, + {"__pow__", prim::kPrimScalarPow}, // P.scalar_pow, + {"__floor__", prim::kPrimScalarFloor}, // P.scalar_floor, + {"__trunc__", prim::kPrimScalarTrunc}, // P.scalar_trunc, + {"__pos__", prim::kPrimScalarUadd}, // P.scalar_uadd, + {"__neg__", prim::kPrimScalarUsub}, // P.scalar_usub, + {"__eq__", prim::kPrimScalarEq}, // P.scalar_eq, + {"__ne__", prim::kPrimScalarNe}, // P.scalar_ne, + {"__lt__", prim::kPrimScalarLt}, // P.scalar_lt, + {"__gt__", prim::kPrimScalarGt}, // P.scalar_gt, + {"__le__", prim::kPrimScalarLe}, // P.scalar_le, + {"__ge__", prim::kPrimScalarGe}, // P.scalar_ge, + {"__bool__", std::string("float_bool")}, // C.float_bool + {"__ms_to_array__", prim::kPrimScalarToArray}, // P.scalar_to_array, + }}, + {kObjectTypeTuple, + { + {"__len__", prim::kPrimTupleLen}, // P.tuple_len, + {"__getitem__", prim::kPrimTupleGetItem}, // P.tuple_getitem, + {"__setitem__", prim::kPrimTupleSetItem}, // P.tuple_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity, + {"__ms_next__", std::string("tuple_next")}, // C.tuple_next, + {"__ms_hasnext__", std::string("tuple_hasnext")}, // C.tuple_hasnext + {"__bool__", std::string("tuple_bool")} // C.tuple_bool + }}, + {kObjectTypeList, + { + {"__len__", prim::kPrimListLen}, // P.list_len, + {"__getitem__", prim::kPrimListGetItem}, // P.list_getitem, + {"__setitem__", prim::kPrimListSetItem}, // P.list_setitem, + {"__ms_iter__", prim::kPrimIdentity}, // P.identity + {"__ms_next__", std::string("list_next")}, // C.list_next + {"append", std::string("list_append")}, // C.list_next + {"__bool__", std::string("list_bool")}, // C.list_bool + {"__ms_hasnext__", std::string("list_hasnext")}, + }}, + {kObjectTypeDictionary, + { + {"__len__", prim::kPrimDictLen}, // P.dict_len + {"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem + {"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem, + {"__bool__", std::string("dict_bool")} // C.dict_bool + }}, + {kObjectTypeTensorType, + { + {"__add__", std::string("add")}, // C.add + {"__sub__", std::string("sub")}, // C.sub + {"__mul__", std::string("mul")}, // C.mul + {"__truediv__", std::string("truediv")}, // C.truediv + {"__floordiv__", std::string("floordiv")}, // C.floordiv + {"__mod__", std::string("mod")}, // C.mod + {"__pow__", std::string("pow_")}, // C.pow + {"__floor__", std::string("array_floor")}, // C.array_floor + {"__trunc__", std::string("array_trunc")}, // C.array_trunc + {"__pos__", std::string("array_uadd")}, // C.array_uadd + {"__neg__", std::string("array_usub")}, // C.array_usub + {"__eq__", std::string("eq")}, // C.eq + {"__ne__", std::string("ne")}, // C.ne + {"__lt__", std::string("lt")}, // C.lt + {"__gt__", std::string("gt")}, // C.gt + {"__le__", std::string("le")}, // C.le + {"__ge__", std::string("ge")}, // C.ge + {"__matmul__", prim::kPrimDot}, // P.dot, + {"__len__", prim::kPrimArrayLen}, // P.array_len, + {"__getitem__", prim::kPrimArrayGetItem}, // P.array_getitem, + {"__setitem__", prim::kPrimArraySetItem}, // P.array_setitem, + {"__ms_iter__", std::string("array_iter")}, // C.array_iter + {"__ms_to_array__", prim::kPrimIdentity}, // P.identity, + {"item", prim::kPrimArrayToScalar}, // P.array_to_scalar, + {"transpose", std::string("transpose")}, // P.transpose + {"__bool__", std::string("tensor_bool")}, // C.tensor_bool + {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices + }}, + {kObjectTypeIndexedSlicesType, + { + {"is_indexed_slices", prim::kPrimIsIndexedSlices}, // F.is_indexed_slices + {"values", prim::kPrimIndexedSlicesGetValues}, // F.indexed_slices_get_values + {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices + {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape + }}, + {kObjectTypeJTagged, {}}, + {kObjectTypeSymbolicKeyType, {}}, + {kObjectTypeEnvType, {}}}; return method_map; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index f23c6e31c4b7f88ed6b69af3c84df9d0ad7257dd..86bfecf14bc7f61c3b3b6c8b2ee0b259ae432ea2 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -30,6 +30,10 @@ bool AbstractBase::operator==(const AbstractBase &other) const { if (tid() != other.tid()) { return false; } + if (BuildType()->type_id() == kObjectTypeUndeterminedType && + other.BuildType()->type_id() == kObjectTypeUndeterminedType) { + return true; + } if (value_ == nullptr || other.value_ == nullptr) { MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: " << this->ToString() << ", other: " << other.ToString(); @@ -65,7 +69,7 @@ std::string AbstractBase::ToString() const { MS_EXCEPTION_IF_NULL(shape_); buffer << type_name() << "(" << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() - << " sparse_grad: " << sparse_grad_ << ")"; + << " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")"; return buffer.str(); } @@ -76,6 +80,7 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { if (*this == *other) { auto ret = shared_from_base(); ret->set_sparse_grad(sparse_grad()); + ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); return ret; } auto value_self = GetValueTrack(); @@ -85,10 +90,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { if (res_value == value_self) { auto ret = shared_from_base(); ret->set_sparse_grad(sparse_grad()); + ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); return ret; } auto ret = std::make_shared(res_value, res_type); ret->set_sparse_grad(sparse_grad()); + ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); return ret; } @@ -409,6 +416,14 @@ std::size_t AbstractSlice::hash() const { return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); } +ShapePtr AbstractUndetermined::shape() const { + auto shp = dyn_cast(GetShapeTrack()); + if (shp == nullptr) { + MS_LOG(EXCEPTION) << "Tensor should have a shape."; + } + return shp; +} + TypePtr AbstractTensor::BuildType() const { MS_EXCEPTION_IF_NULL(element_); TypePtr element_type = element_->BuildType(); @@ -425,6 +440,13 @@ BaseShapePtr AbstractTensor::BuildShape() const { } AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { + if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) { + auto other_tensor = dyn_cast(other); + auto element = element_->Join(other_tensor->element()); + auto shape = ShapeJoin(this->shape(), other_tensor->shape()); + auto ret = std::make_shared(element, shape); + return ret; + } auto other_tensor = dyn_cast(other); if (other_tensor == nullptr) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); @@ -433,6 +455,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { auto shape = ShapeJoin(this->shape(), other_tensor->shape()); auto ret = std::make_shared(element, shape); ret->set_sparse_grad(sparse_grad()); + ret->set_has_indexed_slices_grad(has_indexed_slices_grad()); return ret; } @@ -474,6 +497,7 @@ AbstractBasePtr AbstractTensor::Clone() const { clone->set_shape(shp->Clone()); clone->set_value(GetValueTrack()); clone->set_sparse_grad(sparse_grad()); + clone->set_has_indexed_slices_grad(has_indexed_slices_grad()); return clone; } @@ -484,6 +508,7 @@ AbstractBasePtr AbstractTensor::Broaden() const { broaden->set_shape(shp->Clone()); broaden->set_value(kAnyValue); broaden->set_sparse_grad(sparse_grad()); + broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); return broaden; } @@ -495,17 +520,10 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const { broaden->set_shape(shp); broaden->set_value(kAnyValue); broaden->set_sparse_grad(sparse_grad()); + broaden->set_has_indexed_slices_grad(has_indexed_slices_grad()); return broaden; } -ShapePtr AbstractTensor::shape() const { - auto shp = dyn_cast(GetShapeTrack()); - if (shp == nullptr) { - MS_LOG(EXCEPTION) << "Tensor should have a shape."; - } - return shp; -} - std::string AbstractTensor::ToString() const { std::ostringstream buffer; BaseShapePtr shape_track = GetShapeTrack(); @@ -516,7 +534,7 @@ std::string AbstractTensor::ToString() const { buffer << type_name() << "(" << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad() - << ")"; + << " has_indexed_slices_grad " << has_indexed_slices_grad() << ")"; return buffer.str(); } @@ -1019,5 +1037,64 @@ std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &arg bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { return AbstractBasePtrListDeepEqual(lhs, rhs); } + +// IndexedSlices +TypePtr AbstractIndexedSlices::BuildType() const { + MS_EXCEPTION_IF_NULL(element()); + TypePtr element_type = element()->BuildType(); + return std::make_shared(element_type); +} + +AbstractBasePtr AbstractIndexedSlices::Clone() const { + MS_EXCEPTION_IF_NULL(element()); + auto clone = std::make_shared(element()->Clone()); + ShapePtr shp = shape(); + clone->set_shape(shp->Clone()); + clone->set_value(GetValueTrack()); + clone->set_indices(indices_->Clone()->cast()); + clone->set_values(values_->Clone()->cast()); + clone->set_dense_shape(dense_shape_->Clone()->cast()); + return clone; +} + +AbstractBasePtr AbstractIndexedSlices::Broaden() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape(); + broaden->set_shape(shp->Clone()); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape()->Clone(); + shp->Broaden(); + broaden->set_shape(shp); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +std::string AbstractIndexedSlices::ToString() const { + std::ostringstream buffer; + BaseShapePtr shape_track = GetShapeTrack(); + MS_EXCEPTION_IF_NULL(shape_track); + MS_EXCEPTION_IF_NULL(element()); + auto value_track = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + buffer << type_name() << "(" + << "shape: " << shape_track->ToString() << ", element: " << element()->ToString() + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")" + << ", indices: " << indices_->ToString() << ", values" << values_->ToString() + << ", dense_shape: " << dense_shape_->ToString(); + return buffer.str(); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h index 5b54c749b67df3b45348ef56bf58722a2648bbd3..a5b4acff4520d36f0c380bc1a0dee3cc8e0743f5 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h @@ -44,7 +44,7 @@ class AbstractBase : public Base { public: explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, const BaseShapePtr &shape = kNoShape) - : value_(value), type_(type), shape_(shape), sparse_grad_("") {} + : value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {} ~AbstractBase() override = default; MS_DECLARE_PARENT(AbstractBase, Base) @@ -54,12 +54,16 @@ class AbstractBase : public Base { virtual bool operator==(const AbstractBase &other) const; void set_value(const ValuePtr &value) { value_ = value; } void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } + void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) { + has_indexed_slices_grad_ = has_indexed_slices_grad; + } void set_type(const TypePtr &type) { type_ = type; } void set_shape(const BaseShapePtr &shape) { shape_ = shape; } void set_value_desc(const std::string &desc) { value_desc_ = desc; } const std::string &value_desc() const { return value_desc_; } ValuePtr GetValueTrack() const { return value_; } const std::string &sparse_grad() const { return sparse_grad_; } + const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; } TypePtr GetTypeTrack() const { return type_; } BaseShapePtr GetShapeTrack() const { return shape_; } @@ -88,6 +92,7 @@ class AbstractBase : public Base { BaseShapePtr shape_; std::string value_desc_; // store initial value description for error report std::string sparse_grad_; + bool has_indexed_slices_grad_; }; class AbstractScalar : public AbstractBase { @@ -231,35 +236,49 @@ class AbstractKeywordArg : public AbstractBase { }; using AbstractKeywordArgPtr = std::shared_ptr; -class AbstractTensor : public AbstractBase { +class AbstractUndetermined : public AbstractBase { public: + // shape and type are all unknown + AbstractUndetermined() : AbstractBase(kAnyValue) {} // only element_ and value, shape track are valid member, type track are unknown. - explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) : AbstractBase(kAnyValue), element_(element) { if (element == nullptr) { MS_LOG(EXCEPTION) << "element is nullptr"; } - if (element->isa()) { + if (element->isa()) { MS_LOG(EXCEPTION) << "element type error"; } set_shape(shape); } - AbstractTensor(const TypePtr &element_type, const std::vector &shape) + AbstractUndetermined(const TypePtr &element_type, const std::vector &shape) : AbstractBase(kAnyValue), element_(std::make_shared(kAnyValue, element_type)) { if (element_type == nullptr) { MS_LOG(EXCEPTION) << "element_type is nullptr"; } set_shape(std::make_shared(shape)); } - explicit AbstractTensor(const tensor::TensorPtr &tensor) - : AbstractBase(tensor), element_(std::make_shared(kAnyValue, tensor->Dtype())) { - if (tensor == nullptr) { - MS_LOG(EXCEPTION) << "tensor is nullptr"; - } - set_shape(std::make_shared(tensor->shape())); - } + ~AbstractUndetermined() override = default; + MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase) + TypePtr BuildType() const override { return std::make_shared(); } + AbstractBasePtr Clone() const override { return std::make_shared(); } + const AbstractBasePtr element() const { return element_; } + ShapePtr shape() const; + + protected: + AbstractBasePtr element_; +}; + +class AbstractTensor : public AbstractUndetermined { + public: + // only element_ and value, shape track are valid member, type track are unknown. + explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractUndetermined(element, shape) {} + AbstractTensor(const TypePtr &element_type, const std::vector &shape) + : AbstractUndetermined(element_type, shape) {} + explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {} ~AbstractTensor() override = default; - MS_DECLARE_PARENT(AbstractTensor, AbstractBase) + MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) TypePtr BuildType() const override; BaseShapePtr BuildShape() const override; @@ -271,9 +290,7 @@ class AbstractTensor : public AbstractBase { bool operator==(const AbstractTensor &other) const; bool operator==(const AbstractBase &other) const override; - ShapePtr shape() const; std::string ToString() const override; - const AbstractBasePtr element() const { return element_; } std::size_t hash() const override { auto value = GetValueTrack(); auto hash_sum = hash_combine(tid(), element_->hash()); @@ -285,9 +302,6 @@ class AbstractTensor : public AbstractBase { } return hash_sum; } - - private: - AbstractBasePtr element_; }; using AbstractTensorPtr = std::shared_ptr; using AbstractTensorPtrList = std::vector; @@ -585,6 +599,35 @@ struct AbstractBasePtrListEqual { std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); + +// IndexedSlices +class AbstractIndexedSlices : public AbstractUndetermined { + public: + explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractUndetermined(element, shape) {} + AbstractIndexedSlices(const TypePtr &element_type, const std::vector &shape) + : AbstractUndetermined(element_type, shape) {} + ~AbstractIndexedSlices() override = default; + MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) + + const AbstractTensorPtr indices() const { return indices_; } + const AbstractTensorPtr values() const { return values_; } + const AbstractTuplePtr dense_shape() const { return dense_shape_; } + void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } + void set_values(const AbstractTensorPtr &values) { values_ = values; } + void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } + TypePtr BuildType() const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + AbstractBasePtr BroadenWithShape() const; + + std::string ToString() const override; + + private: + AbstractTensorPtr indices_; + AbstractTensorPtr values_; + AbstractTuplePtr dense_shape_; +}; } // namespace abstract } // namespace mindspore #endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h index c7a004ac44e47b7432a78fc6e6a5f02443281095..f6430eda84c71aaf0acd5f2e97c46fb339dec796 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.h +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.h @@ -58,6 +58,20 @@ class Evaluator : public Base { return args_spec_list; } + virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) { + auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) { + if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) { + return true; + } + return false; + }); + if (is_abstract) { + MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result"; + return std::make_shared(std::make_shared(), std::make_shared()); + } + return nullptr; + } + std::string ToString() const override { return identifier_; } virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/param_validator.h b/mindspore/ccsrc/pipeline/static_analysis/param_validator.h index ecb9529a586269d0eceb5714a3ff9df06d598eaa..2f5729aa7348857d50040405a4ba50812ba9ea37 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/param_validator.h +++ b/mindspore/ccsrc/pipeline/static_analysis/param_validator.h @@ -66,6 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function) ABSTRACT_REPORT_NAME_TRAITS(Type) ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) ABSTRACT_REPORT_NAME_TRAITS(Class) +ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) template std::shared_ptr CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index bf1f319ae283519b9b53ce395e7236d23e55dc60..99dc0859893113ea0a20e4785594345d3ea7bd68 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -36,6 +36,7 @@ #include "pipeline/parse/resolve.h" #include "ir/tensor.h" #include "utils/convert_utils.h" +#include "utils/context/ms_context.h" #include "pipeline/parse/data_converter.h" #include "pipeline/static_analysis/param_validator.h" #include "common/utils.h" @@ -132,6 +133,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimControlDepend, {InferImplControlDepend, true}}, // Debug {prim::kPrimDebug, {InferImplDebug, true}}, + // IndexedSlices + {prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, + {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, + {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, + {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, + {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, }; return prim_eval_implement_map; } @@ -139,6 +146,16 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { using mindspore::parse::PyObjectWrapper; EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse_flag = context->enable_sparse_flag(); + if (enable_sparse_flag && prim_ != prim::kPrimMakeTuple && prim_ != prim::kPrimSwitch) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined"; + return ret_abstract; + } + } prim_->BeginRecordAddAttr(); AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); prim_->EndRecordAddAttr(); @@ -485,6 +502,16 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic } // end anonymous namespace EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse_flag = context->enable_sparse_flag(); + if (enable_sparse_flag) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined"; + return ret_abstract; + } + } MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); const auto &iter = cache_->find(args); @@ -512,6 +539,16 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs } EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse_flag = context->enable_sparse_flag(); + if (enable_sparse_flag) { + auto ret_abstract = AbstractEval(args); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined"; + return ret_abstract; + } + } // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type. if (nargs_ != args.size()) { MS_LOG(ERROR) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs"; @@ -871,6 +908,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { auto ref_value = ref_abs->ref(); MS_EXCEPTION_IF_NULL(ref_value); ret->set_sparse_grad(ref_value->sparse_grad()); + ret->set_has_indexed_slices_grad(ref_value->has_indexed_slices_grad()); return std::make_shared(ret, std::make_shared()); } @@ -886,6 +924,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator { std::shared_ptr key = std::make_shared(node, x); std::shared_ptr abs_scalar = std::make_shared(key, type); abs_scalar->set_sparse_grad(x->sparse_grad()); + abs_scalar->set_has_indexed_slices_grad(x->has_indexed_slices_grad()); return std::make_shared(abs_scalar, std::make_shared()); } }; @@ -897,6 +936,16 @@ class GetAttrEvaluator : public TransitionPrimEvaluator { MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator); EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse_flag = context->enable_sparse_flag(); + if (enable_sparse_flag) { + auto ret_abstract = AbstractEval(args_spec_list); + if (ret_abstract != nullptr) { + MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined"; + return ret_abstract; + } + } // Inputs: data, item if (args_spec_list.size() != 2) { MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size(); diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 5954179aa5d6c9920d6d48b52fbc6fa2aba6202a..1346dba2a2b8cba939a4c4e0e58e538c1ff25f32 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -350,6 +350,17 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); void InitUndeterminedFromEnv(const std::string &sparse_shape_types); + +AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index 9da148d2a738fe72eb146423d7bf10155df1d7cb..54165766803fc280701350d681dbba29efbfd129 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -228,6 +228,10 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); } + if (maybe_func->BuildType()->type_id() == kObjectTypeUndeterminedType) { + MS_LOG(DEBUG) << "EvalCNode eval Undetermined"; + return std::make_shared(maybe_func->Clone(), std::make_shared()); + } AbstractFunctionPtr func = dyn_cast(maybe_func); if (func == nullptr) { MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return not AbstractFunction: " << maybe_func->ToString() diff --git a/mindspore/ccsrc/pipeline/validator.cc b/mindspore/ccsrc/pipeline/validator.cc index 4866d43b93035782f6dd8fb319fbe8a23d323f86..bbca3c8721ece7c656a1ccf6f882113068c1b924 100644 --- a/mindspore/ccsrc/pipeline/validator.cc +++ b/mindspore/ccsrc/pipeline/validator.cc @@ -32,6 +32,7 @@ using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractError; using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractIndexedSlices; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; @@ -93,7 +94,8 @@ void ValidateAbstract(const AnfNodePtr &node) { } if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa() || ptrBase->isa()) { + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa()) { return; } diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index d385ec7a3f69fcd16402970f26b12338d1147c6d..3d367b90e202faeecf8ba4e98016ca065746fb51 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -89,6 +89,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { max_device_memory_ = kDefaultMaxDeviceMemory; print_file_path_ = ""; enable_graph_kernel_ = false; + enable_sparse_flag_ = false; } std::shared_ptr MsContext::GetInstance() { diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 9afe1fa5aaca61ba10849ba8050558d71d1b7ee0..3bca16f8ee30b626d70400ce77524d7a51d0af13 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -161,6 +161,9 @@ class MsContext { void set_enable_graph_kernel(bool enable_graph_kernel) { enable_graph_kernel_ = enable_graph_kernel; } bool enable_graph_kernel() const { return enable_graph_kernel_; } + bool enable_sparse_flag() const { return enable_sparse_flag_; } + void set_enable_sparse_flag(bool enable_sparse_flag) { enable_sparse_flag_ = enable_sparse_flag; } + private: MsContext(const std::string &backend_policy, const std::string &target); void GetGeOptions(std::map *ge_options) const; @@ -204,6 +207,7 @@ class MsContext { float max_device_memory_; std::string print_file_path_; bool enable_graph_kernel_; + bool enable_sparse_flag_; }; } // namespace mindspore diff --git a/mindspore/common/__init__.py b/mindspore/common/__init__.py index ead8aee556e7765949000ea7ad47b8d46088c339..c896805d75af14b98101556a10622996e3b4e946 100644 --- a/mindspore/common/__init__.py +++ b/mindspore/common/__init__.py @@ -17,10 +17,10 @@ from . import dtype from .api import ms_function from .dtype import * from .parameter import Parameter, ParameterTuple -from .tensor import MetaTensor, Tensor +from .tensor import MetaTensor, Tensor, IndexedSlices __all__ = [ - "MetaTensor", "Tensor", # tensor + "MetaTensor", "Tensor", "IndexedSlices", # tensor 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter "dtype" diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 773f6a99a6c6484bf544516269cdb1bef5d57eb4..571cc9cb40d951800b1958ac279049550e42ad24 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -52,13 +52,16 @@ class Parameter: layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, broadcast and gradients communication would not be applied on parameters. Default: False. sparse_grad (str): Set if the parameter's gradient is sparse. Default: empty. + has_indexed_slices (bool): Set if the parameter's gradient is indexed_slices. Default: false. """ - def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, sparse_grad=""): + def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False, + sparse_grad="", has_indexed_slices_grad=False): self.set_parameter_data(default_input) self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel self.sparse_grad = sparse_grad + self.has_indexed_slices_grad = has_indexed_slices_grad self._is_init = False self._sliced = False self.clone_info = _CloneInfo() @@ -186,6 +189,17 @@ class Parameter: raise TypeError("`sparse_grad` parameter must be str type") self._sparse_grad = value + @property + def has_indexed_slices_grad(self): + """Return whether the parameter's gradient is indexed_slices.""" + return self._has_indexed_slices_grad + + @has_indexed_slices_grad.setter + def has_indexed_slices_grad(self, value=False): + if not isinstance(value, bool): + raise TypeError("`has_indexed_slices_grad` parameter must be bool type") + self._has_indexed_slices_grad = value + @property def data(self): return self.default_input diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 92c600520f9176096e17d362e8ef14260eb51c76..4bb845af5536091cf012e13c2cc3fa79c9e2102a 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry -__all__ = ['Tensor', 'MetaTensor'] +__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices'] np_types = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_) @@ -214,3 +214,8 @@ class Tensor(Tensor_): raise TypeError("init_flag must be bool.") self.set_init_flag(value) self._init_flag = value + + +class IndexedSlices: + def __init__(self, indices, values, dense_shape): + raise NotImplementedError diff --git a/mindspore/context.py b/mindspore/context.py index 070544c5291258af7f394ec142227eec5015623b..b5be6c32132bfa9b03ab35d002b3ac5229958069 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -355,6 +355,14 @@ class _Context: def check_bprop(self, check_bprop_flag): self._context_handle.set_check_bprop_flag(check_bprop_flag) + @property + def enable_sparse(self): + return self._context_handle.get_enable_sparse_flag() + + @enable_sparse.setter + def enable_sparse(self, enable_sparse_flag): + self._context_handle.set_enable_sparse_flag(enable_sparse_flag) + @property def max_device_memory(self): return self._context_handle.get_max_device_memory() @@ -510,7 +518,8 @@ def reset_auto_parallel_context(): save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, - enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str) + enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, + enable_sparse=bool) def set_context(**kwargs): """ Sets context for running environment. @@ -567,6 +576,7 @@ def set_context(**kwargs): The format is "xxGB". Default: "1024GB". print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to a file by default, and turn off printing to the screen. + enable_sparse (bool): Whether to enable sparse feature. Default: False. Raises: ValueError: If input key is not an attribute in context. diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 840c4e745e180964cf9096229b473f3ad6aed6a8..a5c3165ab105072f07365bb15b874d63203b9fbc 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -153,6 +153,14 @@ shape_mul = Primitive("shape_mul") # a primitive to compare between tuple. stop_gradient = Primitive("stop_gradient") + +make_indexed_slices = Primitive('MakeIndexedSlices') +indexed_slices_get_values = Primitive('IndexedSlicesGetValues') +indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') +indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') +is_indexed_slices = Primitive('IsIndexedSlices') + + tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__mul__', tensor_mul) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index ad5e86cd9f03fc8fb089f5f70c01bd58820d085a..ced88adec6217548d757a30624f56bbfdd2f60cc 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -564,7 +564,7 @@ class SparseGatherV2(GatherV2): >>> input_params = Tensor(np.array([[1, 2, 7, 42], [3, 4, 54, 22], [2, 2, 55, 3]]), mindspore.float32) >>> input_indices = Tensor(np.array([1, 2]), mindspore.int32) >>> axis = 1 - >>> out = P.GatherV2()(input_params, input_indices, axis) + >>> out = P.SparseGatherV2()(input_params, input_indices, axis) """ diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index ebbcdf6f7c53bd13df2ae4147337e2214af595fc..bc8561f1711095b9c661f903dc16253f6c868d44 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -603,5 +603,18 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); } + +TEST_F(TestOptLib, test_indexed_slices) { + FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices"); + FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices"); + FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values"); + FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values"); + FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape"); + FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape"); + auto patterns = std::vector({irpass.indexed_slices_eliminate_}); + ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); + ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); + ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index af8cab902c8b085e063509033a8e32af9c211553..22e2535819b0b7d0ad1dab48a77b5efb7e703113 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -1130,3 +1130,38 @@ def test_adjust_allreduce_mul_add(tag): return Mul(AllReduce(AddN((Mul(z, z), x))), y) return fns[tag] + + +def test_indexed_slices(tag): + """ test_add_zero """ + fns = FnDict() + make_indexed_slices = Primitive('MakeIndexedSlices') + indexed_slices_get_values = Primitive('IndexedSlicesGetValues') + indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') + indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') + + @fns + def before_get_indices(x, y, z): + return indexed_slices_get_indices(make_indexed_slices(x, y, z)) + + @fns + def after_get_indices(x, y, z): + return x + + @fns + def before_get_values(x, y, z): + return indexed_slices_get_values(make_indexed_slices(x, y, z)) + + @fns + def after_get_values(x, y, z): + return y + + @fns + def before_get_dense_shape(x, y, z): + return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z)) + + @fns + def after_get_dense_shape(x, y, z): + return z + + return fns[tag] diff --git a/tests/ut/python/ir/test_indexed_slices.py b/tests/ut/python/ir/test_indexed_slices.py new file mode 100644 index 0000000000000000000000000000000000000000..86901830907b7c19d7a6ff3d4c512d9a0d4b860a --- /dev/null +++ b/tests/ut/python/ir/test_indexed_slices.py @@ -0,0 +1,290 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +@File : test_indexed_slices.py +@Author: +@Date : 2020-06-08 +@Desc : test mindspore indexed_slices's operation +""" +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like +from mindspore.ops.primitive import constexpr +from mindspore.ops._grad.grad_base import bprop_getters +from mindspore import Tensor, IndexedSlices, context +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.common import dtype as mstype +from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Rel +from mindspore.nn import Optimizer +from mindspore.nn import TrainOneStepCell, WithLossCell + +reduce_sum = P.ReduceSum() +unsorted_segment_sum = P.UnsortedSegmentSum() +transpose = P.Transpose() +shape_op = P.Shape() +reshape = P.Reshape() +size_op = P.Size() +invert_permutation = P.InvertPermutation() +logical_and = P.LogicalAnd() +context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + +@constexpr +def _generate_shape_index(out_shape, indices_shape, axis): + out_rank = len(out_shape) + ind_rank = len(indices_shape) + if axis < 0: + axis += out_rank - ind_rank + 1 + perm_part1 = tuple(range(axis, axis + ind_rank)) + index = tuple(range(out_rank)) + perm = perm_part1 + index[:axis] + index[axis + ind_rank:] + return perm + +@constexpr +def _generate_inverse_index(x_shape, axis): + x_rank = len(x_shape) + index = tuple(range(x_rank)) + if axis < 0: + axis += x_rank + perm = index[1:1 + axis] + (0,) + index[1 + axis:] + return perm + +class MySparseGatherV2(P.GatherV2): + """ + For test + """ + +@bprop_getters.register(MySparseGatherV2) +def get_bprop_sparse_gather_v2(self): + """Generate bprop for MySparseGatherV2""" + + def bprop(x, indices, axis, out, dout): + x_shp = shape_op(x) + if axis == 0: + indices_size = (size_op(indices),) + x_tail_shp = x_shp[1:] + values_shape = indices_size + x_tail_shp + values = reshape(dout, values_shape) + indices = reshape(indices, indices_size) + return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) + if F.rank(dout) == 0: + dout = P.ExpandDims()(dout, -1) + if F.rank(indices) == 0: + indices = P.ExpandDims()(indices, -1) + out_shp = shape_op(dout) + ind_shp = shape_op(indices) + # Example: out_shape:(3,2,3) axis 1 -> (1,0,2) + perm_1 = _generate_shape_index(out_shp, ind_shp, axis) + values_transpose = transpose(dout, perm_1) + params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis]) + # Example: out_shape:(3,2,3) axis 2 -> (1,2,0) + perm_2 = _generate_inverse_index(x_shp, axis) + params_grad = transpose(params_grad, perm_2) + return params_grad, zeros_like(indices), zeros_like(axis) + + return bprop + +adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") +@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", + "Tensor", "Tensor", "Tensor", "Undetermined", "Bool") +def _update_run_op_for_map(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag): + if gradient.is_indexed_slices(): + return gradient.values() + op_mul = P.Mul() + op_square = P.Square() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) + + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) + + update = next_m / (op_sqrt(next_v) + eps) + if decay_flag: + update = update + op_mul(weight_decay_tensor, param_fp32) + + update_with_lr = op_mul(lr, update) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + + next_v = F.depend(next_v, F.assign(param, next_param)) + next_v = F.depend(next_v, F.assign(m, next_m)) + next_v = F.depend(next_v, F.assign(v, next_v)) + return next_v + + +def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): + """Check the type of inputs.""" + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) + + +class AdamWeightDecaySparse(Optimizer): + def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, + decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): + super(AdamWeightDecaySparse, self).__init__(learning_rate, params) + if self.is_group: + raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) + self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) + self.eps = Tensor(np.array([eps]).astype(np.float32)) + self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32)) + + self.params = self.parameters + self.moments1 = self.params.clone(prefix="adam_m", init='zeros') + self.moments2 = self.params.clone(prefix="adam_v", init='zeros') + self.decay_flag = tuple(decay_filter(x) for x in self.params) + self.map = C.Map() + + def construct(self, gradients): + lr = self.get_lr() + updated_velocity = self.map(F.partial(adam_opt_for_map, self.beta1, self.beta2, self.eps, lr, + self.weight_decay_tensor), + self.params, self.moments1, self.moments2, gradients, self.decay_flag) + return updated_velocity + + +def test_indexed_slices_make_indexed_slices(): + class MakeIndexedSlices(nn.Cell): + def __init__(self): + super(MakeIndexedSlices, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + ret = (IndexedSlices(indices, values, self.dense_shape),) + return ret[0].is_indexed_slices() + indices = Tensor([[0, 0], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + MakeIndexedSlices()(indices, values) + + +def test_indexed_slices_attr(): + class IndexedSlicesGetAttr(nn.Cell): + def __init__(self): + super(IndexedSlicesGetAttr, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + x = IndexedSlices(indices, values, self.dense_shape) + return x.values(), x.indices(), x.dense_shape() + indices = Tensor([[0, 0], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + IndexedSlicesGetAttr()(indices, values) + + +def test_indexed_slices_sparse_gatherv2_grad_all(): + grad_all = C.GradOperation('get_all', get_all=True) + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + def construct(self, x, y): + grad = grad_all(self.network)(x, y) + return grad, grad[0].is_indexed_slices(), grad[1].is_indexed_slices() + class SparseGatherV2(nn.Cell): + def __init__(self): + super(SparseGatherV2, self).__init__() + self.sparse_gatherv2 = MySparseGatherV2() + self.axis = 0 + def construct(self, params, indices): + return self.sparse_gatherv2(params, indices, self.axis) + params = Tensor(np.ones([3, 1, 2]).astype(np.int32)) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + GradWrap(SparseGatherV2())(params, indices) + + +def test_indexed_slices_sparse_gatherv2_grad_with_pram(): + grad_by_list = C.GradOperation('get_by_list', get_by_list=True) + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters())) + def construct(self, x): + weights = self.weights + grad = grad_by_list(self.network, weights)(x) + x = grad[0] + return x.is_indexed_slices(), x.values(), x.indices(), x.dense_shape() + class SparseGatherV2(nn.Cell): + def __init__(self): + super(SparseGatherV2, self).__init__() + self.sparse_gatherv2 = MySparseGatherV2() + self.axis = 0 + self.params = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.int32)), + name="params", has_indexed_slices_grad=True) + def construct(self, indices): + return self.sparse_gatherv2(self.params, indices, self.axis) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + network = GradWrap(SparseGatherV2()) + network(indices) + + +def test_indexed_slices_is_indexed_slices(): + class MakeIndexedSlices(nn.Cell): + def __init__(self): + super(MakeIndexedSlices, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + indexed_slices = IndexedSlices(indices, values, self.dense_shape) + ret = indexed_slices.is_indexed_slices() + return ret + indices = Tensor([[0, 0], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + MakeIndexedSlices()(indices, values) + + +def test_indexed_slices_env_get(): + class Loss(nn.Cell): + def __init__(self): + super(Loss, self).__init__() + def construct(self, base, target): + return base + class NetWithSparseGatherV2(nn.Cell): + def __init__(self): + super(NetWithSparseGatherV2, self).__init__() + self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", has_indexed_slices_grad=True) + self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") + self.gatherv2 = MySparseGatherV2() + self.axis = 0 + def construct(self, indices): + return self.gatherv2(self.w1, indices, self.axis) * self.w2 + + inputs = Tensor(np.array([0, 1]).astype(np.int32)) + label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) + net = NetWithSparseGatherV2() + net.set_train() + loss = Loss() + optimizer = AdamWeightDecaySparse(net.trainable_params()) + + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepCell(net_with_loss, optimizer) + train_network(inputs, label) diff --git a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py index 5222f920ba6078bab95d67bb59a8bc9ea030e520..7f9f341a931d85c3d5677c4116f44c6e24966caa 100644 --- a/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py +++ b/tests/ut/python/nn/optim/test_adam_with_tuple_grad.py @@ -155,7 +155,7 @@ def test_AdamWeightDecaySparse(): def __init__(self): super(NetWithSparseGatherV2, self).__init__() self.w1 = Parameter(Tensor(np.ones([3, 1, 2]).astype(np.float32)), name="w1", sparse_grad="sparse_key_w1") - self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2", sparse_grad="sparse_key_w2") + self.w2 = Parameter(Tensor(np.ones([2, 1, 2]).astype(np.float32)), name="w2") self.gatherv2 = P.SparseGatherV2() self.axis = 0 def construct(self, indices):