Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
598bfa02
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
598bfa02
编写于
7月 27, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sparse operators
上级
61867c73
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
291 addition
and
43 deletion
+291
-43
mindspore/ccsrc/frontend/optimizer/clean.cc
mindspore/ccsrc/frontend/optimizer/clean.cc
+59
-2
mindspore/ccsrc/frontend/optimizer/clean.h
mindspore/ccsrc/frontend/optimizer/clean.h
+1
-1
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+4
-4
mindspore/nn/__init__.py
mindspore/nn/__init__.py
+3
-2
mindspore/nn/sparse/__init__.py
mindspore/nn/sparse/__init__.py
+22
-0
mindspore/nn/sparse/sparse.py
mindspore/nn/sparse/sparse.py
+54
-0
mindspore/ops/_grad/__init__.py
mindspore/ops/_grad/__init__.py
+1
-1
mindspore/ops/_grad/grad_implementations.py
mindspore/ops/_grad/grad_implementations.py
+1
-0
mindspore/ops/_grad/grad_sparse.py
mindspore/ops/_grad/grad_sparse.py
+58
-0
mindspore/ops/composite/multitype_ops/ones_like_impl.py
mindspore/ops/composite/multitype_ops/ones_like_impl.py
+10
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+3
-1
mindspore/ops/operations/sparse_ops.py
mindspore/ops/operations/sparse_ops.py
+55
-0
tests/ut/python/ir/test_sparse_tensor.py
tests/ut/python/ir/test_sparse_tensor.py
+18
-9
tests/ut/python/pipeline/parse/test_cell_bprop.py
tests/ut/python/pipeline/parse/test_cell_bprop.py
+2
-23
未找到文件。
mindspore/ccsrc/frontend/optimizer/clean.cc
浏览文件 @
598bfa02
...
...
@@ -32,9 +32,11 @@ namespace opt {
using
mindspore
::
abstract
::
AbstractAttribute
;
using
mindspore
::
abstract
::
AbstractClass
;
using
mindspore
::
abstract
::
AbstractDictionary
;
using
mindspore
::
abstract
::
AbstractIndexedSlices
;
using
mindspore
::
abstract
::
AbstractJTagged
;
using
mindspore
::
abstract
::
AbstractList
;
using
mindspore
::
abstract
::
AbstractScalar
;
using
mindspore
::
abstract
::
AbstractSparseTensor
;
using
mindspore
::
abstract
::
AbstractTuple
;
using
mindspore
::
abstract
::
AbstractUndetermined
;
...
...
@@ -73,6 +75,19 @@ static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
return
std
::
make_shared
<
AbstractTuple
>
(
abs_list
->
elements
());
}
if
(
t
->
isa
<
AbstractSparseTensor
>
())
{
auto
abs_sparse
=
dyn_cast
<
AbstractSparseTensor
>
(
t
);
std
::
vector
<
AbstractBasePtr
>
abstract_list
{
abs_sparse
->
indices
(),
abs_sparse
->
values
(),
abs_sparse
->
dense_shape
()};
return
std
::
make_shared
<
AbstractTuple
>
(
abstract_list
);
}
if
(
t
->
isa
<
AbstractIndexedSlices
>
())
{
auto
abs_indexed_slices
=
dyn_cast
<
AbstractIndexedSlices
>
(
t
);
std
::
vector
<
AbstractBasePtr
>
abstract_list
{
abs_indexed_slices
->
indices
(),
abs_indexed_slices
->
values
(),
abs_indexed_slices
->
dense_shape
()};
return
std
::
make_shared
<
AbstractTuple
>
(
abstract_list
);
}
return
nullptr
;
}
...
...
@@ -389,14 +404,44 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
return
changed
;
}
bool
CleanList
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
)
{
AnfNodePtr
ConvertMakeSparseToMakeTuple
(
const
CNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
->
func_graph
());
std
::
vector
<
AnfNodePtr
>
inputs
;
inputs
.
emplace_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
// Inputs of node should be [make_sparse, indices, values, dense_shape], so offset by 1 to get items;
(
void
)
inputs
.
insert
(
inputs
.
end
(),
node
->
inputs
().
begin
()
+
1
,
node
->
inputs
().
end
());
return
node
->
func_graph
()
->
NewCNode
(
inputs
);
}
AnfNodePtr
ConvertSparseGetAttrToTupleGetItem
(
const
CNodePtr
&
node
,
const
int
&
index
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
node
->
func_graph
());
const
auto
&
inputs
=
node
->
inputs
();
// Inputs should be [spase_getattr, sparse]
if
(
inputs
.
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Node's input number < 2."
;
}
AnfNodePtr
sparse
=
inputs
[
1
];
MS_EXCEPTION_IF_NULL
(
sparse
);
auto
cons_node
=
NewValueNode
(
index
);
AbstractBasePtr
aptr
=
std
::
make_shared
<
AbstractScalar
>
(
std
::
make_shared
<
Int32Imm
>
(
index
));
cons_node
->
set_abstract
(
aptr
);
return
node
->
func_graph
()
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
sparse
,
cons_node
});
}
bool
CleanAfterOptA
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
root
);
bool
changed
=
false
;
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet
all_node
=
manager
->
all_nodes
();
auto
all_node
=
manager
->
all_nodes
();
for
(
auto
&
node
:
all_node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
...
...
@@ -409,6 +454,18 @@ bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
new_node
=
ConvertListSetItemToTupleSetItem
(
cnode
);
}
else
if
(
IsValueNode
<
ValueList
>
(
node
))
{
new_node
=
ConvertValueListNodeToValueTupleNode
(
node
->
cast
<
ValueNodePtr
>
());
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeSparseTensor
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeIndexedSlices
))
{
new_node
=
ConvertMakeSparseToMakeTuple
(
cnode
);
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimSparseTensorGetIndices
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimIndexedSlicesGetIndices
))
{
new_node
=
ConvertSparseGetAttrToTupleGetItem
(
cnode
,
0
);
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimSparseTensorGetValues
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimIndexedSlicesGetValues
))
{
new_node
=
ConvertSparseGetAttrToTupleGetItem
(
cnode
,
1
);
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimSparseTensorGetDenseShape
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimIndexedSlicesGetDenseShape
))
{
new_node
=
ConvertSparseGetAttrToTupleGetItem
(
cnode
,
2
);
}
if
(
new_node
!=
nullptr
)
{
...
...
mindspore/ccsrc/frontend/optimizer/clean.h
浏览文件 @
598bfa02
...
...
@@ -32,7 +32,7 @@ namespace opt {
// Remove the class type from graphs
bool
SimplifyDataStructures
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
);
bool
Clean
List
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
);
bool
Clean
AfterOptA
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
);
// Remove most uses of tuples from the graph
// tuples that are returned will be kept
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
598bfa02
...
...
@@ -69,11 +69,11 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
return
true
;
}
bool
Clean
List
Pass
(
const
ResourcePtr
&
res
)
{
bool
Clean
AfterOptA
Pass
(
const
ResourcePtr
&
res
)
{
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
FuncGraphPtr
func_graph
=
res
->
func_graph
();
bool
changed
=
opt
::
Clean
List
(
func_graph
,
res
->
manager
());
bool
changed
=
opt
::
Clean
AfterOptA
(
func_graph
,
res
->
manager
());
abstract
::
AbstractBasePtrList
args_spec
;
auto
parameters
=
func_graph
->
parameters
();
...
...
@@ -337,7 +337,7 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
std
::
vector
<
PassItem
>
kVmPasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_a"
,
OptPassAGroup
},
{
"clean_
list"
,
CleanList
Pass
},
{
"clean_
after_opta"
,
CleanAfterOptA
Pass
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
},
{
"opt_graph_kernel_a"
,
OptPassGraphKernelGroupA
},
...
...
@@ -346,7 +346,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
std
::
vector
<
PassItem
>
kGePasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_a"
,
OptPassAGroup
},
{
"clean_
list"
,
CleanList
Pass
},
{
"clean_
after_opta"
,
CleanAfterOptA
Pass
},
{
"opt_b"
,
OptPassBGroup
},
{
"add_control_depend"
,
AddControlDependPass
},
{
"opt_control"
,
ControlGroup
},
...
...
mindspore/nn/__init__.py
浏览文件 @
598bfa02
...
...
@@ -17,13 +17,14 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks.
"""
from
.
import
layer
,
loss
,
optim
,
metrics
,
wrap
,
probability
from
.
import
layer
,
loss
,
optim
,
metrics
,
wrap
,
probability
,
sparse
from
.cell
import
Cell
,
GraphKernel
from
.layer
import
*
from
.loss
import
*
from
.optim
import
*
from
.metrics
import
*
from
.wrap
import
*
from
.sparse
import
*
__all__
=
[
"Cell"
,
"GraphKernel"
]
...
...
@@ -32,7 +33,7 @@ __all__.extend(loss.__all__)
__all__
.
extend
(
optim
.
__all__
)
__all__
.
extend
(
metrics
.
__all__
)
__all__
.
extend
(
wrap
.
__all__
)
__all__
.
extend
(
sparse
.
__all__
)
__all__
.
sort
()
mindspore/nn/sparse/__init__.py
0 → 100644
浏览文件 @
598bfa02
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Sparse related transformation.
"""
from
.sparse
import
SparseToDense
__all__
=
[
"SparseToDense"
,
]
mindspore/nn/sparse/sparse.py
0 → 100644
浏览文件 @
598bfa02
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sparse related tools."""
from
mindspore.ops
import
operations
as
P
from
..cell
import
Cell
class
SparseToDense
(
Cell
):
"""
Convert a sparse tensor into dense.
Not yet supported by any backend at the moment.
Args:
sparse_tensor (SparseTensor): the sparse tensor to convert.
Returns:
Tensor, the tensor converted.
Examples:
>>> class SparseToDenseCell(nn.Cell):
>>> def __init__(self, dense_shape):
>>> super(SparseToDenseCell, self).__init__()
>>> self.dense_shape = dense_shape
>>> self.sparse_to_dense = nn.SparseToDense()
>>> def construct(self, indices, values):
>>> sparse = SparseTensor(indices, values, self.dense_shape)
>>> return self.sparse_to_dense(sparse)
>>>
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> dense_shape = (3, 4)
>>> SparseToDenseCell(dense_shape)(indices, values)
"""
def
__init__
(
self
):
super
(
SparseToDense
,
self
).
__init__
()
self
.
sparse_to_dense
=
P
.
SparseToDense
()
def
construct
(
self
,
sparse_tensor
):
return
self
.
sparse_to_dense
(
sparse_tensor
.
indices
(),
sparse_tensor
.
values
(),
sparse_tensor
.
dense_shape
())
mindspore/ops/_grad/__init__.py
浏览文件 @
598bfa02
...
...
@@ -15,7 +15,7 @@
"""grad impl."""
from
.
import
grad_array_ops
,
grad_comm_ops
,
grad_debug_ops
,
grad_implementations
,
\
grad_inner_ops
,
grad_math_ops
,
grad_nn_ops
,
grad_other_ops
,
grad_quant_ops
grad_inner_ops
,
grad_math_ops
,
grad_nn_ops
,
grad_other_ops
,
grad_quant_ops
,
grad_sparse
from
.grad_base
import
get_bprop_fn
__all__
=
[
'get_bprop_fn'
]
mindspore/ops/_grad/grad_implementations.py
浏览文件 @
598bfa02
...
...
@@ -116,6 +116,7 @@ def bprop_tuple_getitem(data, idx, out, dout):
"""Backpropagator for primitive `tuple_getitem`."""
return
F
.
tuple_setitem
(
C
.
zeros_like
(
data
),
idx
,
dout
),
C
.
zeros_like
(
idx
)
@
bprops
.
register
(
"list_getitem"
)
def
bprop_list_getitem
(
data
,
idx
,
out
,
dout
):
"""Backpropagator for primitive `list_getitem`."""
...
...
mindspore/ops/_grad/grad_sparse.py
0 → 100644
浏览文件 @
598bfa02
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""bprop primitives"""
from
..
import
functional
as
F
from
..
import
operations
as
P
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
.grad_base
import
bprops
,
bprop_getters
# Unused parameters are placeholders.
@
bprops
.
register
(
"MakeSparseTensor"
)
def
bprop_make_sparse_tensor
(
indices
,
values
,
dense_shape
,
out
,
dout
):
"""Backpropagator for primitive `MakeSparseTensor`."""
return
zeros_like
(
indices
),
F
.
sparse_tensor_get_values
(
dout
),
()
@
bprops
.
register
(
"SparseTensorGetIndices"
)
def
bprop_sparse_tensor_get_indices
(
sparse_tensor
,
out
,
dout
):
"""Backpropagator for primitive `SparseTensorGetIndices`."""
return
(
zeros_like
(
sparse_tensor
),)
@
bprops
.
register
(
"SparseTensorGetValues"
)
def
bprop_sparse_tensor_get_values
(
sparse_tensor
,
out
,
dout
):
"""Backpropagator for primitive `SparseTensorGetValues`."""
return
F
.
make_sparse_tensor
(
F
.
sparse_tensor_get_indices
(
sparse_tensor
),
dout
,
F
.
sparse_tensor_get_dense_shape
(
sparse_tensor
))
@
bprops
.
register
(
"SparseTensorGetDenseShape"
)
def
bprop_sparse_tensor_get_dense_shape
(
sparse_tensor
,
out
,
dout
):
"""Backpropagator for primitive `SparseTensorGetDenseShape`."""
return
(
zeros_like
(
sparse_tensor
),)
@
bprop_getters
.
register
(
P
.
SparseToDense
)
def
get_bprop_sparse_to_dense
(
self
):
"""Generate bprop for SparseToDense"""
def
bprop
(
indices
,
values
,
dense_shape
,
out
,
dout
):
return
zeros_like
(
indices
),
dout
,
zeros_like
(
dense_shape
)
return
bprop
mindspore/ops/composite/multitype_ops/ones_like_impl.py
浏览文件 @
598bfa02
...
...
@@ -42,6 +42,16 @@ def _ones_like_tensor(x):
return
P
.
Fill
()(
P
.
DType
()(
x
),
P
.
Shape
()(
x
),
1.0
)
@
ones_like_leaf
.
register
(
"SparseTensor"
)
def
_ones_like_sparse_tensor
(
x
):
"""Returns a tensor with the same shape and dtype as x and all elements are 1."""
values_
=
F
.
sparse_tensor_get_values
(
x
)
values
=
P
.
Fill
()(
P
.
DType
()(
values_
),
P
.
Shape
()(
values_
),
1.0
)
return
F
.
make_sparse_tensor
(
F
.
sparse_tensor_get_indices
(
x
),
values
,
F
.
sparse_tensor_get_dense_shape
(
x
))
ones_like
=
base
.
HyperMap
(
ones_like_leaf
)
"""
`ones_like` is a function which can generate a graph of `ones_like` operation according to input tensor dtype.
...
...
mindspore/ops/operations/__init__.py
浏览文件 @
598bfa02
...
...
@@ -84,6 +84,7 @@ from ._quant_ops import *
from
.other_ops
import
(
Assign
,
IOU
,
BoundingBoxDecode
,
BoundingBoxEncode
,
PopulationCount
,
CheckValid
,
MakeRefKey
,
Partial
,
Depend
,
CheckBprop
,
Push
,
Pull
)
from
.thor_ops
import
*
from
.sparse_ops
import
SparseToDense
__all__
=
[
'ReverseSequence'
,
...
...
@@ -357,7 +358,8 @@ __all__ = [
"PopulationCount"
,
"ParallelConcat"
,
"Push"
,
"Pull"
"Pull"
,
'SparseToDense'
,
]
__all__
.
sort
()
mindspore/ops/operations/sparse_ops.py
0 → 100644
浏览文件 @
598bfa02
# coding: utf-8
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for sparse operators."""
from
..._checkparam
import
Validator
as
validator
from
...common
import
dtype
as
mstype
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
class
SparseToDense
(
PrimitiveWithInfer
):
"""
Convert a sparse representation into a dense tensor.
Inputs:
- **indices** (Tensor) - The indices of sparse representation.
- **values** (Tensor) - Values corresponding to each row of indices.
- **dense_shape** (tuple) - A int tuple which specifies the shape of dense tensor.
Returns:
Tensor, the shape of tensor is dense_shape.
Examples:
>>> indices = Tensor([[0, 1], [1, 2]])
>>> values = Tensor([1, 2], dtype=ms.float32)
>>> dense_shape = (3, 4)
>>> out = P.SparseToDense()(indices, values, dense_shape)
"""
@
prim_attr_register
def
__init__
(
self
):
"""init index_select"""
self
.
init_prim_io_names
(
inputs
=
[
'indices'
,
'values'
,
'dense_shape'
],
outputs
=
[
'output'
])
def
__infer__
(
self
,
indices
,
values
,
dense_shape
):
validator
.
check_subclass
(
"indices"
,
indices
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
validator
.
check_subclass
(
"values"
,
values
[
'dtype'
],
mstype
.
tensor
,
self
.
name
)
out
=
{
'shape'
:
dense_shape
[
'value'
],
'dtype'
:
values
[
'dtype'
],
'value'
:
None
}
return
out
tests/ut/python/ir/test_sparse_tensor.py
浏览文件 @
598bfa02
...
...
@@ -28,6 +28,7 @@ from mindspore import Tensor, SparseTensor, context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_sparse
=
True
)
grad_op
=
C
.
GradOperation
(
'get_all'
,
get_all
=
True
)
class
MakeSparseTensor
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
...
...
@@ -45,15 +46,6 @@ def test_sparse_tensor_make_sparse_tensor():
def
test_sparse_tensor_attr
():
grad_op
=
C
.
GradOperation
(
'get_all'
,
get_all
=
True
)
class
GradWrap
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
GradWrap
,
self
).
__init__
()
self
.
network
=
network
def
construct
(
self
,
input1
,
input2
):
gout
=
grad_op
(
self
.
network
)(
input1
,
input2
)
return
gout
class
SparseTensorGetAttr
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseTensorGetAttr
,
self
).
__init__
()
...
...
@@ -82,3 +74,20 @@ def test_sparse_tensor_indices_dim_less_than_dense_shape_dim():
dense_shape
=
(
2
,
2
,
2
)
with
pytest
.
raises
(
TypeError
):
MakeSparseTensor
(
dense_shape
)(
indices
,
values
)
def
test_sparse_tensor_to_tensor
():
class
SparseToDenseCell
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
super
(
SparseToDenseCell
,
self
).
__init__
()
self
.
dense_shape
=
dense_shape
self
.
sparse_to_dense
=
nn
.
SparseToDense
()
def
construct
(
self
,
indices
,
values
):
sparse
=
SparseTensor
(
indices
,
values
,
self
.
dense_shape
)
return
self
.
sparse_to_dense
(
sparse
)
indices
=
Tensor
([[
0
,
1
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
dense_shape
=
(
3
,
4
)
SparseToDenseCell
(
dense_shape
)(
indices
,
values
)
grad_op
(
SparseToDenseCell
(
dense_shape
))(
indices
,
values
)
tests/ut/python/pipeline/parse/test_cell_bprop.py
浏览文件 @
598bfa02
...
...
@@ -102,7 +102,7 @@ def test_with_no_bprop():
with_no_bprop
=
WithNoBprop
()
x
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
y
=
Tensor
(
2
,
dtype
=
ms
.
int32
)
assert
C
.
grad_all
(
with_no_bprop
)(
x
,
y
)
==
(
2
,
1
)
C
.
grad_all
(
with_no_bprop
)(
x
,
y
)
def
test_grad_in_bprop_1
():
...
...
@@ -263,10 +263,7 @@ def test_grad_inline_bprop_two_input():
net
=
InlineBpropTwoInput
()
input1
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
input2
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
grads
=
C
.
grad_all
(
net
)(
input1
,
input2
)
assert
(
grads
[
0
].
asnumpy
()
==
np
.
array
([
2
,
2
]).
astype
(
np
.
float32
)).
all
()
assert
(
grads
[
1
].
asnumpy
()
==
np
.
array
([
2
,
2
]).
astype
(
np
.
float32
)).
all
()
assert
len
(
grads
)
==
2
C
.
grad_all
(
net
)(
input1
,
input2
)
class
TwoInputBprop
(
nn
.
Cell
):
...
...
@@ -350,24 +347,6 @@ def test_refkey_bprop():
assert
(
grads
[
1
][
0
].
asnumpy
()
==
np
.
array
([
2
,
2
]).
astype
(
np
.
float32
)).
all
()
class
MulAddWithWrongOutputNum
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MulAddWithWrongOutputNum
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
2
*
x
+
y
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
return
(
2
*
dout
,)
def
test_grad_mul_add_with_wrong_output_num
():
context
.
set_context
(
check_bprop
=
True
)
mul_add
=
MulAddWithWrongOutputNum
()
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
mul_add
)(
1
,
2
)
class
MulAddWithWrongOutputType
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
MulAddWithWrongOutputType
,
self
).
__init__
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录