Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4d4e23fd
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4d4e23fd
编写于
7月 20, 2020
作者:
P
panyifeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add bprop for sparse_tensor
上级
abcee8e5
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
214 addition
and
23 deletion
+214
-23
mindspore/ccsrc/debug/anf_ir_utils.cc
mindspore/ccsrc/debug/anf_ir_utils.cc
+3
-0
mindspore/ccsrc/frontend/operator/composite/composite.cc
mindspore/ccsrc/frontend/operator/composite/composite.cc
+41
-0
mindspore/ccsrc/frontend/operator/composite/composite.h
mindspore/ccsrc/frontend/operator/composite/composite.h
+10
-0
mindspore/ccsrc/frontend/operator/prim_others.cc
mindspore/ccsrc/frontend/operator/prim_others.cc
+4
-0
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
+8
-0
mindspore/ccsrc/frontend/optimizer/clean.cc
mindspore/ccsrc/frontend/optimizer/clean.cc
+45
-2
mindspore/ccsrc/frontend/optimizer/clean.h
mindspore/ccsrc/frontend/optimizer/clean.h
+1
-0
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+30
-8
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
+11
-3
mindspore/ops/_grad/grad_implementations.py
mindspore/ops/_grad/grad_implementations.py
+5
-0
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+8
-0
tests/st/ops/ascend/test_addn.py
tests/st/ops/ascend/test_addn.py
+15
-0
tests/ut/python/ir/test_indexed_slices.py
tests/ut/python/ir/test_indexed_slices.py
+2
-2
tests/ut/python/ir/test_sparse_tensor.py
tests/ut/python/ir/test_sparse_tensor.py
+31
-8
未找到文件。
mindspore/ccsrc/debug/anf_ir_utils.cc
浏览文件 @
4d4e23fd
...
...
@@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
* │ └── MapPy
* ├── Tail
* ├── MakeTupleGradient
* ├── MakeListGradient
* ├── GradOperation
* └── TupleAdd
*/
...
...
@@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
// do nothing
}
else
if
(
meta_func_graph
->
isa
<
prim
::
MakeTupleGradient
>
())
{
// do nothing
}
else
if
(
meta_func_graph
->
isa
<
prim
::
MakeListGradient
>
())
{
// do nothing
}
else
if
(
meta_func_graph
->
isa
<
prim
::
TupleAdd
>
())
{
// do nothing
}
else
if
(
meta_func_graph
->
isa
<
prim
::
TupleSlice
>
())
{
...
...
mindspore/ccsrc/frontend/operator/composite/composite.cc
浏览文件 @
4d4e23fd
...
...
@@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
return
fg
;
}
FuncGraphPtr
MakeListGradient
::
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
{
int
list_size
=
SizeToInt
(
args_spec_list
.
size
());
std
::
ostringstream
ss
;
ss
<<
"▶make_list_"
<<
list_size
;
FuncGraphPtr
fg
=
std
::
make_shared
<
FuncGraph
>
();
fg
->
debug_info
()
->
set_name
(
ss
.
str
());
std
::
vector
<
AnfNodePtr
>
params
;
params
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeList
));
for
(
int
i
=
0
;
i
<
list_size
;
++
i
)
{
params
.
push_back
(
fg
->
add_parameter
());
}
// make fprob first result, maketuple's forward result.
AnfNodePtr
out
=
fg
->
NewCNode
(
params
);
// make fprob second result, maketuple's backward function.
FuncGraphPtr
b
=
std
::
make_shared
<
FuncGraph
>
();
ss
.
clear
();
ss
<<
"◀make_list_"
<<
list_size
;
b
->
debug_info
()
->
set_name
(
ss
.
str
());
AnfNodePtr
dout
=
b
->
add_parameter
();
std
::
vector
<
AnfNodePtr
>
grads
;
grads
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
grads
.
push_back
(
NewValueNode
(
newenv
));
for
(
int
i
=
0
;
i
<
list_size
;
++
i
)
{
grads
.
push_back
(
b
->
NewCNode
({
NewValueNode
(
prim
::
kPrimListGetItem
),
dout
,
NewValueNode
(
i
)}));
}
b
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
b
->
set_output
(
b
->
NewCNode
(
grads
));
fg
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
fg
->
set_output
(
fg
->
NewCNode
({
NewValueNode
(
prim
::
kPrimMakeTuple
),
out
,
NewValueNode
(
b
)}));
(
void
)
fg
->
transforms
().
emplace
(
"primal"
,
FuncGraphTransform
(
prim
::
kPrimMakeList
));
return
fg
;
}
GradOperation
::
GradOperation
(
const
std
::
string
&
name
,
bool
get_all
,
bool
get_by_list
,
bool
sens_param
)
:
MetaFuncGraph
(
name
),
get_all_
(
get_all
),
get_by_list_
(
get_by_list
),
sens_param_
(
sens_param
)
{
if
(
get_by_list
)
{
...
...
mindspore/ccsrc/frontend/operator/composite/composite.h
浏览文件 @
4d4e23fd
...
...
@@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph {
};
using
MakeTupleGradientPtr
=
std
::
shared_ptr
<
MakeTupleGradient
>
;
class
MakeListGradient
:
public
MetaFuncGraph
{
public:
explicit
MakeListGradient
(
const
std
::
string
&
name
)
:
MetaFuncGraph
(
name
)
{}
~
MakeListGradient
()
override
=
default
;
MS_DECLARE_PARENT
(
MakeListGradient
,
MetaFuncGraph
)
FuncGraphPtr
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
override
;
friend
bool
operator
==
(
const
MakeListGradient
&
lhs
,
const
MakeListGradient
&
rhs
)
{
return
lhs
.
name_
==
rhs
.
name_
;
}
};
using
MakeListGradientPtr
=
std
::
shared_ptr
<
MakeListGradient
>
;
class
GradOperation
:
public
MetaFuncGraph
{
public:
explicit
GradOperation
(
const
std
::
string
&
name
,
bool
get_all
=
false
,
bool
get_by_list
=
false
,
...
...
mindspore/ccsrc/frontend/operator/prim_others.cc
浏览文件 @
4d4e23fd
...
...
@@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
auto
elem
=
GetValue
<
int
>
(
e
);
return
elem
;
});
if
(
IntToSize
(
indices_shp
[
1
])
!=
dense_shape_vec
.
size
())
{
MS_EXCEPTION
(
TypeError
)
<<
"The size of dense_shape must be equal with the second dimension of indices "
<<
indices_shp
[
1
]
<<
", but got "
<<
dense_shape_vec
.
size
();
}
for
(
auto
dense_shape_elem
:
dense_shape_vec
)
{
if
(
dense_shape_elem
<
0
)
{
MS_EXCEPTION
(
TypeError
)
<<
"The element of dense_shape must be positive, but got "
...
...
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
浏览文件 @
4d4e23fd
...
...
@@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
return
meta
;
}
if
(
prim
->
Hash
()
==
prim
::
kPrimMakeList
->
Hash
()
&&
prim
->
name
()
==
prim
::
kPrimMakeList
->
name
())
{
MetaFuncGraphPtr
meta
=
std
::
make_shared
<
prim
::
MakeListGradient
>
(
"make_list_gradient"
);
bprop_registry_meta_
[
prim
::
kPrimMakeList
]
=
meta
;
return
meta
;
}
MS_LOG
(
EXCEPTION
)
<<
"Fail to find bprop function for "
<<
prim
->
name
()
<<
"."
;
}
...
...
@@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return
fprop
;
}
else
if
(
prim
->
Hash
()
==
prim
::
kPrimMakeTuple
->
Hash
()
&&
prim
->
name
()
==
prim
::
kPrimMakeTuple
->
name
())
{
return
nullptr
;
}
else
if
(
prim
->
Hash
()
==
prim
::
kPrimMakeList
->
Hash
()
&&
prim
->
name
()
==
prim
::
kPrimMakeList
->
name
())
{
return
nullptr
;
}
FuncGraphPtr
bprop_fg
=
nullptr
;
...
...
mindspore/ccsrc/frontend/optimizer/clean.cc
浏览文件 @
4d4e23fd
...
...
@@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
[](
const
AbstractAttribute
&
item
)
{
return
item
.
second
;
});
return
std
::
make_shared
<
AbstractTuple
>
(
baselist
);
}
return
nullptr
;
}
static
AbstractBasePtr
AdaptAbs
(
const
AbstractBasePtr
&
t
)
{
if
(
t
==
nullptr
)
{
return
nullptr
;
}
if
(
t
->
isa
<
AbstractList
>
())
{
auto
abs_list
=
dyn_cast
<
AbstractList
>
(
t
);
return
std
::
make_shared
<
AbstractTuple
>
(
abs_list
->
elements
());
...
...
@@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
new_node
=
EraseMakeKeywordArgNode
(
cnode
);
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimExtractKeywordArg
))
{
new_node
=
EraseExtractKeywordArg
(
cnode
);
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeList
))
{
}
if
(
new_node
!=
nullptr
)
{
new_node
->
set_abstract
(
node
->
abstract
());
MS_LOG
(
DEBUG
)
<<
"Replace node: "
<<
node
->
DebugString
()
<<
" with new_node: "
<<
new_node
->
DebugString
();
(
void
)
manager
->
Replace
(
node
,
new_node
);
changed
=
true
;
}
}
for
(
auto
&
node
:
manager
->
all_nodes
())
{
auto
ret
=
Reabs
(
node
->
abstract
());
if
(
ret
)
{
MS_LOG
(
DEBUG
)
<<
"Replace "
<<
node
->
DebugString
()
<<
"'s abstract "
<<
node
->
abstract
()
->
ToString
()
<<
" with "
<<
ret
->
ToString
();
node
->
set_abstract
(
ret
);
changed
=
true
;
}
}
return
changed
;
}
bool
CleanList
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
)
{
MS_EXCEPTION_IF_NULL
(
manager
);
manager
->
AddFuncGraph
(
root
);
bool
changed
=
false
;
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet
all_node
=
manager
->
all_nodes
();
for
(
auto
&
node
:
all_node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
AnfNodePtr
new_node
=
nullptr
;
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimMakeList
))
{
new_node
=
ConvertMakeListToMakeTuple
(
cnode
);
}
else
if
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimListGetItem
))
{
new_node
=
ConvertListGetItemToTupleGetItem
(
cnode
);
...
...
@@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
}
for
(
auto
&
node
:
manager
->
all_nodes
())
{
auto
ret
=
Rea
bs
(
node
->
abstract
());
auto
ret
=
AdaptA
bs
(
node
->
abstract
());
if
(
ret
)
{
MS_LOG
(
DEBUG
)
<<
"Replace "
<<
node
->
DebugString
()
<<
"'s abstract "
<<
node
->
abstract
()
->
ToString
()
<<
" with "
<<
ret
->
ToString
();
...
...
mindspore/ccsrc/frontend/optimizer/clean.h
浏览文件 @
4d4e23fd
...
...
@@ -32,6 +32,7 @@ namespace opt {
// Remove the class type from graphs
bool
SimplifyDataStructures
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
);
bool
CleanList
(
const
FuncGraphPtr
&
root
,
const
FuncGraphManagerPtr
&
manager
);
// Remove most uses of tuples from the graph
// tuples that are returned will be kept
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
4d4e23fd
...
...
@@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
return
true
;
}
bool
CleanListPass
(
const
ResourcePtr
&
res
)
{
MS_EXCEPTION_IF_NULL
(
res
->
func_graph
());
FuncGraphPtr
func_graph
=
res
->
func_graph
();
bool
changed
=
opt
::
CleanList
(
func_graph
,
res
->
manager
());
abstract
::
AbstractBasePtrList
args_spec
;
auto
parameters
=
func_graph
->
parameters
();
(
void
)
std
::
transform
(
parameters
.
begin
(),
parameters
.
end
(),
std
::
back_inserter
(
args_spec
),
[](
const
AnfNodePtr
&
p
)
->
AbstractBasePtr
{
return
p
->
abstract
();
});
if
(
changed
)
{
FuncGraphPtr
new_fg
=
Renormalize
(
res
,
func_graph
,
args_spec
);
res
->
set_func_graph
(
new_fg
);
}
res
->
set_args_spec
(
args_spec
);
return
true
;
}
namespace
{
OptPassGroupMap
GetOptPassesA
(
const
opt
::
irpass
::
OptimizeIRPassLib
&
irpass
)
{
opt
::
OptPassConfig
a_1
=
opt
::
OptPassConfig
({
...
...
@@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
// Safe inlining
irpass
.
inline_
,
irpass
.
sparse_tensor_eliminate_
,
});
opt
::
OptPassConfig
a_2
=
opt
::
OptPassConfig
({
irpass
.
merge_addn_
,
...
...
@@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass
.
make_ref_eliminate_
,
irpass
.
get_ref_param_eliminate_
,
irpass
.
indexed_slices_eliminate_
,
irpass
.
sparse_tensor_eliminate_
,
});
OptPassGroupMap
map
({
{
"b_1"
,
b_1
},
...
...
@@ -322,18 +340,22 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
return
true
;
}
std
::
vector
<
PassItem
>
kVmPasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
std
::
vector
<
PassItem
>
kVmPasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_a"
,
OptPassAGroup
},
{
"clean_list"
,
CleanListPass
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
},
{
"opt_graph_kernel_a"
,
OptPassGraphKernelGroupA
},
{
"opt_graph_kernel_b"
,
OptPassGraphKernelGroupB
},
{
"add_control_depend"
,
AddControlDependPass
}};
std
::
vector
<
PassItem
>
kGePasses
=
{
{
"opt_a"
,
OptPassAGroup
},
{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_b"
,
OptPassBGroup
},
{
"add_control_depend"
,
AddControlDependPass
},
{
"opt_control"
,
ControlGroup
},
{
"opt_prepare"
,
PrepareGroup
},
std
::
vector
<
PassItem
>
kGePasses
=
{{
"simplify_data_structures"
,
SimplifyDataStructuresPass
},
{
"opt_a"
,
OptPassAGroup
},
{
"clean_list"
,
CleanListPass
},
{
"opt_b"
,
OptPassBGroup
},
{
"add_control_depend"
,
AddControlDependPass
},
{
"opt_control"
,
ControlGroup
},
{
"opt_prepare"
,
PrepareGroup
},
{
"cconv"
,
CconvPass
}};
std
::
vector
<
PassItem
>
kPynativePasses
=
{{
"opt_a"
,
OptPassAGroup
},
{
"opt_b"
,
OptPassBGroup
},
{
"cconv"
,
CconvPass
}};
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
浏览文件 @
4d4e23fd
...
...
@@ -22,6 +22,7 @@
#include "ir/func_graph_cloner.h"
#include "abstract/utils.h"
#include "debug/trace.h"
#include "utils/context/ms_context.h"
namespace
mindspore
{
namespace
abstract
{
...
...
@@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
AbstractBasePtrList
bparams
;
bparams
.
push_back
(
SensitivityTransform
(
orig_func_
));
(
void
)
std
::
transform
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
bparams
),
[](
const
AbstractBasePtr
&
arg_spec
)
->
AbstractBasePtr
{
return
SensitivityTransform
(
arg_spec
);
});
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
();
(
void
)
std
::
transform
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
bparams
),
[
&
enable_sparse
](
const
AbstractBasePtr
&
arg_spec
)
->
AbstractBasePtr
{
if
(
enable_sparse
&&
arg_spec
->
isa
<
AbstractTensor
>
())
{
return
std
::
make_shared
<
AbstractUndetermined
>
();
}
return
SensitivityTransform
(
arg_spec
);
});
AbstractBasePtr
bparams_final
=
std
::
make_shared
<
AbstractTuple
>
(
bparams
);
AbstractFunctionPtr
bprop
=
std
::
make_shared
<
VirtualAbstractClosure
>
(
SensitivityTransform
(
result
->
abstract
()),
bparams_final
);
...
...
mindspore/ops/_grad/grad_implementations.py
浏览文件 @
4d4e23fd
...
...
@@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout):
"""Backpropagator for primitive `tuple_getitem`."""
return
F
.
tuple_setitem
(
C
.
zeros_like
(
data
),
idx
,
dout
),
C
.
zeros_like
(
idx
)
@
bprops
.
register
(
"list_getitem"
)
def
bprop_list_getitem
(
data
,
idx
,
out
,
dout
):
"""Backpropagator for primitive `list_getitem`."""
return
F
.
list_setitem
(
C
.
zeros_like
(
data
),
idx
,
dout
),
C
.
zeros_like
(
idx
)
@
bprops
.
register
(
"identity"
)
def
bprop_identity
(
x
,
out
,
dout
):
...
...
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
4d4e23fd
...
...
@@ -17,6 +17,7 @@
from
functools
import
reduce
import
numpy
as
np
import
mindspore
as
ms
from
mindspore.ops
import
_selected_grad_ops
as
SG
from
..
import
functional
as
F
from
..
import
operations
as
P
...
...
@@ -33,6 +34,7 @@ shape_op = P.Shape()
reduce_sum
=
P
.
ReduceSum
()
reshape
=
P
.
Reshape
()
tile
=
P
.
Tile
()
is_sub_class
=
P
.
IsSubClass
()
def
binop_grad_common
(
x
,
y
,
dx
,
dy
):
...
...
@@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self):
"""Generate bprop for AddN"""
def
bprop
(
x
,
out
,
dout
):
if
is_sub_class
(
F
.
typeof
(
x
),
ms
.
list_
):
dx
=
[]
for
_
in
range
(
len
(
x
)):
dx
.
append
(
dout
)
return
(
dx
,)
dx
=
()
for
_
in
range
(
len
(
x
)):
dx
=
dx
+
(
dout
,)
...
...
tests/st/ops/ascend/test_addn.py
浏览文件 @
4d4e23fd
...
...
@@ -16,6 +16,7 @@ import numpy as np
import
mindspore.context
as
context
import
mindspore.nn
as
nn
import
mindspore.ops.composite
as
C
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
...
...
@@ -45,3 +46,17 @@ def test_net():
add
=
Net
()
output
=
add
(
x
,
y
)
assert
output
==
expect
def
test_grad_addn_with_list
():
grad_op
=
C
.
GradOperation
(
'get_all'
,
get_all
=
True
)
class
AddN
(
nn
.
Cell
):
def
__init__
(
self
):
super
().
__init__
()
self
.
add_n
=
P
.
AddN
()
def
construct
(
self
,
a
,
b
):
return
self
.
add_n
([
a
,
b
])
inp
=
Tensor
(
np
.
ones
([
128
,
96
]).
astype
(
np
.
float32
))
grad_op
(
AddN
())(
inp
,
inp
)
tests/ut/python/ir/test_indexed_slices.py
浏览文件 @
4d4e23fd
...
...
@@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
self
.
network
=
network
def
construct
(
self
,
x
,
y
):
grad
=
grad_all
(
self
.
network
)(
x
,
y
)
return
grad
,
grad
[
0
],
grad
[
1
]
return
grad
[
0
].
indices
(),
grad
[
0
].
values
(),
grad
[
0
].
dense_shape
()
class
SparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseGatherV2
,
self
).
__init__
()
...
...
@@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
weights
=
self
.
weights
grad
=
grad_by_list
(
self
.
network
,
weights
)(
x
)
x
=
grad
[
0
]
return
x
,
x
.
values
(),
x
.
indices
(),
x
.
dense_shape
()
return
x
.
values
(),
x
.
indices
(),
x
.
dense_shape
()
class
SparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseGatherV2
,
self
).
__init__
()
...
...
tests/ut/python/ir/test_sparse_tensor.py
浏览文件 @
4d4e23fd
...
...
@@ -18,6 +18,9 @@
@Date : 2020-07-16
@Desc : test mindspore sparse_tensor's operation
"""
import
numpy
as
np
import
pytest
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore.ops
import
composite
as
C
...
...
@@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_sparse
=
True
)
def
test_sparse_tensor_make_sparse_tensor
():
class
MakeSparseTensor
(
nn
.
Cell
):
def
__init__
(
self
):
class
MakeSparseTensor
(
nn
.
Cell
):
def
__init__
(
self
,
dense_shape
):
super
(
MakeSparseTensor
,
self
).
__init__
()
self
.
dense_shape
=
(
3
,
4
)
self
.
dense_shape
=
dense_shape
def
construct
(
self
,
indices
,
values
):
ret
=
(
SparseTensor
(
indices
,
values
,
self
.
dense_shape
),)
return
ret
[
0
]
def
test_sparse_tensor_make_sparse_tensor
():
indices
=
Tensor
([[
0
,
1
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
MakeSparseTensor
()(
indices
,
values
)
MakeSparseTensor
(
(
3
,
4
)
)(
indices
,
values
)
def
test_sparse_tensor_attr
():
...
...
@@ -59,3 +65,20 @@ def test_sparse_tensor_attr():
indices
=
Tensor
([[
0
,
1
],
[
1
,
2
]])
values
=
Tensor
([
1
,
2
],
dtype
=
ms
.
float32
)
SparseTensorGetAttr
()(
indices
,
values
)
grad_op
(
SparseTensorGetAttr
())(
indices
,
values
)
def
test_sparse_tensor_indices_dim_greater_than_dense_shape_dim
():
indices
=
Tensor
(
np
.
array
([[
0
,
0
,
0
],
[
0
,
0
,
1
]],
dtype
=
np
.
int32
))
values
=
Tensor
(
np
.
array
([
100
,
200
],
dtype
=
np
.
float32
))
dense_shape
=
(
2
,
2
)
with
pytest
.
raises
(
TypeError
):
MakeSparseTensor
(
dense_shape
)(
indices
,
values
)
def
test_sparse_tensor_indices_dim_less_than_dense_shape_dim
():
indices
=
Tensor
(
np
.
array
([[
0
,
0
],
[
0
,
1
]],
dtype
=
np
.
int32
))
values
=
Tensor
(
np
.
array
([
100
,
200
],
dtype
=
np
.
float32
))
dense_shape
=
(
2
,
2
,
2
)
with
pytest
.
raises
(
TypeError
):
MakeSparseTensor
(
dense_shape
)(
indices
,
values
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录