提交 bf8651e2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4031 fix cell bprop

Merge pull request !4031 from riemann_penn/fix_cell_bprop
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Resources for ast tree parse.""" """Resources for ast tree parse."""
import ast import ast
import math import math
from mindspore import IndexedSlices, SparseTensor from mindspore import RowTensor, SparseTensor
from mindspore.ops.composite import multitype_ops from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C from mindspore.ops import functional as F, composite as C
from . import standard_method as M from . import standard_method as M
...@@ -140,6 +140,6 @@ convert_object_map = { ...@@ -140,6 +140,6 @@ convert_object_map = {
math.tan: NO_IMPLEMENT, math.tan: NO_IMPLEMENT,
# user defined # user defined
IndexedSlices: F.make_indexed_slices, RowTensor: F.make_row_tensor,
SparseTensor: F.make_sparse_tensor, SparseTensor: F.make_sparse_tensor,
} }
...@@ -120,7 +120,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s ...@@ -120,7 +120,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
} }
} }
} else if (type->isa<IndexedSlicesType>()) { } else if (type->isa<RowTensorType>()) {
// Do Nothing // Do Nothing
} else if (type->isa<UndeterminedType>()) { } else if (type->isa<UndeterminedType>()) {
// Do Nothing // Do Nothing
......
...@@ -174,12 +174,11 @@ inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_Virtua ...@@ -174,12 +174,11 @@ inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_Virtua
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
// IndexedSlices // RowTensor
inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices"); inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared<Primitive>("MakeRowTensor");
inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared<Primitive>("RowTensorGetValues");
inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); inline const PrimitivePtr kPrimRowTensorGetIndices = std::make_shared<Primitive>("RowTensorGetIndices");
inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); inline const PrimitivePtr kPrimRowTensorGetDenseShape = std::make_shared<Primitive>("RowTensorGetDenseShape");
inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
// SparseTensor // SparseTensor
inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor"); inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
......
...@@ -340,8 +340,8 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv ...@@ -340,8 +340,8 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
return std::make_shared<AbstractScalar>(kAnyValue, kBool); return std::make_shared<AbstractScalar>(kAnyValue, kBool);
} }
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple. // Inputs: two tensors and a tuple.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3); CheckArgsSize(op_name, args_spec_list, 3);
...@@ -393,41 +393,41 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim ...@@ -393,41 +393,41 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
<< "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i]; << "th dimension of values " << values_shp[i] << ", but got " << dense_shape_vec[i];
} }
} }
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); auto ret = std::make_shared<AbstractRowTensor>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices); ret->set_indices(indices);
ret->set_values(values); ret->set_values(values);
ret->set_dense_shape(dense_shape); ret->set_dense_shape(dense_shape);
return ret; return ret;
} }
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple. // Inputs: two tensors and a tuple.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1); CheckArgsSize(op_name, args_spec_list, 1);
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(indexed_slices->values()); MS_EXCEPTION_IF_NULL(row_tensor->values());
return indexed_slices->values(); return row_tensor->values();
} }
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple. // Inputs: two tensors and a tuple.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1); CheckArgsSize(op_name, args_spec_list, 1);
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(indexed_slices->indices()); MS_EXCEPTION_IF_NULL(row_tensor->indices());
return indexed_slices->indices(); return row_tensor->indices();
} }
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple. // Inputs: two tensors and a tuple.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1); CheckArgsSize(op_name, args_spec_list, 1);
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0); auto row_tensor = CheckArg<AbstractRowTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape()); MS_EXCEPTION_IF_NULL(row_tensor->dense_shape());
return indexed_slices->dense_shape(); return row_tensor->dense_shape();
} }
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
......
...@@ -32,9 +32,9 @@ namespace opt { ...@@ -32,9 +32,9 @@ namespace opt {
using mindspore::abstract::AbstractAttribute; using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractIndexedSlices;
using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractRowTensor;
using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractSparseTensor;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
...@@ -81,10 +81,10 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { ...@@ -81,10 +81,10 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
return std::make_shared<AbstractTuple>(abstract_list); return std::make_shared<AbstractTuple>(abstract_list);
} }
if (t->isa<AbstractIndexedSlices>()) { if (t->isa<AbstractRowTensor>()) {
auto abs_indexed_slices = dyn_cast<AbstractIndexedSlices>(t); auto abs_row_tensor = dyn_cast<AbstractRowTensor>(t);
std::vector<AbstractBasePtr> abstract_list{abs_indexed_slices->indices(), abs_indexed_slices->values(), std::vector<AbstractBasePtr> abstract_list{abs_row_tensor->indices(), abs_row_tensor->values(),
abs_indexed_slices->dense_shape()}; abs_row_tensor->dense_shape()};
return std::make_shared<AbstractTuple>(abstract_list); return std::make_shared<AbstractTuple>(abstract_list);
} }
...@@ -455,16 +455,16 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager ...@@ -455,16 +455,16 @@ bool CleanAfterOptA(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager
} else if (IsValueNode<ValueList>(node)) { } else if (IsValueNode<ValueList>(node)) {
new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>()); new_node = ConvertValueListNodeToValueTupleNode(node->cast<ValueNodePtr>());
} else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) || } else if (IsPrimitiveCNode(node, prim::kPrimMakeSparseTensor) ||
IsPrimitiveCNode(node, prim::kPrimMakeIndexedSlices)) { IsPrimitiveCNode(node, prim::kPrimMakeRowTensor)) {
new_node = ConvertMakeSparseToMakeTuple(cnode); new_node = ConvertMakeSparseToMakeTuple(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) || } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetIndices) ||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetIndices)) { IsPrimitiveCNode(node, prim::kPrimRowTensorGetIndices)) {
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0); new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 0);
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) || } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetValues) ||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetValues)) { IsPrimitiveCNode(node, prim::kPrimRowTensorGetValues)) {
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1); new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 1);
} else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) || } else if (IsPrimitiveCNode(node, prim::kPrimSparseTensorGetDenseShape) ||
IsPrimitiveCNode(node, prim::kPrimIndexedSlicesGetDenseShape)) { IsPrimitiveCNode(node, prim::kPrimRowTensorGetDenseShape)) {
new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2); new_node = ConvertSparseGetAttrToTupleGetItem(cnode, 2);
} }
......
...@@ -43,7 +43,7 @@ ...@@ -43,7 +43,7 @@
#include "frontend/optimizer/irpass/transpose_eliminate.h" #include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/irpass/value_based_eliminate.h" #include "frontend/optimizer/irpass/value_based_eliminate.h"
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" #include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
namespace mindspore { namespace mindspore {
...@@ -157,10 +157,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { ...@@ -157,10 +157,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
mark_interface_fusion_ = mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect); MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
// IndexedSlices Eliminate // RowTensor Eliminate
indexed_slices_eliminate_ = MakeSubstitution( row_tensor_eliminate_ = MakeSubstitution(
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate", std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape});
// SparseTensor Eliminate // SparseTensor Eliminate
sparse_tensor_eliminate_ = MakeSubstitution( sparse_tensor_eliminate_ = MakeSubstitution(
......
...@@ -105,8 +105,8 @@ class OptimizeIRPassLib { ...@@ -105,8 +105,8 @@ class OptimizeIRPassLib {
// Fusion // Fusion
SubstitutionPtr mark_interface_fusion_; SubstitutionPtr mark_interface_fusion_;
// IndexedSlices Eliminate // RowTensor Eliminate
SubstitutionPtr indexed_slices_eliminate_; SubstitutionPtr row_tensor_eliminate_;
// SparseTensor Eliminate // SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_; SubstitutionPtr sparse_tensor_eliminate_;
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
...@@ -28,24 +28,24 @@ ...@@ -28,24 +28,24 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace irpass { namespace irpass {
// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}} // {prim::kPrimRowTensorGetIndices, {prim::kPrimMakeRowTensor, Xs}}
// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}} // {prim::kPrimRowTensorGetValues, {prim::kPrimMakeRowTensor, Xs}}
// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}} // {prim::kPrimRowTensorGetDenseShape, {prim::kPrimMakeRowTensor, Xs}}
class IndexedSlicesEliminater : public AnfVisitor { class RowTensorEliminater : public AnfVisitor {
public: public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset(); Reset();
AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node); AnfVisitor::Match(prim::kPrimRowTensorGetIndices, {IsCNode})(node);
if (is_match_) { if (is_match_) {
return tuple_->input(1); return tuple_->input(1);
} }
AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node); AnfVisitor::Match(prim::kPrimRowTensorGetValues, {IsCNode})(node);
if (is_match_) { if (is_match_) {
return tuple_->input(2); return tuple_->input(2);
} }
AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node); AnfVisitor::Match(prim::kPrimRowTensorGetDenseShape, {IsCNode})(node);
if (is_match_) { if (is_match_) {
return tuple_->input(3); return tuple_->input(3);
...@@ -54,7 +54,7 @@ class IndexedSlicesEliminater : public AnfVisitor { ...@@ -54,7 +54,7 @@ class IndexedSlicesEliminater : public AnfVisitor {
} }
void Visit(const CNodePtr &cnode) override { void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) { if (IsPrimitiveCNode(cnode, prim::kPrimMakeRowTensor)) {
tuple_ = cnode; tuple_ = cnode;
is_match_ = true; is_match_ = true;
} }
...@@ -72,4 +72,4 @@ class IndexedSlicesEliminater : public AnfVisitor { ...@@ -72,4 +72,4 @@ class IndexedSlicesEliminater : public AnfVisitor {
} // namespace irpass } // namespace irpass
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_ #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ROW_TENSOR_ELIMINATE_H_
...@@ -170,7 +170,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -170,7 +170,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.replace_refkey_by_param_, irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_, irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_, irpass.row_tensor_eliminate_,
}); });
OptPassGroupMap map({ OptPassGroupMap map({
{"b_1", b_1}, {"b_1", b_1},
......
...@@ -132,11 +132,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { ...@@ -132,11 +132,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimControlDepend, {InferImplControlDepend, true}}, {prim::kPrimControlDepend, {InferImplControlDepend, true}},
// Debug // Debug
{prim::kPrimDebug, {InferImplDebug, true}}, {prim::kPrimDebug, {InferImplDebug, true}},
// IndexedSlices // RowTensor
{prim::kPrimMakeIndexedSlices, {InferImplMakeIndexedSlices, true}}, {prim::kPrimMakeRowTensor, {InferImplMakeRowTensor, true}},
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}},
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}},
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}},
// SparseTensor // SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
...@@ -402,8 +402,8 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { ...@@ -402,8 +402,8 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
} }
dic["dtype"] = arg_tensor->BuildType(); dic["dtype"] = arg_tensor->BuildType();
dic["value"] = BuildValue(arg_tensor->BuildValue()); dic["value"] = BuildValue(arg_tensor->BuildValue());
} else if (abs_base->isa<AbstractIndexedSlices>()) { } else if (abs_base->isa<AbstractRowTensor>()) {
auto arg = dyn_cast<AbstractIndexedSlices>(abs_base); auto arg = dyn_cast<AbstractRowTensor>(abs_base);
dic["shape"] = arg->shape()->shape(); dic["shape"] = arg->shape()->shape();
dic["dtype"] = arg->BuildType(); dic["dtype"] = arg->BuildType();
dic["value"] = BuildValue(arg->BuildValue()); dic["value"] = BuildValue(arg->BuildValue());
......
...@@ -348,14 +348,14 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv ...@@ -348,14 +348,14 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplRowTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
......
...@@ -32,9 +32,9 @@ using mindspore::abstract::AbstractBase; ...@@ -32,9 +32,9 @@ using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractError; using mindspore::abstract::AbstractError;
using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractIndexedSlices;
using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractRowTensor;
using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractSparseTensor;
using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensor;
...@@ -95,7 +95,7 @@ void ValidateAbstract(const AnfNodePtr &node) { ...@@ -95,7 +95,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
} }
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() || ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractRowTensor>() ||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) { ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
return; return;
} }
......
...@@ -136,8 +136,7 @@ REGISTER_PYBIND_DEFINE( ...@@ -136,8 +136,7 @@ REGISTER_PYBIND_DEFINE(
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>())))); TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
return data; return data;
})); }));
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType") (void)py::class_<RowTensorType, Type, std::shared_ptr<RowTensorType>>(m_sub, "RowTensorType").def(py::init());
.def(py::init());
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType") (void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
.def(py::init()); .def(py::init());
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType") (void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
......
...@@ -17,10 +17,10 @@ from . import dtype ...@@ -17,10 +17,10 @@ from . import dtype
from .api import ms_function from .api import ms_function
from .dtype import * from .dtype import *
from .parameter import Parameter, ParameterTuple from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor from .tensor import MetaTensor, Tensor, RowTensor, SparseTensor
__all__ = [ __all__ = [
"MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor "MetaTensor", "Tensor", "RowTensor", "SparseTensor", # tensor
'ms_function', # api 'ms_function', # api
'Parameter', 'ParameterTuple', # parameter 'Parameter', 'ParameterTuple', # parameter
"dtype" "dtype"
......
...@@ -99,7 +99,7 @@ slice_type = typing.Slice ...@@ -99,7 +99,7 @@ slice_type = typing.Slice
ellipsis_type = typing.TypeEllipsis ellipsis_type = typing.TypeEllipsis
list_type = typing.List list_type = typing.List
tuple_type = typing.Tuple tuple_type = typing.Tuple
index_slices = typing.IndexedSlicesType() index_slices = typing.RowTensorType()
sparse_tensor = typing.SparseTensorType() sparse_tensor = typing.SparseTensorType()
undetermined = typing.UndeterminedType() undetermined = typing.UndeterminedType()
......
...@@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename ...@@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor'] __all__ = ['Tensor', 'MetaTensor', 'RowTensor', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64, np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_) np.float32, np.float64, np.bool_)
...@@ -267,20 +267,20 @@ class Tensor(Tensor_): ...@@ -267,20 +267,20 @@ class Tensor(Tensor_):
return tensor_operator_registry.get('any')(keep_dims)(self, axis) return tensor_operator_registry.get('any')(keep_dims)(self, axis)
class IndexedSlices: class RowTensor:
""" """
A sparse representation of a set of tensor slices at given indices. A sparse representation of a set of tensor slices at given indices.
An IndexedSlices is typically used to represent a subset of a larger An RowTensor is typically used to represent a subset of a larger
tensor dense of shape [L0, D1, .. , DN] where L0 >> D0. tensor dense of shape [L0, D1, .. , DN] where L0 >> D0.
The values in indices are the indices in the first dimension of the slices The values in indices are the indices in the first dimension of the slices
that have been extracted from the larger tensor. that have been extracted from the larger tensor.
The dense tensor dense represented by an IndexedSlices slices has The dense tensor dense represented by an RowTensor slices has
`dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`. `dense[slices.indices[i], :, :, :, ...] = slices.values[i, :, :, :, ...]`.
IndexedSlices can only be used in the `Cell`'s construct method. RowTensor can only be used in the `Cell`'s contruct method.
It is not supported in pynative mode at the moment. It is not supported in pynative mode at the moment.
...@@ -291,7 +291,7 @@ class IndexedSlices: ...@@ -291,7 +291,7 @@ class IndexedSlices:
of the corresponding dense tensor. of the corresponding dense tensor.
Returns: Returns:
IndexedSlices, composed of `indices`, `values`, and `dense_shape`. RowTensor, composed of `indices`, `values`, and `dense_shape`.
Examples: Examples:
>>> class Net(nn.Cell): >>> class Net(nn.Cell):
...@@ -299,8 +299,8 @@ class IndexedSlices: ...@@ -299,8 +299,8 @@ class IndexedSlices:
>>> super(Net, self).__init__() >>> super(Net, self).__init__()
>>> self.dense_shape = dense_shape >>> self.dense_shape = dense_shape
>>> def construct(self, indices, values): >>> def construct(self, indices, values):
>>> x = IndexedSlices(indices, values, self.dense_shape) >>> x = RowTensor(indices, values, self.dense_shape)
>>> return x.values(), x.indices(), x.dense_shape() >>> return x.values, x.indices, x.dense_shape
>>> >>>
>>> indices = Tensor([0]) >>> indices = Tensor([0])
>>> values = Tensor([[1, 2]], dtype=ms.float32) >>> values = Tensor([[1, 2]], dtype=ms.float32)
...@@ -308,17 +308,20 @@ class IndexedSlices: ...@@ -308,17 +308,20 @@ class IndexedSlices:
""" """
def __init__(self, indices, values, dense_shape): def __init__(self, indices, values, dense_shape):
"Init IndexedSlices" "Init RowTensor"
self.__indices = indices self.__indices = indices
self.__values = values self.__values = values
self.__dense_shape = dense_shape self.__dense_shape = dense_shape
@property
def indices(self): def indices(self):
return self.__indices return self.__indices
@property
def values(self): def values(self):
return self.__values return self.__values
@property
def dense_shape(self): def dense_shape(self):
return self.__dense_shape return self.__dense_shape
...@@ -353,7 +356,7 @@ class SparseTensor: ...@@ -353,7 +356,7 @@ class SparseTensor:
>>> self.dense_shape = dense_shape >>> self.dense_shape = dense_shape
>>> def construct(self, indices, values): >>> def construct(self, indices, values):
>>> x = SparseTensor(indices, values, self.dense_shape) >>> x = SparseTensor(indices, values, self.dense_shape)
>>> return x.values(), x.indices(), x.dense_shape() >>> return x.values, x.indices, x.dense_shape
>>> >>>
>>> indices = Tensor([[0, 1], [1, 2]]) >>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32) >>> values = Tensor([1, 2], dtype=ms.float32)
...@@ -366,11 +369,14 @@ class SparseTensor: ...@@ -366,11 +369,14 @@ class SparseTensor:
self.__values = values self.__values = values
self.__dense_shape = dense_shape self.__dense_shape = dense_shape
@property
def indices(self): def indices(self):
return self.__indices return self.__indices
@property
def values(self): def values(self):
return self.__values return self.__values
@property
def dense_shape(self): def dense_shape(self):
return self.__dense_shape return self.__dense_shape
...@@ -1050,16 +1050,16 @@ bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const ...@@ -1050,16 +1050,16 @@ bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const
return AbstractBasePtrListDeepEqual(lhs, rhs); return AbstractBasePtrListDeepEqual(lhs, rhs);
} }
// IndexedSlices // RowTensor
TypePtr AbstractIndexedSlices::BuildType() const { TypePtr AbstractRowTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element()); MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType(); TypePtr element_type = element()->BuildType();
return std::make_shared<IndexedSlicesType>(element_type); return std::make_shared<RowTensorType>(element_type);
} }
AbstractBasePtr AbstractIndexedSlices::Clone() const { AbstractBasePtr AbstractRowTensor::Clone() const {
MS_EXCEPTION_IF_NULL(element()); MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone()); auto clone = std::make_shared<AbstractRowTensor>(element()->Clone());
ShapePtr shp = shape(); ShapePtr shp = shape();
clone->set_shape(shp->Clone()); clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack()); clone->set_value(GetValueTrack());
...@@ -1069,9 +1069,9 @@ AbstractBasePtr AbstractIndexedSlices::Clone() const { ...@@ -1069,9 +1069,9 @@ AbstractBasePtr AbstractIndexedSlices::Clone() const {
return clone; return clone;
} }
AbstractBasePtr AbstractIndexedSlices::Broaden() const { AbstractBasePtr AbstractRowTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element()); MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden()); auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
auto shp = shape(); auto shp = shape();
broaden->set_shape(shp->Clone()); broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue); broaden->set_value(kAnyValue);
...@@ -1081,9 +1081,9 @@ AbstractBasePtr AbstractIndexedSlices::Broaden() const { ...@@ -1081,9 +1081,9 @@ AbstractBasePtr AbstractIndexedSlices::Broaden() const {
return broaden; return broaden;
} }
AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { AbstractBasePtr AbstractRowTensor::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element()); MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden()); auto broaden = std::make_shared<AbstractRowTensor>(element()->Broaden());
auto shp = shape()->Clone(); auto shp = shape()->Clone();
shp->Broaden(); shp->Broaden();
broaden->set_shape(shp); broaden->set_shape(shp);
...@@ -1094,7 +1094,7 @@ AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const { ...@@ -1094,7 +1094,7 @@ AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const {
return broaden; return broaden;
} }
std::string AbstractIndexedSlices::ToString() const { std::string AbstractRowTensor::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack(); BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track); MS_EXCEPTION_IF_NULL(shape_track);
......
...@@ -593,15 +593,15 @@ struct AbstractBasePtrListEqual { ...@@ -593,15 +593,15 @@ struct AbstractBasePtrListEqual {
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
// IndexedSlices // RowTensor
class AbstractIndexedSlices : public AbstractUndetermined { class AbstractRowTensor : public AbstractUndetermined {
public: public:
explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>()) explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {} : AbstractUndetermined(element, shape) {}
AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape) AbstractRowTensor(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {} : AbstractUndetermined(element_type, shape) {}
~AbstractIndexedSlices() override = default; ~AbstractRowTensor() override = default;
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; } const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
......
...@@ -66,7 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function) ...@@ -66,7 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function)
ABSTRACT_REPORT_NAME_TRAITS(Type) ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class) ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) ABSTRACT_REPORT_NAME_TRAITS(RowTensor)
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor) ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue) ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
......
...@@ -179,40 +179,40 @@ bool TensorType::operator==(const Type &other) const { ...@@ -179,40 +179,40 @@ bool TensorType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type; return *element_type_ == *other_elem_type;
} }
TypePtr IndexedSlicesType::DeepCopy() const { TypePtr RowTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_); MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) { if (IsGeneric()) {
return std::make_shared<IndexedSlicesType>(); return std::make_shared<RowTensorType>();
} }
return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy()); return std::make_shared<RowTensorType>(element_type_->DeepCopy());
} }
std::string IndexedSlicesType::ToReprString() const { std::string RowTensorType::ToReprString() const {
if (element_type_ == nullptr) { if (element_type_ == nullptr) {
return "IndexedSlices"; return "RowTensor";
} }
return "IndexedSlices[" + element_type_->ToReprString() + "]"; return "RowTensor[" + element_type_->ToReprString() + "]";
} }
std::string IndexedSlicesType::ToString() const { std::string RowTensorType::ToString() const {
if (element_type_ == nullptr) { if (element_type_ == nullptr) {
return "IndexedSlices"; return "RowTensor";
} }
return "IndexedSlices[" + element_type_->ToString() + "]"; return "RowTensor[" + element_type_->ToString() + "]";
} }
std::string IndexedSlicesType::DumpText() const { std::string RowTensorType::DumpText() const {
if (element_type_ == nullptr) { if (element_type_ == nullptr) {
return "IndexedSlices"; return "RowTensor";
} }
return "IndexedSlices[" + element_type_->DumpText() + "]"; return "RowTensor[" + element_type_->DumpText() + "]";
} }
bool IndexedSlicesType::operator==(const Type &other) const { bool RowTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) { if (!IsSameObjectType(*this, other)) {
return false; return false;
} }
auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_; auto other_elem_type = static_cast<const RowTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) { if (element_type_ == nullptr && other_elem_type == nullptr) {
return true; return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) { } else if (element_type_ == nullptr || other_elem_type == nullptr) {
......
...@@ -154,15 +154,15 @@ class TensorType : public Object { ...@@ -154,15 +154,15 @@ class TensorType : public Object {
}; };
using TensorTypePtr = std::shared_ptr<TensorType>; using TensorTypePtr = std::shared_ptr<TensorType>;
class IndexedSlicesType : public Object { class RowTensorType : public Object {
public: public:
IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {} RowTensorType() : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType) {}
explicit IndexedSlicesType(const TypePtr &ele) explicit RowTensorType(const TypePtr &ele)
: Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {} : Object(kObjectTypeRowTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~IndexedSlicesType() override = default; ~RowTensorType() override = default;
MS_DECLARE_PARENT(IndexedSlicesType, Object) MS_DECLARE_PARENT(RowTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; } TypeId generic_type_id() const override { return kObjectTypeRowTensorType; }
const TypePtr element() const { return element_type_; } const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; } void set_element(const TypePtr &element_type) { element_type_ = element_type; }
...@@ -175,7 +175,7 @@ class IndexedSlicesType : public Object { ...@@ -175,7 +175,7 @@ class IndexedSlicesType : public Object {
private: private:
TypePtr element_type_; TypePtr element_type_;
}; };
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>; using RowTensorTypePtr = std::shared_ptr<RowTensorType>;
class SparseTensorType : public Object { class SparseTensorType : public Object {
public: public:
......
...@@ -115,8 +115,8 @@ const char *ObjectIdLabel(const TypeId &v) { ...@@ -115,8 +115,8 @@ const char *ObjectIdLabel(const TypeId &v) {
return "kObjectTypeKeyword"; return "kObjectTypeKeyword";
case kObjectTypeTensorType: case kObjectTypeTensorType:
return "kObjectTypeTensorType"; return "kObjectTypeTensorType";
case kObjectTypeIndexedSlicesType: case kObjectTypeRowTensorType:
return "kObjectTypeIndexedSlicesType"; return "kObjectTypeRowTensorType";
case kObjectTypeSparseTensorType: case kObjectTypeSparseTensorType:
return "kObjectTypeSparseTensorType"; return "kObjectTypeSparseTensorType";
case kObjectTypeUndeterminedType: case kObjectTypeUndeterminedType:
......
...@@ -50,7 +50,7 @@ enum TypeId : int { ...@@ -50,7 +50,7 @@ enum TypeId : int {
kObjectTypeSlice, kObjectTypeSlice,
kObjectTypeKeyword, kObjectTypeKeyword,
kObjectTypeTensorType, kObjectTypeTensorType,
kObjectTypeIndexedSlicesType, kObjectTypeRowTensorType,
kObjectTypeSparseTensorType, kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType, kObjectTypeUndeterminedType,
kObjectTypeClass, kObjectTypeClass,
......
...@@ -190,9 +190,9 @@ TypePtr TensorStrToType(const std::string &type_name) { ...@@ -190,9 +190,9 @@ TypePtr TensorStrToType(const std::string &type_name) {
return type; return type;
} }
TypePtr IndexedSlicesStrToType(const std::string &type_name) { TypePtr RowTensorStrToType(const std::string &type_name) {
if (type_name == "IndexedSlices") { if (type_name == "RowTensor") {
return std::make_shared<IndexedSlicesType>(); return std::make_shared<RowTensorType>();
} }
auto start = type_name.find_first_of('[') + 1; auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']'); auto end = type_name.find_last_of(']');
...@@ -204,7 +204,7 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) { ...@@ -204,7 +204,7 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) {
if (element_type == nullptr) { if (element_type == nullptr) {
return nullptr; return nullptr;
} }
return std::make_shared<IndexedSlicesType>(element_type); return std::make_shared<RowTensorType>(element_type);
} }
TypePtr SparseTensorStrToType(const std::string &type_name) { TypePtr SparseTensorStrToType(const std::string &type_name) {
...@@ -364,8 +364,8 @@ TypePtr StringToType(const std::string &type_name) { ...@@ -364,8 +364,8 @@ TypePtr StringToType(const std::string &type_name) {
type = TensorStrToType(type_name); type = TensorStrToType(type_name);
} else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) { } else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
type = UndeterminedStrToType(type_name); type = UndeterminedStrToType(type_name);
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { } else if (type_name.compare(0, strlen("RowTensor"), "RowTensor") == 0) {
type = IndexedSlicesStrToType(type_name); type = RowTensorStrToType(type_name);
} else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) { } else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
type = SparseTensorStrToType(type_name); type = SparseTensorStrToType(type_name);
} else if (type_name.compare(0, strlen("List"), "List") == 0) { } else if (type_name.compare(0, strlen("List"), "List") == 0) {
...@@ -446,7 +446,7 @@ const TypePtr kTypeExternal = std::make_shared<External>(); ...@@ -446,7 +446,7 @@ const TypePtr kTypeExternal = std::make_shared<External>();
const TypePtr kTypeEnv = std::make_shared<EnvType>(); const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>(); const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>(); const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>(); const TypePtr kRowTensorType = std::make_shared<RowTensorType>();
const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>(); const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>();
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>(); const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
const TypePtr kString = std::make_shared<String>(); const TypePtr kString = std::make_shared<String>();
......
...@@ -85,13 +85,13 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d ...@@ -85,13 +85,13 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d
@_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @_adam_opt.register("Function", "Function", "Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "IndexedSlices", "Tensor", "Tensor", "Tensor", "Bool") "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr, def _run_opt_with_sparse(opt, sparse_opt, push, pull, beta1_power, beta2_power, beta1, beta2, eps, lr,
gradient, params, moment1, moment2, ps_parameter): gradient, params, moment1, moment2, ps_parameter):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
indices = gradient.indices() indices = gradient.indices
values = gradient.values() values = gradient.values
if ps_parameter: if ps_parameter:
op_shape = P.Shape() op_shape = P.Shape()
shapes = (op_shape(params), op_shape(moment1), op_shape(moment2), shapes = (op_shape(params), op_shape(moment1), op_shape(moment2),
......
...@@ -24,13 +24,13 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") ...@@ -24,13 +24,13 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
@_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor",
"IndexedSlices", "Tensor", "Tensor", "Bool") "RowTensor", "Tensor", "Tensor", "Bool")
def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear,
gradient, weight, moment, ps_parameter): gradient, weight, moment, ps_parameter):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
indices = gradient.indices() indices = gradient.indices
values = gradient.values() values = gradient.values
if ps_parameter: if ps_parameter:
op_shape = P.Shape() op_shape = P.Shape()
shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices))
......
...@@ -28,13 +28,13 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt") ...@@ -28,13 +28,13 @@ _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
@_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", @_lazy_adam_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor",
"IndexedSlices", "Tensor", "Tensor", "Tensor") "RowTensor", "Tensor", "Tensor", "Tensor")
def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, def _run_opt_with_sparse(opt, sparse_opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params,
moment1, moment2): moment1, moment2):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse.""" """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
success = True success = True
success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, success = F.depend(success, sparse_opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
eps, gradient.values(), gradient.indices())) eps, gradient.values, gradient.indices))
return success return success
......
...@@ -23,7 +23,7 @@ from mindspore.nn.cell import Cell ...@@ -23,7 +23,7 @@ from mindspore.nn.cell import Cell
from mindspore.nn.layer.container import CellList from mindspore.nn.layer.container import CellList
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor, IndexedSlices from mindspore.common.tensor import Tensor, RowTensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
...@@ -493,14 +493,14 @@ op_gather = P.GatherV2() ...@@ -493,14 +493,14 @@ op_gather = P.GatherV2()
_apply_decay = C.MultitypeFuncGraph("apply_decay") _apply_decay = C.MultitypeFuncGraph("apply_decay")
@_apply_decay.register("Number", "Bool", "Tensor", "IndexedSlices") @_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay.""" """Get grad with weight_decay."""
if if_apply: if if_apply:
indices = gradient.indices() indices = gradient.indices
values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values())) values = op_add((op_gather(weight, indices, 0) * weight_decay, gradient.values))
shape = gradient.dense_shape() shape = gradient.dense_shape
return IndexedSlices(indices, values, shape) return RowTensor(indices, values, shape)
return gradient return gradient
...@@ -523,12 +523,12 @@ def tensor_grad_scale(scale, grad): ...@@ -523,12 +523,12 @@ def tensor_grad_scale(scale, grad):
return grad * scale return grad * scale
@_grad_scale.register("Number", "IndexedSlices") @_grad_scale.register("Number", "RowTensor")
def tensor_grad_scale_with_sparse(scale, grad): def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
if scale == 1.0: if scale == 1.0:
return grad return grad
return IndexedSlices(grad.indices(), grad.values() * scale, grad.dense_shape()) return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)
class _ConvertToCell(LearningRateSchedule): class _ConvertToCell(LearningRateSchedule):
......
...@@ -22,12 +22,12 @@ from .optimizer import Optimizer ...@@ -22,12 +22,12 @@ from .optimizer import Optimizer
_proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt") _proximal_ada_grad_opt = C.MultitypeFuncGraph("proximal_ada_grad_opt")
@_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "IndexedSlices", "Tensor", @_proximal_ada_grad_opt.register("Function", "Function", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor",
"Tensor") "Tensor")
def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum): def _tensor_run_opt_with_sparse(opt, sparse_opt, l1, l2, learning_rate, gradient, weight, accum):
"""Apply sparse proximal_ada_grad optimizer to the weight parameter.""" """Apply sparse proximal_ada_grad optimizer to the weight parameter."""
success = True success = True
success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values(), gradient.indices())) success = F.depend(success, sparse_opt(weight, accum, learning_rate, l1, l2, gradient.values, gradient.indices))
return success return success
......
...@@ -49,6 +49,6 @@ class SparseToDense(Cell): ...@@ -49,6 +49,6 @@ class SparseToDense(Cell):
self.sparse_to_dense = P.SparseToDense() self.sparse_to_dense = P.SparseToDense()
def construct(self, sparse_tensor): def construct(self, sparse_tensor):
return self.sparse_to_dense(sparse_tensor.indices(), return self.sparse_to_dense(sparse_tensor.indices,
sparse_tensor.values(), sparse_tensor.values,
sparse_tensor.dense_shape()) sparse_tensor.dense_shape)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from mindspore import context from mindspore import context
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.common.tensor import IndexedSlices from mindspore.common.tensor import RowTensor
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, AllGather from mindspore.ops.operations.comm_ops import AllReduce, AllGather
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
...@@ -103,7 +103,7 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter, ...@@ -103,7 +103,7 @@ def _tensors_allreduce_ps(degree, mean, allgather, allreduce, allreduce_filter,
return grad return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices") @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad): def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce_filter, grad):
""" """
Apply allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
...@@ -118,21 +118,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce ...@@ -118,21 +118,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce, allreduce
grad (tuple): The indices, gradient tensor and tensor_shape before operation. grad (tuple): The indices, gradient tensor and tensor_shape before operation.
Returns: Returns:
IndexedSlices, the gradient after operation. RowTensor, the gradient after operation.
""" """
if allreduce_filter: if allreduce_filter:
indices = allgather(grad.indices()) indices = allgather(grad.indices)
dout = allgather(grad.values()) dout = allgather(grad.values)
if mean: if mean:
degree = F.scalar_cast(degree, F.dtype(grad.values())) degree = F.scalar_cast(degree, F.dtype(grad.values))
cast_op = P.Cast() cast_op = P.Cast()
mul_op = P.Mul() mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = IndexedSlices(indices, dout, grad.dense_shape()) grad = RowTensor(indices, dout, grad.dense_shape)
return grad return grad
@reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "IndexedSlices", "Bool") @reduce_opt.register("Number", "Bool", "Function", "Function", "Bool", "RowTensor", "Bool")
def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter): def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allreduce_filter, grad, ps_parameter):
""" """
Apply allgather on gradient instead of allreduce for sparse feature. Apply allgather on gradient instead of allreduce for sparse feature.
...@@ -148,20 +148,20 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred ...@@ -148,20 +148,20 @@ def _tensors_allreduce_with_sparse_ps(degree, mean, allgather, allreduce, allred
ps_parameter (bool): Use parameter server or not. ps_parameter (bool): Use parameter server or not.
Returns: Returns:
IndexedSlices, the gradient after operation. RowTensor, the gradient after operation.
""" """
if ps_parameter: if ps_parameter:
return grad return grad
if allreduce_filter: if allreduce_filter:
indices = allgather(grad.indices()) indices = allgather(grad.indices)
dout = allgather(grad.values()) dout = allgather(grad.values)
if mean: if mean:
degree = F.scalar_cast(degree, F.dtype(grad.values())) degree = F.scalar_cast(degree, F.dtype(grad.values))
cast_op = P.Cast() cast_op = P.Cast()
mul_op = P.Mul() mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout))) dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = IndexedSlices(indices, dout, grad.dense_shape()) grad = RowTensor(indices, dout, grad.dense_shape)
return grad return grad
...@@ -182,18 +182,18 @@ def _tensors_get_datatype(grad): ...@@ -182,18 +182,18 @@ def _tensors_get_datatype(grad):
return F.dtype(grad) return F.dtype(grad)
@_get_datatype.register("IndexedSlices") @_get_datatype.register("RowTensor")
def _tensors_get_datatype_with_sparse(grad): def _tensors_get_datatype_with_sparse(grad):
""" """
Acquire gradient datatype. Acquire gradient datatype.
Args: Args:
grad (IndexedSlices): The gradient before operation. grad (RowTensor): The gradient before operation.
Returns: Returns:
mstype, the datatype of gradient. mstype, the datatype of gradient.
""" """
return F.dtype(grad.values()) return F.dtype(grad.values)
_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") _cast_datatype = C.MultitypeFuncGraph("_cast_datatype")
...@@ -214,20 +214,20 @@ def _tensors_cast_datatype(datatype, grad): ...@@ -214,20 +214,20 @@ def _tensors_cast_datatype(datatype, grad):
return F.cast(grad, datatype) return F.cast(grad, datatype)
@_cast_datatype.register("TypeType", "IndexedSlices") @_cast_datatype.register("TypeType", "RowTensor")
def _tensors_cast_datatype_with_sparse(datatype, grad): def _tensors_cast_datatype_with_sparse(datatype, grad):
""" """
Cast gradient to datatype. Cast gradient to datatype.
Args: Args:
datatype (mstype): the destination datatype of gradient. datatype (mstype): the destination datatype of gradient.
grad (IndexedSlices): The gradient before operation. grad (RowTensor): The gradient before operation.
Returns: Returns:
IndexedSlices, the gradient after operation. RowTensor, the gradient after operation.
""" """
dout = F.cast(grad.values(), datatype) dout = F.cast(grad.values, datatype)
return IndexedSlices(grad.indices(), dout, grad.dense_shape()) return RowTensor(grad.indices, dout, grad.dense_shape)
class DistributedGradReducer(Cell): class DistributedGradReducer(Cell):
......
...@@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer ...@@ -18,7 +18,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, IndexedSlices from ...common import Tensor, RowTensor
from ...common.parameter import Parameter from ...common.parameter import Parameter
from ...ops import functional as F from ...ops import functional as F
from ...ops import composite as C from ...ops import composite as C
...@@ -35,11 +35,11 @@ reciprocal = P.Reciprocal() ...@@ -35,11 +35,11 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad)) return grad * F.cast(reciprocal(scale), F.dtype(grad))
@_grad_scale.register("Tensor", "IndexedSlices") @_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_indexed_slices(scale, grad): def tensor_grad_scale_row_tensor(scale, grad):
return IndexedSlices(grad.indices(), return RowTensor(grad.indices,
grad.values() * F.cast(reciprocal(scale), F.dtype(grad.values())), grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
grad.dense_shape()) grad.dense_shape)
_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus() grad_overflow = P.FloatStatus()
......
...@@ -27,7 +27,7 @@ from .grad_base import bprop_getters ...@@ -27,7 +27,7 @@ from .grad_base import bprop_getters
from ..primitive import constexpr from ..primitive import constexpr
from ... import context from ... import context
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import IndexedSlices from ...common.tensor import RowTensor
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
unsorted_segment_sum = P.UnsortedSegmentSum() unsorted_segment_sum = P.UnsortedSegmentSum()
...@@ -75,12 +75,12 @@ def dout_cast_number(dout, x): ...@@ -75,12 +75,12 @@ def dout_cast_number(dout, x):
dx = cast(dout, get_dtype(x)) dx = cast(dout, get_dtype(x))
return dx return dx
@dout_cast.register("IndexedSlices", "Tensor") @dout_cast.register("RowTensor", "Tensor")
def dout_cast_indexed_slices(dout, x): def dout_cast_row_tensor(dout, x):
cast = P.Cast() cast = P.Cast()
get_dtype = P.DType() get_dtype = P.DType()
values = cast(dout.values(), get_dtype(x)) values = cast(dout.values, get_dtype(x))
return IndexedSlices(dout.indices(), values, dout.dense_shape()) return RowTensor(dout.indices, values, dout.dense_shape)
@bprop_getters.register(P.Cast) @bprop_getters.register(P.Cast)
...@@ -240,7 +240,7 @@ def get_bprop_embedding_lookup(self): ...@@ -240,7 +240,7 @@ def get_bprop_embedding_lookup(self):
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
# Reshape the 'actual_dout' on device # Reshape the 'actual_dout' on device
actual_dout = reshape_op(dout, actual_dout_shape_changed) actual_dout = reshape_op(dout, actual_dout_shape_changed)
return IndexedSlices(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return RowTensor(new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop_sparse return bprop_sparse
...@@ -369,7 +369,7 @@ def get_bprop_sparse_gather_v2(self): ...@@ -369,7 +369,7 @@ def get_bprop_sparse_gather_v2(self):
values_shape = indices_size + x_tail_shp values_shape = indices_size + x_tail_shp
values = reshape(dout, values_shape) values = reshape(dout, values_shape)
indices = reshape(indices, indices_size) indices = reshape(indices, indices_size)
return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
if F.rank(dout) == 0: if F.rank(dout) == 0:
dout = P.ExpandDims()(dout, -1) dout = P.ExpandDims()(dout, -1)
if F.rank(indices) == 0: if F.rank(indices) == 0:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from .. import operations as P from .. import operations as P
from ...common.tensor import IndexedSlices from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, _GetTensorSlice, _MirrorOperator, ReduceOp,
...@@ -47,9 +47,9 @@ def get_bprop_all_reduce(self): ...@@ -47,9 +47,9 @@ def get_bprop_all_reduce(self):
if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce_grad(dout) dx = all_reduce_grad(dout)
else: else:
indices = all_gather(dout.indices()) indices = all_gather(dout.indices)
grad = all_gather(dout.values()) grad = all_gather(dout.values)
dx = IndexedSlices(indices, grad, dout.dense_shape()) dx = RowTensor(indices, grad, dout.dense_shape)
return (dx,) return (dx,)
else: else:
...@@ -60,12 +60,12 @@ def get_bprop_all_reduce(self): ...@@ -60,12 +60,12 @@ def get_bprop_all_reduce(self):
z = cast(z, dtype(dx)) z = cast(z, dtype(dx))
dx = mul(dx, z) dx = mul(dx, z)
else: else:
indices = all_gather(dout.indices()) indices = all_gather(dout.indices)
grad = all_gather(dout.values()) grad = all_gather(dout.values)
z = equal(x, out) z = equal(x, out)
z = cast(z, dtype(grad)) z = cast(z, dtype(grad))
grad = mul(grad, z) grad = mul(grad, z)
dx = IndexedSlices(indices, grad, dout.dense_shape()) dx = RowTensor(indices, grad, dout.dense_shape)
return (dx,) return (dx,)
return bprop return bprop
...@@ -195,19 +195,19 @@ def get_bprop_mirror_operator(self): ...@@ -195,19 +195,19 @@ def get_bprop_mirror_operator(self):
num = F.scalar_cast(dev_num, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else: else:
indices = all_gather(dout.indices()) indices = all_gather(dout.indices)
grad = all_gather(dout.values()) grad = all_gather(dout.values)
float_one = F.scalar_cast(1.0, F.dtype(grad)) float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad)) num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = IndexedSlices(indices, grad, dout.dense_shape()) dx = RowTensor(indices, grad, dout.dense_shape)
else: else:
if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.typeof(dout), mstype.tensor):
dx = all_reduce(dout) dx = all_reduce(dout)
else: else:
indices = all_gather(dout.indices()) indices = all_gather(dout.indices)
grad = all_gather(dout.values()) grad = all_gather(dout.values)
dx = IndexedSlices(indices, grad, dout.dense_shape()) dx = RowTensor(indices, grad, dout.dense_shape)
return (dx,) return (dx,)
return bprop return bprop
......
...@@ -152,10 +152,10 @@ shape_mul = Primitive("shape_mul") ...@@ -152,10 +152,10 @@ shape_mul = Primitive("shape_mul")
# a primitive to compare between tuple. # a primitive to compare between tuple.
stop_gradient = Primitive("stop_gradient") stop_gradient = Primitive("stop_gradient")
make_indexed_slices = Primitive('MakeIndexedSlices') make_row_tensor = Primitive('MakeRowTensor')
indexed_slices_get_values = Primitive('IndexedSlicesGetValues') row_tensor_get_values = Primitive('RowTensorGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') row_tensor_get_indices = Primitive('RowTensorGetIndices')
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
make_sparse_tensor = Primitive('MakeSparseTensor') make_sparse_tensor = Primitive('MakeSparseTensor')
sparse_tensor_get_values = Primitive('SparseTensorGetValues') sparse_tensor_get_values = Primitive('SparseTensorGetValues')
......
...@@ -389,8 +389,8 @@ class CheckBprop(PrimitiveWithInfer): ...@@ -389,8 +389,8 @@ class CheckBprop(PrimitiveWithInfer):
validator.check_value_type('grads', xshapes, (tuple,), tips) validator.check_value_type('grads', xshapes, (tuple,), tips)
validator.check_value_type('params', yshapes, (tuple,), tips) validator.check_value_type('params', yshapes, (tuple,), tips)
if len(xshapes) < len(yshapes): if len(xshapes) < len(yshapes):
raise TypeError(f"{tips}, the size of output should be {len(yshapes)}," raise ValueError(f"{tips}, the size of output should be {len(yshapes)},"
f" but got {len(xshapes)}.") f" but got {len(xshapes)}.")
checking_range = len(yshapes) checking_range = len(yshapes)
for i in range(checking_range): for i in range(checking_range):
xshape = xshapes[i] xshape = xshapes[i]
...@@ -398,8 +398,8 @@ class CheckBprop(PrimitiveWithInfer): ...@@ -398,8 +398,8 @@ class CheckBprop(PrimitiveWithInfer):
if not xshape or not yshape: if not xshape or not yshape:
continue continue
if xshape != yshape: if xshape != yshape:
raise TypeError(f"{tips}, the shape of {i}th output should be {yshape}," raise ValueError(f"{tips}, the shape of {i}th output should be {yshape},"
f" but got {xshape}.") f" but got {xshape}.")
return xshapes return xshapes
def infer_dtype(self, xdtypes, ydtypes): def infer_dtype(self, xdtypes, ydtypes):
...@@ -407,8 +407,8 @@ class CheckBprop(PrimitiveWithInfer): ...@@ -407,8 +407,8 @@ class CheckBprop(PrimitiveWithInfer):
validator.check_value_type('grads', xdtypes, (tuple,), tips) validator.check_value_type('grads', xdtypes, (tuple,), tips)
validator.check_value_type('params', ydtypes, (tuple,), tips) validator.check_value_type('params', ydtypes, (tuple,), tips)
if len(xdtypes) < len(ydtypes): if len(xdtypes) < len(ydtypes):
raise TypeError(f"{tips}, the size of output should be {len(ydtypes)}," raise ValueError(f"{tips}, the size of output should be {len(ydtypes)},"
f" but got {len(xdtypes)}.") f" but got {len(xdtypes)}.")
checking_range = len(ydtypes) checking_range = len(ydtypes)
for i in range(checking_range): for i in range(checking_range):
xdtype = xdtypes[i] xdtype = xdtypes[i]
......
...@@ -19,25 +19,16 @@ import pytest ...@@ -19,25 +19,16 @@ import pytest
import mindspore as ms import mindspore as ms
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Parameter from mindspore import Parameter, ParameterTuple
from mindspore import context from mindspore import context
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from .....mindspore_test_framework.utils.bprop_util import bprop
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def setup_module(module):
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def teardown_module(module):
context.set_context(device_target="Ascend")
class MulAdd(nn.Cell): class MulAdd(nn.Cell):
def __init__(self):
super(MulAdd, self).__init__()
def construct(self, x, y): def construct(self, x, y):
return 2 * x + y return 2 * x + y
...@@ -45,7 +36,9 @@ class MulAdd(nn.Cell): ...@@ -45,7 +36,9 @@ class MulAdd(nn.Cell):
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
return 2 * dout, 2 * y return 2 * dout, 2 * y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_mul_add(): def test_grad_mul_add():
mul_add = MulAdd() mul_add = MulAdd()
x = Tensor(1, dtype=ms.int32) x = Tensor(1, dtype=ms.int32)
...@@ -62,7 +55,9 @@ class InlineMulADD(nn.Cell): ...@@ -62,7 +55,9 @@ class InlineMulADD(nn.Cell):
def construct(self, x, y): def construct(self, x, y):
return self.mul_add(x, y) + x + self.param * y return self.mul_add(x, y) + x + self.param * y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_inline_mul_add(): def test_grad_inline_mul_add():
inline_mul_add = InlineMulADD() inline_mul_add = InlineMulADD()
x = Tensor(1, dtype=ms.int32) x = Tensor(1, dtype=ms.int32)
...@@ -83,7 +78,9 @@ class WithParameter(nn.Cell): ...@@ -83,7 +78,9 @@ class WithParameter(nn.Cell):
# In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
return self.param1 * self.param2 * dout, 2 * y return self.param1 * self.param2 * dout, 2 * y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_with_param(): def test_with_param():
with_param = WithParameter() with_param = WithParameter()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
...@@ -91,20 +88,21 @@ def test_with_param(): ...@@ -91,20 +88,21 @@ def test_with_param():
class WithNoBprop(nn.Cell): class WithNoBprop(nn.Cell):
def __init__(self):
super(WithNoBprop, self).__init__()
def construct(self, x, y): def construct(self, x, y):
return 2 * x + y return 2 * x + y
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_with_no_bprop(): def test_with_no_bprop():
with_no_bprop = WithNoBprop() with_no_bprop = WithNoBprop()
x = Tensor(1, dtype=ms.int32) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32) y = Tensor(2, dtype=ms.int32)
C.grad_all(with_no_bprop)(x, y) assert C.grad_all(with_no_bprop)(x, y) == (2, 1)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_in_bprop_1(): def test_grad_in_bprop_1():
class GradInBprop_1(nn.Cell): class GradInBprop_1(nn.Cell):
def __init__(self): def __init__(self):
...@@ -140,7 +138,9 @@ def test_grad_in_bprop_1(): ...@@ -140,7 +138,9 @@ def test_grad_in_bprop_1():
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_in_bprop_2(): def test_grad_in_bprop_2():
class GradInBprop_1(nn.Cell): class GradInBprop_1(nn.Cell):
def __init__(self): def __init__(self):
...@@ -179,7 +179,9 @@ def test_grad_in_bprop_2(): ...@@ -179,7 +179,9 @@ def test_grad_in_bprop_2():
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_in_bprop_3(): def test_grad_in_bprop_3():
class GradInBprop_1(nn.Cell): class GradInBprop_1(nn.Cell):
def __init__(self): def __init__(self):
...@@ -230,7 +232,9 @@ class OneInputBprop(nn.Cell): ...@@ -230,7 +232,9 @@ class OneInputBprop(nn.Cell):
def bprop(self, x, out, dout): def bprop(self, x, out, dout):
return (5 * x,) return (5 * x,)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_one_input_bprop(): def test_grad_one_input_bprop():
net = OneInputBprop() net = OneInputBprop()
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
...@@ -239,9 +243,6 @@ def test_grad_one_input_bprop(): ...@@ -239,9 +243,6 @@ def test_grad_one_input_bprop():
class TwoInput(nn.Cell): class TwoInput(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y): def construct(self, x, y):
return x * y return x * y
...@@ -258,12 +259,17 @@ class InlineBpropTwoInput(nn.Cell): ...@@ -258,12 +259,17 @@ class InlineBpropTwoInput(nn.Cell):
grads = C.grad_all(self.f)(x, y) grads = C.grad_all(self.f)(x, y)
return grads[0] * 2, grads[1] * 2 return grads[0] * 2, grads[1] * 2
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_inline_bprop_two_input(): def test_grad_inline_bprop_two_input():
net = InlineBpropTwoInput() net = InlineBpropTwoInput()
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
input2 = Tensor(np.ones([2, 2]).astype(np.float32)) input2 = Tensor(np.ones([2, 2]).astype(np.float32))
C.grad_all(net)(input1, input2) grads = C.grad_all(net)(input1, input2)
assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
assert len(grads) == 2
class TwoInputBprop(nn.Cell): class TwoInputBprop(nn.Cell):
...@@ -314,7 +320,9 @@ class InlineMutilTwoInputParameterCell(nn.Cell): ...@@ -314,7 +320,9 @@ class InlineMutilTwoInputParameterCell(nn.Cell):
output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y) output = self.f1(x, y) + self.f2(x, y) + self.f3(x, y) + self.f4(x, y)
return output return output
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_inline_bprop_multi_input(): def test_grad_inline_bprop_multi_input():
net = InlineMutilTwoInputParameterCell() net = InlineMutilTwoInputParameterCell()
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
...@@ -335,29 +343,54 @@ class MulAddWithParam(nn.Cell): ...@@ -335,29 +343,54 @@ class MulAddWithParam(nn.Cell):
def construct(self, x): def construct(self, x):
return self.mul_add(self.param, x) return self.mul_add(self.param, x)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_refkey_bprop(): def test_refkey_bprop():
net = MulAddWithParam() grad_by_list = C.GradOperation('get_by_list', get_all=True, get_by_list=True)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
self.weights = ParameterTuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x):
weights = self.weights
grads = grad_by_list(self.network, weights)(x)
return grads
network = GradWrap(MulAddWithParam())
input_data = Tensor(np.array([2, 2], np.float32)) input_data = Tensor(np.array([2, 2], np.float32))
grads = bprop(net, input_data, grads = network(input_data)
grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))),
wrt=['params', 'inputs'],
params=net.trainable_params())
assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all() assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all()
assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
class MulAddWithWrongOutputType(nn.Cell): class MulAddWithWrongOutputNum(nn.Cell):
def __init__(self): def construct(self, x, y):
super(MulAddWithWrongOutputType, self).__init__() return 2 * x + y
def bprop(self, x, y, out, dout):
return (2 * dout,)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_mul_add_with_wrong_output_num():
context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputNum()
with pytest.raises(TypeError):
C.grad_all(mul_add)(1, 2)
class MulAddWithWrongOutputType(nn.Cell):
def construct(self, x, y): def construct(self, x, y):
return 2 * x + y return 2 * x + y
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
return 2 * dout, 2 return 2 * dout, 2
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_mul_add_with_wrong_output_type(): def test_grad_mul_add_with_wrong_output_type():
context.set_context(check_bprop=True) context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputType() mul_add = MulAddWithWrongOutputType()
...@@ -376,7 +409,9 @@ class MulAddWithWrongOutputShape(nn.Cell): ...@@ -376,7 +409,9 @@ class MulAddWithWrongOutputShape(nn.Cell):
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
return 2, self.ones return 2, self.ones
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_mul_add_with_wrong_output_shape(): def test_grad_mul_add_with_wrong_output_shape():
context.set_context(check_bprop=True) context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputShape() mul_add = MulAddWithWrongOutputShape()
......
...@@ -606,14 +606,14 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { ...@@ -606,14 +606,14 @@ TEST_F(TestOptLib, test_adjust_allreduce_mul_add) {
ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); ASSERT_TRUE(CheckOpt(before2r, after2, patterns));
} }
TEST_F(TestOptLib, test_indexed_slices) { TEST_F(TestOptLib, test_row_tensor) {
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_indices"); FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "before_get_indices");
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_indices"); FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_row_tensor", "after_get_indices");
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_values"); FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_row_tensor", "before_get_values");
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_values"); FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_row_tensor", "after_get_values");
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "before_get_dense_shape"); FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "before_get_dense_shape");
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_indexed_slices", "after_get_dense_shape"); FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_row_tensor", "after_get_dense_shape");
auto patterns = std::vector<SubstitutionPtr>({irpass.indexed_slices_eliminate_}); auto patterns = std::vector<SubstitutionPtr>({irpass.row_tensor_eliminate_});
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
......
...@@ -1130,17 +1130,17 @@ def test_adjust_allreduce_mul_add(tag): ...@@ -1130,17 +1130,17 @@ def test_adjust_allreduce_mul_add(tag):
return fns[tag] return fns[tag]
def test_indexed_slices(tag): def test_row_tensor(tag):
""" test_add_zero """ """ test_add_zero """
fns = FnDict() fns = FnDict()
make_indexed_slices = Primitive('MakeIndexedSlices') make_row_tensor = Primitive('MakeRowTensor')
indexed_slices_get_values = Primitive('IndexedSlicesGetValues') row_tensor_get_values = Primitive('RowTensorGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') row_tensor_get_indices = Primitive('RowTensorGetIndices')
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
@fns @fns
def before_get_indices(x, y, z): def before_get_indices(x, y, z):
return indexed_slices_get_indices(make_indexed_slices(x, y, z)) return row_tensor_get_indices(make_row_tensor(x, y, z))
@fns @fns
def after_get_indices(x, y, z): def after_get_indices(x, y, z):
...@@ -1148,7 +1148,7 @@ def test_indexed_slices(tag): ...@@ -1148,7 +1148,7 @@ def test_indexed_slices(tag):
@fns @fns
def before_get_values(x, y, z): def before_get_values(x, y, z):
return indexed_slices_get_values(make_indexed_slices(x, y, z)) return row_tensor_get_values(make_row_tensor(x, y, z))
@fns @fns
def after_get_values(x, y, z): def after_get_values(x, y, z):
...@@ -1156,7 +1156,7 @@ def test_indexed_slices(tag): ...@@ -1156,7 +1156,7 @@ def test_indexed_slices(tag):
@fns @fns
def before_get_dense_shape(x, y, z): def before_get_dense_shape(x, y, z):
return indexed_slices_get_dense_shape(make_indexed_slices(x, y, z)) return row_tensor_get_dense_shape(make_row_tensor(x, y, z))
@fns @fns
def after_get_dense_shape(x, y, z): def after_get_dense_shape(x, y, z):
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
@File : test_indexed_slices.py @File : test_row_tensor.py
@Author: @Author:
@Date : 2020-06-08 @Date : 2020-06-08
@Desc : test mindspore indexed_slices's operation @Desc : test mindspore row_tensor's operation
""" """
import numpy as np import numpy as np
import pytest import pytest
...@@ -29,7 +29,7 @@ from mindspore.ops import operations as P ...@@ -29,7 +29,7 @@ from mindspore.ops import operations as P
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.ops._grad.grad_base import bprop_getters from mindspore.ops._grad.grad_base import bprop_getters
from mindspore import Tensor, IndexedSlices, context from mindspore import Tensor, RowTensor, context
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
...@@ -122,7 +122,7 @@ def get_bprop_sparse_gather_v2(self): ...@@ -122,7 +122,7 @@ def get_bprop_sparse_gather_v2(self):
values_shape = indices_size + x_tail_shp values_shape = indices_size + x_tail_shp
values = reshape(dout, values_shape) values = reshape(dout, values_shape)
indices = reshape(indices, indices_size) indices = reshape(indices, indices_size)
return IndexedSlices(indices, values, x_shp), zeros_like(indices), zeros_like(axis) return RowTensor(indices, values, x_shp), zeros_like(indices), zeros_like(axis)
if F.rank(dout) == 0: if F.rank(dout) == 0:
dout = P.ExpandDims()(dout, -1) dout = P.ExpandDims()(dout, -1)
if F.rank(indices) == 0: if F.rank(indices) == 0:
...@@ -142,10 +142,10 @@ def get_bprop_sparse_gather_v2(self): ...@@ -142,10 +142,10 @@ def get_bprop_sparse_gather_v2(self):
adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map") adam_opt_for_map = C.MultitypeFuncGraph("adam_opt_for_map")
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "IndexedSlices", "Bool") "Tensor", "Tensor", "Tensor", "RowTensor", "Bool")
def _update_run_op_for_map_indexed_slices(beta1, beta2, eps, lr, weight_decay_tensor, param, def _update_run_op_for_map_row_tensor(beta1, beta2, eps, lr, weight_decay_tensor, param,
m, v, gradient, decay_flag): m, v, gradient, decay_flag):
return gradient.values() return gradient.values
@adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", @adam_opt_for_map.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor", "Bool") "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
...@@ -219,35 +219,35 @@ class AdamWeightDecaySparse(Optimizer): ...@@ -219,35 +219,35 @@ class AdamWeightDecaySparse(Optimizer):
return updated_velocity return updated_velocity
def test_indexed_slices_make_indexed_slices(): def test_row_tensor_make_row_tensor():
class MakeIndexedSlices(nn.Cell): class MakeRowTensor(nn.Cell):
def __init__(self): def __init__(self):
super(MakeIndexedSlices, self).__init__() super(MakeRowTensor, self).__init__()
self.dense_shape = (3, 2) self.dense_shape = (3, 2)
def construct(self, indices, values): def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),) ret = (RowTensor(indices, values, self.dense_shape),)
return ret[0] return ret[0]
indices = Tensor([1, 2]) indices = Tensor([1, 2])
values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
MakeIndexedSlices()(indices, values) MakeRowTensor()(indices, values)
class IndexedSlicesGetAttr(nn.Cell): class RowTensorGetAttr(nn.Cell):
def __init__(self, dense_shape): def __init__(self, dense_shape):
super(IndexedSlicesGetAttr, self).__init__() super(RowTensorGetAttr, self).__init__()
self.dense_shape = dense_shape self.dense_shape = dense_shape
def construct(self, indices, values): def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape) x = RowTensor(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape() return x.values, x.indices, x.dense_shape
def test_indexed_slices_attr(): def test_row_tensor_attr():
indices = Tensor([0]) indices = Tensor([0])
values = Tensor([[1, 2]], dtype=ms.float32) values = Tensor([[1, 2]], dtype=ms.float32)
IndexedSlicesGetAttr((3, 2))(indices, values) RowTensorGetAttr((3, 2))(indices, values)
def test_indexed_slices_sparse_gatherv2_grad_all(): def test_row_tensor_sparse_gatherv2_grad_all():
grad_all = C.GradOperation('get_all', get_all=True) grad_all = C.GradOperation('get_all', get_all=True)
class GradWrap(nn.Cell): class GradWrap(nn.Cell):
def __init__(self, network): def __init__(self, network):
...@@ -255,7 +255,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): ...@@ -255,7 +255,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
self.network = network self.network = network
def construct(self, x, y): def construct(self, x, y):
grad = grad_all(self.network)(x, y) grad = grad_all(self.network)(x, y)
return grad[0].indices(), grad[0].values(), grad[0].dense_shape() return grad[0].indices, grad[0].values, grad[0].dense_shape
class SparseGatherV2(nn.Cell): class SparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(SparseGatherV2, self).__init__() super(SparseGatherV2, self).__init__()
...@@ -268,7 +268,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): ...@@ -268,7 +268,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
GradWrap(SparseGatherV2())(params, indices) GradWrap(SparseGatherV2())(params, indices)
def test_indexed_slices_sparse_gatherv2_grad_with_pram(): def test_row_tensor_sparse_gatherv2_grad_with_pram():
grad_by_list = C.GradOperation('get_by_list', get_by_list=True) grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
class GradWrap(nn.Cell): class GradWrap(nn.Cell):
def __init__(self, network): def __init__(self, network):
...@@ -279,7 +279,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): ...@@ -279,7 +279,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
weights = self.weights weights = self.weights
grad = grad_by_list(self.network, weights)(x) grad = grad_by_list(self.network, weights)(x)
x = grad[0] x = grad[0]
return x.values(), x.indices(), x.dense_shape() return x.values, x.indices, x.dense_shape
class SparseGatherV2(nn.Cell): class SparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(SparseGatherV2, self).__init__() super(SparseGatherV2, self).__init__()
...@@ -293,7 +293,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): ...@@ -293,7 +293,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
network(indices) network(indices)
def test_indexed_slices_env_get(): def test_row_tensor_env_get():
class Loss(nn.Cell): class Loss(nn.Cell):
def __init__(self): def __init__(self):
super(Loss, self).__init__() super(Loss, self).__init__()
...@@ -321,7 +321,7 @@ def test_indexed_slices_env_get(): ...@@ -321,7 +321,7 @@ def test_indexed_slices_env_get():
train_network(inputs, label) train_network(inputs, label)
def test_indexed_slices_model_train(): def test_row_tensor_model_train():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, in_features, out_features): def __init__(self, in_features, out_features):
super(Net, self).__init__() super(Net, self).__init__()
...@@ -347,76 +347,76 @@ def test_indexed_slices_model_train(): ...@@ -347,76 +347,76 @@ def test_indexed_slices_model_train():
model.train(2, dataset, dataset_sink_mode=False) model.train(2, dataset, dataset_sink_mode=False)
def test_indexed_slices_values_dim_greater_than_dense_shape_dim(): def test_row_tensor_values_dim_greater_than_dense_shape_dim():
indices = Tensor(np.array([0, 1], dtype=np.int32)) indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4, 5).astype(np.float32)) values = Tensor(np.random.randn(2, 4, 5).astype(np.float32))
dense_shape = (3, 4) dense_shape = (3, 4)
with pytest.raises(TypeError): with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values) RowTensorGetAttr(dense_shape)(indices, values)
def test_indexed_slices_values_dim_less_than_dense_shape_dim(): def test_row_tensor_values_dim_less_than_dense_shape_dim():
indices = Tensor(np.array([0, 1], dtype=np.int32)) indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4).astype(np.float32)) values = Tensor(np.random.randn(2, 4).astype(np.float32))
dense_shape = (3, 4, 5) dense_shape = (3, 4, 5)
with pytest.raises(TypeError): with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values) RowTensorGetAttr(dense_shape)(indices, values)
def test_indexed_slices_value_and_dense_shape_illegal(): def test_row_tensor_value_and_dense_shape_illegal():
indices = Tensor(np.array([0, 1], dtype=np.int32)) indices = Tensor(np.array([0, 1], dtype=np.int32))
values = Tensor(np.random.randn(2, 4).astype(np.float32)) values = Tensor(np.random.randn(2, 4).astype(np.float32))
dense_shape = (3, 5) dense_shape = (3, 5)
with pytest.raises(TypeError): with pytest.raises(TypeError):
IndexedSlicesGetAttr(dense_shape)(indices, values) RowTensorGetAttr(dense_shape)(indices, values)
class IndexedSlicesValuesDouble(nn.Cell): class RowTensorValuesDouble(nn.Cell):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def construct(self, x): def construct(self, x):
indices = x.indices() indices = x.indices
values = x.values() * 2 values = x.values * 2
dense_shape = x.dense_shape() dense_shape = x.dense_shape
return IndexedSlices(indices, values, dense_shape) return RowTensor(indices, values, dense_shape)
class IndexedSlicesValuesAdd2(nn.Cell): class RowTensorValuesAdd2(nn.Cell):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def construct(self, x): def construct(self, x):
indices = x.indices() indices = x.indices
values = x.values() + 2 values = x.values + 2
dense_shape = x.dense_shape() dense_shape = x.dense_shape
return IndexedSlices(indices, values, dense_shape) return RowTensor(indices, values, dense_shape)
class IndexedSlicesWithControlIf(nn.Cell): class RowTensorWithControlIf(nn.Cell):
def __init__(self, dense_shape): def __init__(self, dense_shape):
super().__init__() super().__init__()
self.op1 = IndexedSlicesValuesDouble() self.op1 = RowTensorValuesDouble()
self.op2 = IndexedSlicesValuesAdd2() self.op2 = RowTensorValuesAdd2()
self.dense_shape = dense_shape self.dense_shape = dense_shape
def construct(self, a, b, indices, values): def construct(self, a, b, indices, values):
x = IndexedSlices(indices, values, self.dense_shape) x = RowTensor(indices, values, self.dense_shape)
if a > b: if a > b:
x = self.op1(x) x = self.op1(x)
else: else:
x = self.op2(x) x = self.op2(x)
return x.indices(), x.values() return x.indices, x.values
def test_indexed_slices_with_control_flow_if(): def test_row_tensor_with_control_flow_if():
a = Tensor(np.array(0).astype(np.int32)) a = Tensor(np.array(0).astype(np.int32))
b = Tensor(np.array(2).astype(np.int32)) b = Tensor(np.array(2).astype(np.int32))
indices = Tensor(np.array([0, 2]).astype(np.int32)) indices = Tensor(np.array([0, 2]).astype(np.int32))
values = Tensor(np.ones([2, 2]).astype(np.float32)) values = Tensor(np.ones([2, 2]).astype(np.float32))
dense_shape = (5, 2) dense_shape = (5, 2)
net = IndexedSlicesWithControlIf(dense_shape) net = RowTensorWithControlIf(dense_shape)
net(a, b, indices, values) net(a, b, indices, values)
......
...@@ -52,7 +52,7 @@ def test_sparse_tensor_attr(): ...@@ -52,7 +52,7 @@ def test_sparse_tensor_attr():
self.dense_shape = (3, 4) self.dense_shape = (3, 4)
def construct(self, indices, values): def construct(self, indices, values):
x = SparseTensor(indices, values, self.dense_shape) x = SparseTensor(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape() return x.values, x.indices, x.dense_shape
indices = Tensor([[0, 1], [1, 2]]) indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([1, 2], dtype=ms.float32)
......
...@@ -175,7 +175,7 @@ def test_bprop_with_wrong_output_num(): ...@@ -175,7 +175,7 @@ def test_bprop_with_wrong_output_num():
def construct(self, x, y): def construct(self, x, y):
return BpropWithWrongOutputNum()(x, y) return BpropWithWrongOutputNum()(x, y)
with pytest.raises(TypeError): with pytest.raises(ValueError):
C.grad_all(BpropWithWrongOutputNumCell())(1, 2) C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
def test_bprop_with_wrong_output_type(): def test_bprop_with_wrong_output_type():
...@@ -247,7 +247,7 @@ def test_bprop_with_wrong_output_shape(): ...@@ -247,7 +247,7 @@ def test_bprop_with_wrong_output_shape():
def construct(self, x): def construct(self, x):
return BpropWithWrongOutputShape()(x) return BpropWithWrongOutputShape()(x)
with pytest.raises(TypeError): with pytest.raises(ValueError):
net = BpropWithWrongOutputShapeCell() net = BpropWithWrongOutputShapeCell()
net.set_grad() net.set_grad()
C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
""" """
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context, Tensor, IndexedSlices, SparseTensor from mindspore import context, Tensor, RowTensor, SparseTensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True)
...@@ -36,18 +36,18 @@ class GradWrap(nn.Cell): ...@@ -36,18 +36,18 @@ class GradWrap(nn.Cell):
return grad return grad
def test_indexed_slices_attr(): def test_row_tensor_attr():
class IndexedSlicesGetAttr(nn.Cell): class RowTensorGetAttr(nn.Cell):
def __init__(self, dense_shape): def __init__(self, dense_shape):
super(IndexedSlicesGetAttr, self).__init__() super(RowTensorGetAttr, self).__init__()
self.dense_shape = dense_shape self.dense_shape = dense_shape
def construct(self, indices, values): def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape) x = RowTensor(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape() return x.values, x.indices, x.dense_shape
indices = Tensor([0]) indices = Tensor([0])
values = Tensor([[1, 2]], dtype=ms.float32) values = Tensor([[1, 2]], dtype=ms.float32)
IndexedSlicesGetAttr((3, 2))(indices, values) RowTensorGetAttr((3, 2))(indices, values)
GradWrap(IndexedSlicesGetAttr((3, 2)))(indices, values) GradWrap(RowTensorGetAttr((3, 2)))(indices, values)
def test_sparse_tensor_attr(): def test_sparse_tensor_attr():
...@@ -57,7 +57,7 @@ def test_sparse_tensor_attr(): ...@@ -57,7 +57,7 @@ def test_sparse_tensor_attr():
self.dense_shape = (3, 4) self.dense_shape = (3, 4)
def construct(self, indices, values): def construct(self, indices, values):
x = SparseTensor(indices, values, self.dense_shape) x = SparseTensor(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape() return x.values, x.indices, x.dense_shape
indices = Tensor([[0, 1], [1, 2]]) indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([1, 2], dtype=ms.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册