提交 4d4e23fd 编写于 作者: P panyifeng

Add bprop for sparse_tensor

上级 abcee8e5
...@@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap ...@@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* │ └── MapPy * │ └── MapPy
* ├── Tail * ├── Tail
* ├── MakeTupleGradient * ├── MakeTupleGradient
* ├── MakeListGradient
* ├── GradOperation * ├── GradOperation
* └── TupleAdd * └── TupleAdd
*/ */
...@@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ ...@@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
// do nothing // do nothing
} else if (meta_func_graph->isa<prim::MakeTupleGradient>()) { } else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
// do nothing // do nothing
} else if (meta_func_graph->isa<prim::MakeListGradient>()) {
// do nothing
} else if (meta_func_graph->isa<prim::TupleAdd>()) { } else if (meta_func_graph->isa<prim::TupleAdd>()) {
// do nothing // do nothing
} else if (meta_func_graph->isa<prim::TupleSlice>()) { } else if (meta_func_graph->isa<prim::TupleSlice>()) {
......
...@@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg ...@@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
return fg; return fg;
} }
FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
int list_size = SizeToInt(args_spec_list.size());
std::ostringstream ss;
ss << "▶make_list_" << list_size;
FuncGraphPtr fg = std::make_shared<FuncGraph>();
fg->debug_info()->set_name(ss.str());
std::vector<AnfNodePtr> params;
params.push_back(NewValueNode(prim::kPrimMakeList));
for (int i = 0; i < list_size; ++i) {
params.push_back(fg->add_parameter());
}
// make fprob first result, maketuple's forward result.
AnfNodePtr out = fg->NewCNode(params);
// make fprob second result, maketuple's backward function.
FuncGraphPtr b = std::make_shared<FuncGraph>();
ss.clear();
ss << "◀make_list_" << list_size;
b->debug_info()->set_name(ss.str());
AnfNodePtr dout = b->add_parameter();
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
grads.push_back(NewValueNode(newenv));
for (int i = 0; i < list_size; ++i) {
grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
}
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
b->set_output(b->NewCNode(grads));
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
return fg;
}
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
if (get_by_list) { if (get_by_list) {
......
...@@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph { ...@@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph {
}; };
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
class MakeListGradient : public MetaFuncGraph {
public:
explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
~MakeListGradient() override = default;
MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
};
using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
class GradOperation : public MetaFuncGraph { class GradOperation : public MetaFuncGraph {
public: public:
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
......
...@@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi ...@@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
auto elem = GetValue<int>(e); auto elem = GetValue<int>(e);
return elem; return elem;
}); });
if (IntToSize(indices_shp[1]) != dense_shape_vec.size()) {
MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices "
<< indices_shp[1] << ", but got " << dense_shape_vec.size();
}
for (auto dense_shape_elem : dense_shape_vec) { for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) { if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
......
...@@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { ...@@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
return meta; return meta;
} }
if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
bprop_registry_meta_[prim::kPrimMakeList] = meta;
return meta;
}
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
} }
...@@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R ...@@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return fprop; return fprop;
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
return nullptr; return nullptr;
} else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
return nullptr;
} }
FuncGraphPtr bprop_fg = nullptr; FuncGraphPtr bprop_fg = nullptr;
......
...@@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) { ...@@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractAttribute &item) { return item.second; });
return std::make_shared<AbstractTuple>(baselist); return std::make_shared<AbstractTuple>(baselist);
} }
return nullptr;
}
static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
if (t == nullptr) {
return nullptr;
}
if (t->isa<AbstractList>()) { if (t->isa<AbstractList>()) {
auto abs_list = dyn_cast<AbstractList>(t); auto abs_list = dyn_cast<AbstractList>(t);
return std::make_shared<AbstractTuple>(abs_list->elements()); return std::make_shared<AbstractTuple>(abs_list->elements());
...@@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr ...@@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
new_node = EraseMakeKeywordArgNode(cnode); new_node = EraseMakeKeywordArgNode(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
new_node = EraseExtractKeywordArg(cnode); new_node = EraseExtractKeywordArg(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { }
if (new_node != nullptr) {
new_node->set_abstract(node->abstract());
MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
(void)manager->Replace(node, new_node);
changed = true;
}
}
for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract());
if (ret) {
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
<< ret->ToString();
node->set_abstract(ret);
changed = true;
}
}
return changed;
}
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
bool changed = false;
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes();
for (auto &node : all_node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
AnfNodePtr new_node = nullptr;
if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
new_node = ConvertMakeListToMakeTuple(cnode); new_node = ConvertMakeListToMakeTuple(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
new_node = ConvertListGetItemToTupleGetItem(cnode); new_node = ConvertListGetItemToTupleGetItem(cnode);
...@@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr ...@@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
} }
for (auto &node : manager->all_nodes()) { for (auto &node : manager->all_nodes()) {
auto ret = Reabs(node->abstract()); auto ret = AdaptAbs(node->abstract());
if (ret) { if (ret) {
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
<< ret->ToString(); << ret->ToString();
......
...@@ -32,6 +32,7 @@ namespace opt { ...@@ -32,6 +32,7 @@ namespace opt {
// Remove the class type from graphs // Remove the class type from graphs
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
// Remove most uses of tuples from the graph // Remove most uses of tuples from the graph
// tuples that are returned will be kept // tuples that are returned will be kept
......
...@@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { ...@@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
return true; return true;
} }
bool CleanListPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
bool changed = opt::CleanList(func_graph, res->manager());
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
if (changed) {
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
}
res->set_args_spec(args_spec);
return true;
}
namespace { namespace {
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = opt::OptPassConfig({ opt::OptPassConfig a_1 = opt::OptPassConfig({
...@@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Safe inlining // Safe inlining
irpass.inline_, irpass.inline_,
irpass.sparse_tensor_eliminate_,
}); });
opt::OptPassConfig a_2 = opt::OptPassConfig({ opt::OptPassConfig a_2 = opt::OptPassConfig({
irpass.merge_addn_, irpass.merge_addn_,
...@@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { ...@@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_, irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_, irpass.indexed_slices_eliminate_,
irpass.sparse_tensor_eliminate_,
}); });
OptPassGroupMap map({ OptPassGroupMap map({
{"b_1", b_1}, {"b_1", b_1},
...@@ -322,18 +340,22 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { ...@@ -322,18 +340,22 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
return true; return true;
} }
std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup}, std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup},
{"clean_list", CleanListPass},
{"opt_b", OptPassBGroup}, {"opt_b", OptPassBGroup},
{"cconv", CconvPass}, {"cconv", CconvPass},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}}; {"add_control_depend", AddControlDependPass}};
std::vector<PassItem> kGePasses = { std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, {"clean_list", CleanListPass},
{"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, {"opt_b", OptPassBGroup},
{"add_control_depend", AddControlDependPass},
{"opt_control", ControlGroup},
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}}; {"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "abstract/utils.h" #include "abstract/utils.h"
#include "debug/trace.h" #include "debug/trace.h"
#include "utils/context/ms_context.h"
namespace mindspore { namespace mindspore {
namespace abstract { namespace abstract {
...@@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg ...@@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList bparams; AbstractBasePtrList bparams;
bparams.push_back(SensitivityTransform(orig_func_)); bparams.push_back(SensitivityTransform(orig_func_));
(void)std::transform( auto context = MsContext::GetInstance();
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), MS_EXCEPTION_IF_NULL(context);
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); bool enable_sparse = context->enable_sparse();
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
return std::make_shared<AbstractUndetermined>();
}
return SensitivityTransform(arg_spec);
});
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams); AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
AbstractFunctionPtr bprop = AbstractFunctionPtr bprop =
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final); std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
......
...@@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout): ...@@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout):
"""Backpropagator for primitive `tuple_getitem`.""" """Backpropagator for primitive `tuple_getitem`."""
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
@bprops.register("list_getitem")
def bprop_list_getitem(data, idx, out, dout):
"""Backpropagator for primitive `list_getitem`."""
return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
@bprops.register("identity") @bprops.register("identity")
def bprop_identity(x, out, dout): def bprop_identity(x, out, dout):
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from functools import reduce from functools import reduce
import numpy as np import numpy as np
import mindspore as ms
from mindspore.ops import _selected_grad_ops as SG from mindspore.ops import _selected_grad_ops as SG
from .. import functional as F from .. import functional as F
from .. import operations as P from .. import operations as P
...@@ -33,6 +34,7 @@ shape_op = P.Shape() ...@@ -33,6 +34,7 @@ shape_op = P.Shape()
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
reshape = P.Reshape() reshape = P.Reshape()
tile = P.Tile() tile = P.Tile()
is_sub_class = P.IsSubClass()
def binop_grad_common(x, y, dx, dy): def binop_grad_common(x, y, dx, dy):
...@@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self): ...@@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self):
"""Generate bprop for AddN""" """Generate bprop for AddN"""
def bprop(x, out, dout): def bprop(x, out, dout):
if is_sub_class(F.typeof(x), ms.list_):
dx = []
for _ in range(len(x)):
dx.append(dout)
return (dx,)
dx = () dx = ()
for _ in range(len(x)): for _ in range(len(x)):
dx = dx + (dout,) dx = dx + (dout,)
......
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.composite as C
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
...@@ -45,3 +46,17 @@ def test_net(): ...@@ -45,3 +46,17 @@ def test_net():
add = Net() add = Net()
output = add(x, y) output = add(x, y)
assert output == expect assert output == expect
def test_grad_addn_with_list():
grad_op = C.GradOperation('get_all', get_all=True)
class AddN(nn.Cell):
def __init__(self):
super().__init__()
self.add_n = P.AddN()
def construct(self, a, b):
return self.add_n([a, b])
inp = Tensor(np.ones([128, 96]).astype(np.float32))
grad_op(AddN())(inp, inp)
...@@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): ...@@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
self.network = network self.network = network
def construct(self, x, y): def construct(self, x, y):
grad = grad_all(self.network)(x, y) grad = grad_all(self.network)(x, y)
return grad, grad[0], grad[1] return grad[0].indices(), grad[0].values(), grad[0].dense_shape()
class SparseGatherV2(nn.Cell): class SparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(SparseGatherV2, self).__init__() super(SparseGatherV2, self).__init__()
...@@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): ...@@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
weights = self.weights weights = self.weights
grad = grad_by_list(self.network, weights)(x) grad = grad_by_list(self.network, weights)(x)
x = grad[0] x = grad[0]
return x, x.values(), x.indices(), x.dense_shape() return x.values(), x.indices(), x.dense_shape()
class SparseGatherV2(nn.Cell): class SparseGatherV2(nn.Cell):
def __init__(self): def __init__(self):
super(SparseGatherV2, self).__init__() super(SparseGatherV2, self).__init__()
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
@Date : 2020-07-16 @Date : 2020-07-16
@Desc : test mindspore sparse_tensor's operation @Desc : test mindspore sparse_tensor's operation
""" """
import numpy as np
import pytest
import mindspore as ms import mindspore as ms
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import composite as C from mindspore.ops import composite as C
...@@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context ...@@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
def test_sparse_tensor_make_sparse_tensor():
class MakeSparseTensor(nn.Cell): class MakeSparseTensor(nn.Cell):
def __init__(self): def __init__(self, dense_shape):
super(MakeSparseTensor, self).__init__() super(MakeSparseTensor, self).__init__()
self.dense_shape = (3, 4) self.dense_shape = dense_shape
def construct(self, indices, values): def construct(self, indices, values):
ret = (SparseTensor(indices, values, self.dense_shape),) ret = (SparseTensor(indices, values, self.dense_shape),)
return ret[0] return ret[0]
def test_sparse_tensor_make_sparse_tensor():
indices = Tensor([[0, 1], [1, 2]]) indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([1, 2], dtype=ms.float32)
MakeSparseTensor()(indices, values) MakeSparseTensor((3, 4))(indices, values)
def test_sparse_tensor_attr(): def test_sparse_tensor_attr():
...@@ -59,3 +65,20 @@ def test_sparse_tensor_attr(): ...@@ -59,3 +65,20 @@ def test_sparse_tensor_attr():
indices = Tensor([[0, 1], [1, 2]]) indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([1, 2], dtype=ms.float32)
SparseTensorGetAttr()(indices, values) SparseTensorGetAttr()(indices, values)
grad_op(SparseTensorGetAttr())(indices, values)
def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim():
indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32))
values = Tensor(np.array([100, 200], dtype=np.float32))
dense_shape = (2, 2)
with pytest.raises(TypeError):
MakeSparseTensor(dense_shape)(indices, values)
def test_sparse_tensor_indices_dim_less_than_dense_shape_dim():
indices = Tensor(np.array([[0, 0], [0, 1]], dtype=np.int32))
values = Tensor(np.array([100, 200], dtype=np.float32))
dense_shape = (2, 2, 2)
with pytest.raises(TypeError):
MakeSparseTensor(dense_shape)(indices, values)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册