Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
86889c59
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86889c59
编写于
7月 15, 2020
作者:
W
wangnan39@huawei.com
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimizer adapt IndexedSlices
上级
4dc96564
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
73 addition
and
233 deletion
+73
-233
mindspore/nn/optim/adam.py
mindspore/nn/optim/adam.py
+11
-7
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+11
-7
mindspore/nn/optim/lazyadam.py
mindspore/nn/optim/lazyadam.py
+3
-3
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+8
-6
mindspore/nn/optim/proximal_ada_grad.py
mindspore/nn/optim/proximal_ada_grad.py
+3
-2
mindspore/nn/wrap/grad_reducer.py
mindspore/nn/wrap/grad_reducer.py
+16
-15
mindspore/ops/_grad/grad_array_ops.py
mindspore/ops/_grad/grad_array_ops.py
+3
-2
mindspore/ops/_grad/grad_comm_ops.py
mindspore/ops/_grad/grad_comm_ops.py
+13
-12
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
+0
-174
tests/ut/python/parallel/test_sparse_feature_bprop.py
tests/ut/python/parallel/test_sparse_feature_bprop.py
+5
-5
未找到文件。
mindspore/nn/optim/adam.py
浏览文件 @
86889c59
...
@@ -108,24 +108,26 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
...
@@ -108,24 +108,26 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po
validator
.
check_integer
(
'decay_steps'
,
decay_steps
,
0
,
Rel
.
GT
,
prim_name
)
validator
.
check_integer
(
'decay_steps'
,
decay_steps
,
0
,
Rel
.
GT
,
prim_name
)
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"
Tuple
"
,
@
_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"
IndexedSlices
"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
,
ps_parameter
):
moment1
,
moment2
,
ps_parameter
):
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
indices
=
gradient
.
indices
()
values
=
gradient
.
values
()
if
ps_parameter
:
if
ps_parameter
:
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
_ps_push
=
P
.
Push
(
"Adam"
,
[
0
,
1
,
2
])
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]
))
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]
),
shapes
),
params
))
eps
,
values
,
indices
),
shapes
),
params
))
else
:
else
:
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]
))
eps
,
values
,
indices
))
return
success
return
success
...
@@ -149,17 +151,19 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b
...
@@ -149,17 +151,19 @@ def _run_opt_with_one_number(opt, sparse_opt, beta1_power, beta2_power, beta1, b
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
@
_adam_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"
Tuple
"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"
IndexedSlices
"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_push_pull_opt_with_sparse
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_push_pull_opt_with_sparse
(
push
,
pull
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
moment1
,
moment2
):
"""Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse."""
"""Apply sparse adam optimizer by push and pull to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
values
=
gradient
.
values
()
indices
=
gradient
.
indices
()
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
shapes
=
(
op_shape
(
params
),
op_shape
(
moment1
),
op_shape
(
moment2
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta1_power
),
op_shape
(
beta2_power
),
op_shape
(
lr
),
op_shape
(
beta1
),
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]
))
op_shape
(
beta2
),
op_shape
(
eps
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
pull
(
push
((
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]
),
shapes
),
params
))
eps
,
values
,
indices
),
shapes
),
params
))
return
success
return
success
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
86889c59
...
@@ -25,20 +25,22 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
...
@@ -25,20 +25,22 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt")
_ftrl_push_pull_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
_ftrl_push_pull_opt
=
C
.
MultitypeFuncGraph
(
"ftrl_opt"
)
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"
Tuple
"
,
"Tensor"
,
@
_ftrl_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"
IndexedSlices
"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
"Tensor"
,
"Bool"
)
def
_tensor_run_opt_with_sparse
(
opt
,
spars_opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
,
def
_tensor_run_opt_with_sparse
(
opt
,
spars_opt
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
,
ps_parameter
):
ps_parameter
):
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
indices
=
gradient
.
indices
()
values
=
gradient
.
values
()
if
ps_parameter
:
if
ps_parameter
:
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
_ps_pull
=
P
.
Pull
()
_ps_pull
=
P
.
Pull
()
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
_ps_push
=
P
.
Push
(
"Ftrl"
,
[
0
,
1
,
2
])
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]
))
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
gradient
[
1
],
gradient
[
0
]
),
shapes
),
weight
))
success
=
F
.
depend
(
success
,
_ps_pull
(
_ps_push
((
values
,
indices
),
shapes
),
weight
))
else
:
else
:
success
=
F
.
depend
(
success
,
spars_opt
(
weight
,
moment
,
linear
,
gradient
[
1
],
gradient
[
0
]
))
success
=
F
.
depend
(
success
,
spars_opt
(
weight
,
moment
,
linear
,
values
,
indices
))
return
success
return
success
...
@@ -58,14 +60,16 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra
...
@@ -58,14 +60,16 @@ def _tensor_run_opt(opt, spars_opt, learning_rate, l1, l2, lr_power, linear, gra
return
success
return
success
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"
Tuple
"
,
@
_ftrl_push_pull_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Number"
,
"Number"
,
"Number"
,
"Tensor"
,
"
IndexedSlices
"
,
"Tensor"
,
"Tensor"
)
"Tensor"
,
"Tensor"
)
def
_tensor_run_push_pull_opt_with_sparse
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
def
_tensor_run_push_pull_opt_with_sparse
(
push
,
pull
,
learning_rate
,
l1
,
l2
,
lr_power
,
linear
,
gradient
,
weight
,
moment
):
weight
,
moment
):
success
=
True
success
=
True
op_shape
=
P
.
Shape
()
op_shape
=
P
.
Shape
()
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
gradient
[
1
]),
op_shape
(
gradient
[
0
]))
values
=
gradient
.
values
()
success
=
F
.
depend
(
success
,
pull
(
push
((
gradient
[
1
],
gradient
[
0
]),
shapes
),
weight
))
indices
=
gradient
.
indices
()
shapes
=
(
op_shape
(
weight
),
op_shape
(
moment
),
op_shape
(
linear
),
op_shape
(
values
),
op_shape
(
indices
))
success
=
F
.
depend
(
success
,
pull
(
push
((
values
,
indices
),
shapes
),
weight
))
return
success
return
success
...
...
mindspore/nn/optim/lazyadam.py
浏览文件 @
86889c59
...
@@ -27,14 +27,14 @@ from .optimizer import Optimizer
...
@@ -27,14 +27,14 @@ from .optimizer import Optimizer
_lazy_adam_opt
=
C
.
MultitypeFuncGraph
(
"lazy_adam_opt"
)
_lazy_adam_opt
=
C
.
MultitypeFuncGraph
(
"lazy_adam_opt"
)
@
_lazy_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tuple"
,
@
_lazy_adam_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
"
IndexedSlices"
,
"
Tensor"
,
"Tensor"
,
"Tensor"
)
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
def
_run_opt_with_sparse
(
opt
,
sparse_opt
,
beta1_power
,
beta2_power
,
beta1
,
beta2
,
eps
,
lr
,
gradient
,
params
,
moment1
,
moment2
):
moment1
,
moment2
):
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
"""Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
success
=
F
.
depend
(
success
,
sparse_opt
(
params
,
moment1
,
moment2
,
beta1_power
,
beta2_power
,
lr
,
beta1
,
beta2
,
eps
,
gradient
[
1
],
gradient
[
0
]
))
eps
,
gradient
.
values
(),
gradient
.
indices
()
))
return
success
return
success
...
...
mindspore/nn/optim/optimizer.py
浏览文件 @
86889c59
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
...
@@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.cell
import
Cell
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.initializer
import
initializer
from
mindspore.common.initializer
import
initializer
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
,
IndexedSlices
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
mindspore._checkparam
import
Rel
...
@@ -490,12 +490,14 @@ op_gather = P.GatherV2()
...
@@ -490,12 +490,14 @@ op_gather = P.GatherV2()
_apply_decay
=
C
.
MultitypeFuncGraph
(
"apply_decay"
)
_apply_decay
=
C
.
MultitypeFuncGraph
(
"apply_decay"
)
@
_apply_decay
.
register
(
"Number"
,
"Bool"
,
"Tensor"
,
"
Tuple
"
)
@
_apply_decay
.
register
(
"Number"
,
"Bool"
,
"Tensor"
,
"
IndexedSlices
"
)
def
_tensor_apply_decay_with_sparse
(
weight_decay
,
if_apply
,
weight
,
gradient
):
def
_tensor_apply_decay_with_sparse
(
weight_decay
,
if_apply
,
weight
,
gradient
):
"""Get grad with weight_decay."""
"""Get grad with weight_decay."""
if
if_apply
:
if
if_apply
:
weight
=
op_gather
(
weight
,
gradient
[
0
],
0
)
indices
=
gradient
.
indices
()
return
gradient
[
0
],
op_add
((
weight
*
weight_decay
,
gradient
[
1
])),
gradient
[
2
]
values
=
op_add
((
op_gather
(
weight
,
indices
,
0
)
*
weight_decay
,
gradient
.
values
()))
shape
=
gradient
.
dense_shape
()
return
IndexedSlices
(
indices
,
values
,
shape
)
return
gradient
return
gradient
...
@@ -518,9 +520,9 @@ def tensor_grad_scale(scale, grad):
...
@@ -518,9 +520,9 @@ def tensor_grad_scale(scale, grad):
return
grad
*
scale
return
grad
*
scale
@
_grad_scale
.
register
(
"Number"
,
"
Tuple
"
)
@
_grad_scale
.
register
(
"Number"
,
"
IndexedSlices
"
)
def
tensor_grad_scale_with_sparse
(
scale
,
grad
):
def
tensor_grad_scale_with_sparse
(
scale
,
grad
):
"""Get grad with scale."""
"""Get grad with scale."""
if
scale
==
1.0
:
if
scale
==
1.0
:
return
grad
return
grad
return
grad
[
0
],
grad
[
1
]
*
scale
,
grad
[
2
]
return
IndexedSlices
(
grad
.
indices
(),
grad
.
values
()
*
scale
,
grad
.
dense_shape
())
mindspore/nn/optim/proximal_ada_grad.py
浏览文件 @
86889c59
...
@@ -23,11 +23,12 @@ from .optimizer import Optimizer
...
@@ -23,11 +23,12 @@ from .optimizer import Optimizer
_proximal_ada_grad_opt
=
C
.
MultitypeFuncGraph
(
"proximal_ada_grad_opt"
)
_proximal_ada_grad_opt
=
C
.
MultitypeFuncGraph
(
"proximal_ada_grad_opt"
)
@
_proximal_ada_grad_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tuple"
,
"Tensor"
,
"Tensor"
)
@
_proximal_ada_grad_opt
.
register
(
"Function"
,
"Function"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"IndexedSlices"
,
"Tensor"
,
"Tensor"
)
def
_tensor_run_opt_with_sparse
(
opt
,
sparse_opt
,
learning_rate
,
l1
,
l2
,
gradient
,
weight
,
accum
):
def
_tensor_run_opt_with_sparse
(
opt
,
sparse_opt
,
learning_rate
,
l1
,
l2
,
gradient
,
weight
,
accum
):
"""Apply sparse proximal_ada_grad optimizer to the weight parameter."""
"""Apply sparse proximal_ada_grad optimizer to the weight parameter."""
success
=
True
success
=
True
success
=
F
.
depend
(
success
,
sparse_opt
(
weight
,
accum
,
learning_rate
,
l1
,
l2
,
gradient
[
1
],
gradient
[
0
]
))
success
=
F
.
depend
(
success
,
sparse_opt
(
weight
,
accum
,
learning_rate
,
l1
,
l2
,
gradient
.
values
(),
gradient
.
indices
()
))
return
success
return
success
...
...
mindspore/nn/wrap/grad_reducer.py
浏览文件 @
86889c59
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.cell
import
Cell
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.communication.management
import
GlobalComm
,
get_group_size
from
mindspore.common.tensor
import
IndexedSlices
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
AllGather
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
...
@@ -77,7 +78,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
...
@@ -77,7 +78,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduc
return
grad
return
grad
@
reduce_opt
.
register
(
"Number"
,
"Bool"
,
"Function"
,
"Bool"
,
"
Tuple
"
,
"Function"
)
@
reduce_opt
.
register
(
"Number"
,
"Bool"
,
"Function"
,
"Bool"
,
"
IndexedSlices
"
,
"Function"
)
def
_tensors_allreduce_with_sparse
(
degree
,
mean
,
allgather
,
allreduce_filter
,
grad
,
allreduce
):
def
_tensors_allreduce_with_sparse
(
degree
,
mean
,
allgather
,
allreduce_filter
,
grad
,
allreduce
):
"""
"""
Apply allgather on gradient instead of allreduce for sparse feature.
Apply allgather on gradient instead of allreduce for sparse feature.
...
@@ -88,21 +89,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr
...
@@ -88,21 +89,21 @@ def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, gr
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allgather would apply.
allreduce_filter (bool): When it is true, allgather would apply.
grad (
tuple): The indices, gradient tensor and tensor_shape
before operation.
grad (
IndexedSlices): The gradient
before operation.
allreduce (Primitive): The communication operator for gradients.
allreduce (Primitive): The communication operator for gradients.
Returns:
Returns:
Tuple, include indices, the gradient tensor and tensor_shape
after operation.
IndexedSlices, the gradient
after operation.
"""
"""
if
allreduce_filter
:
if
allreduce_filter
:
indices
=
allgather
(
grad
[
0
]
)
indices
=
allgather
(
grad
.
indices
()
)
dout
=
allgather
(
grad
[
1
]
)
dout
=
allgather
(
grad
.
values
()
)
if
mean
:
if
mean
:
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
grad
[
1
]
))
degree
=
F
.
scalar_cast
(
degree
,
F
.
dtype
(
grad
.
values
()
))
cast_op
=
P
.
Cast
()
cast_op
=
P
.
Cast
()
mul_op
=
P
.
Mul
()
mul_op
=
P
.
Mul
()
dout
=
mul_op
(
dout
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
dout
)))
dout
=
mul_op
(
dout
,
cast_op
(
F
.
scalar_to_array
(
1.0
/
degree
),
F
.
dtype
(
dout
)))
grad
=
(
indices
,
dout
,
grad
[
2
]
)
grad
=
IndexedSlices
(
indices
,
dout
,
grad
.
dense_shape
()
)
return
grad
return
grad
...
@@ -123,18 +124,18 @@ def _tensors_get_datatype(grad):
...
@@ -123,18 +124,18 @@ def _tensors_get_datatype(grad):
return
F
.
dtype
(
grad
)
return
F
.
dtype
(
grad
)
@
_get_datatype
.
register
(
"
Tuple
"
)
@
_get_datatype
.
register
(
"
IndexedSlices
"
)
def
_tensors_get_datatype_with_sparse
(
grad
):
def
_tensors_get_datatype_with_sparse
(
grad
):
"""
"""
Acquire gradient datatype.
Acquire gradient datatype.
Args:
Args:
grad (
Tuple): The gradient tensor
before operation.
grad (
IndexedSlices): The gradient
before operation.
Returns:
Returns:
mstype, the datatype of gradient.
mstype, the datatype of gradient.
"""
"""
return
F
.
dtype
(
grad
[
1
]
)
return
F
.
dtype
(
grad
.
values
()
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
_cast_datatype
=
C
.
MultitypeFuncGraph
(
"_cast_datatype"
)
...
@@ -155,20 +156,20 @@ def _tensors_cast_datatype(datatype, grad):
...
@@ -155,20 +156,20 @@ def _tensors_cast_datatype(datatype, grad):
return
F
.
cast
(
grad
,
datatype
)
return
F
.
cast
(
grad
,
datatype
)
@
_cast_datatype
.
register
(
"TypeType"
,
"
Tuple
"
)
@
_cast_datatype
.
register
(
"TypeType"
,
"
IndexedSlices
"
)
def
_tensors_cast_datatype_with_sparse
(
datatype
,
grad
):
def
_tensors_cast_datatype_with_sparse
(
datatype
,
grad
):
"""
"""
Cast gradient to datatype.
Cast gradient to datatype.
Args:
Args:
datatype (mstype): the destination datatype of gradient.
datatype (mstype): the destination datatype of gradient.
grad (
Tuple): The gradient tensor
before operation.
grad (
IndexedSlices): The gradient
before operation.
Returns:
Returns:
Tuple, the gradient tuple
after operation.
IndexedSlices, the gradient
after operation.
"""
"""
dout
=
F
.
cast
(
grad
[
1
]
,
datatype
)
dout
=
F
.
cast
(
grad
.
values
()
,
datatype
)
return
(
grad
[
0
],
dout
,
grad
[
2
]
)
return
IndexedSlices
(
grad
.
indices
(),
dout
,
grad
.
dense_shape
()
)
class
DistributedGradReducer
(
Cell
):
class
DistributedGradReducer
(
Cell
):
...
...
mindspore/ops/_grad/grad_array_ops.py
浏览文件 @
86889c59
...
@@ -25,6 +25,7 @@ from .grad_base import bprop_getters
...
@@ -25,6 +25,7 @@ from .grad_base import bprop_getters
from
..primitive
import
constexpr
from
..primitive
import
constexpr
from
...
import
context
from
...
import
context
from
...common
import
dtype
as
mstype
from
...common
import
dtype
as
mstype
from
...common.tensor
import
IndexedSlices
reduce_sum
=
P
.
ReduceSum
()
reduce_sum
=
P
.
ReduceSum
()
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
unsorted_segment_sum
=
P
.
UnsortedSegmentSum
()
...
@@ -206,7 +207,7 @@ def get_bprop_embedding_lookup(self):
...
@@ -206,7 +207,7 @@ def get_bprop_embedding_lookup(self):
actual_dout_shape_changed
=
new_indices_shape_changed
+
x_shp_tail
actual_dout_shape_changed
=
new_indices_shape_changed
+
x_shp_tail
# Reshape the 'actual_dout' on device
# Reshape the 'actual_dout' on device
actual_dout
=
reshape_op
(
dout
,
actual_dout_shape_changed
)
actual_dout
=
reshape_op
(
dout
,
actual_dout_shape_changed
)
return
(
new_indices
,
actual_dout
,
x_shp
),
zeros_like
(
indices
),
zeros_like
(
offset
)
return
IndexedSlices
(
new_indices
,
actual_dout
,
x_shp
),
zeros_like
(
indices
),
zeros_like
(
offset
)
return
bprop_sparse
return
bprop_sparse
...
@@ -335,7 +336,7 @@ def get_bprop_sparse_gather_v2(self):
...
@@ -335,7 +336,7 @@ def get_bprop_sparse_gather_v2(self):
values_shape
=
indices_size
+
x_tail_shp
values_shape
=
indices_size
+
x_tail_shp
values
=
reshape
(
dout
,
values_shape
)
values
=
reshape
(
dout
,
values_shape
)
indices
=
reshape
(
indices
,
indices_size
)
indices
=
reshape
(
indices
,
indices_size
)
return
(
indices
,
values
,
x_shp
),
zeros_like
(
indices
),
zeros_like
(
axis
)
return
IndexedSlices
(
indices
,
values
,
x_shp
),
zeros_like
(
indices
),
zeros_like
(
axis
)
if
F
.
rank
(
dout
)
==
0
:
if
F
.
rank
(
dout
)
==
0
:
dout
=
P
.
ExpandDims
()(
dout
,
-
1
)
dout
=
P
.
ExpandDims
()(
dout
,
-
1
)
if
F
.
rank
(
indices
)
==
0
:
if
F
.
rank
(
indices
)
==
0
:
...
...
mindspore/ops/_grad/grad_comm_ops.py
浏览文件 @
86889c59
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
functional
as
F
from
..
import
operations
as
P
from
..
import
operations
as
P
from
...common.tensor
import
IndexedSlices
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
..operations.comm_ops
import
(
AllGather
,
_HostAllGather
,
AllReduce
,
_AlltoAll
,
Broadcast
,
from
..operations.comm_ops
import
(
AllGather
,
_HostAllGather
,
AllReduce
,
_AlltoAll
,
Broadcast
,
_GetTensorSlice
,
_MirrorOperator
,
ReduceOp
,
_GetTensorSlice
,
_MirrorOperator
,
ReduceOp
,
...
@@ -46,9 +47,9 @@ def get_bprop_all_reduce(self):
...
@@ -46,9 +47,9 @@ def get_bprop_all_reduce(self):
if
F
.
issubclass_
(
F
.
typeof
(
dout
),
mstype
.
tensor
):
if
F
.
issubclass_
(
F
.
typeof
(
dout
),
mstype
.
tensor
):
dx
=
all_reduce_grad
(
dout
)
dx
=
all_reduce_grad
(
dout
)
else
:
else
:
indices
=
all_gather
(
dout
[
0
]
)
indices
=
all_gather
(
dout
.
indices
()
)
grad
=
all_gather
(
dout
[
1
]
)
grad
=
all_gather
(
dout
.
values
()
)
dx
=
(
indices
,
grad
,
dout
[
2
]
)
dx
=
IndexedSlices
(
indices
,
grad
,
dout
.
dense_shape
()
)
return
(
dx
,)
return
(
dx
,)
else
:
else
:
...
@@ -59,12 +60,12 @@ def get_bprop_all_reduce(self):
...
@@ -59,12 +60,12 @@ def get_bprop_all_reduce(self):
z
=
cast
(
z
,
dtype
(
dx
))
z
=
cast
(
z
,
dtype
(
dx
))
dx
=
mul
(
dx
,
z
)
dx
=
mul
(
dx
,
z
)
else
:
else
:
indices
=
all_gather
(
dout
[
0
]
)
indices
=
all_gather
(
dout
.
indices
()
)
grad
=
all_gather
(
dout
[
1
]
)
grad
=
all_gather
(
dout
.
values
()
)
z
=
equal
(
x
,
out
)
z
=
equal
(
x
,
out
)
z
=
cast
(
z
,
dtype
(
grad
))
z
=
cast
(
z
,
dtype
(
grad
))
grad
=
mul
(
grad
,
z
)
grad
=
mul
(
grad
,
z
)
dx
=
(
indices
,
grad
,
dout
[
2
]
)
dx
=
IndexedSlices
(
indices
,
grad
,
dout
.
dense_shape
()
)
return
(
dx
,)
return
(
dx
,)
return
bprop
return
bprop
...
@@ -194,19 +195,19 @@ def get_bprop_mirror_operator(self):
...
@@ -194,19 +195,19 @@ def get_bprop_mirror_operator(self):
num
=
F
.
scalar_cast
(
dev_num
,
F
.
dtype
(
dx
))
num
=
F
.
scalar_cast
(
dev_num
,
F
.
dtype
(
dx
))
dx
=
mul
(
dx
,
cast
(
F
.
scalar_to_array
(
float_one
/
num
),
F
.
dtype
(
dx
)))
dx
=
mul
(
dx
,
cast
(
F
.
scalar_to_array
(
float_one
/
num
),
F
.
dtype
(
dx
)))
else
:
else
:
indices
=
all_gather
(
dout
[
0
]
)
indices
=
all_gather
(
dout
.
indices
()
)
grad
=
all_gather
(
dout
[
1
]
)
grad
=
all_gather
(
dout
.
values
()
)
float_one
=
F
.
scalar_cast
(
1.0
,
F
.
dtype
(
grad
))
float_one
=
F
.
scalar_cast
(
1.0
,
F
.
dtype
(
grad
))
num
=
F
.
scalar_cast
(
dev_num
,
F
.
dtype
(
grad
))
num
=
F
.
scalar_cast
(
dev_num
,
F
.
dtype
(
grad
))
grad
=
mul
(
grad
,
cast
(
F
.
scalar_to_array
(
float_one
/
num
),
F
.
dtype
(
grad
)))
grad
=
mul
(
grad
,
cast
(
F
.
scalar_to_array
(
float_one
/
num
),
F
.
dtype
(
grad
)))
dx
=
(
indices
,
grad
,
dout
[
2
]
)
dx
=
(
indices
,
grad
,
dout
.
dense_shape
()
)
else
:
else
:
if
F
.
issubclass_
(
F
.
typeof
(
dout
),
mstype
.
tensor
):
if
F
.
issubclass_
(
F
.
typeof
(
dout
),
mstype
.
tensor
):
dx
=
all_reduce
(
dout
)
dx
=
all_reduce
(
dout
)
else
:
else
:
indices
=
all_gather
(
dout
[
0
]
)
indices
=
all_gather
(
dout
.
indices
()
)
grad
=
all_gather
(
dout
[
1
]
)
grad
=
all_gather
(
dout
.
values
()
)
dx
=
(
indices
,
grad
,
dout
[
2
]
)
dx
=
(
indices
,
grad
,
dout
.
dense_shape
()
)
return
(
dx
,)
return
(
dx
,)
return
bprop
return
bprop
...
...
tests/ut/python/nn/optim/test_adam_with_tuple_grad.py
已删除
100644 → 0
浏览文件 @
4dc96564
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test adam """
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
,
Parameter
,
context
from
mindspore.common.api
import
_executor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Optimizer
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
context
.
set_context
(
enable_sparse
=
True
)
adam_opt_for_map
=
C
.
MultitypeFuncGraph
(
"adam_opt_for_map"
)
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Bool"
)
def
_update_run_op_for_map
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
op_mul
=
P
.
Mul
()
op_square
=
P
.
Square
()
op_sqrt
=
P
.
Sqrt
()
op_cast
=
P
.
Cast
()
op_reshape
=
P
.
Reshape
()
op_shape
=
P
.
Shape
()
param_fp32
=
op_cast
(
param
,
mstype
.
float32
)
m_fp32
=
op_cast
(
m
,
mstype
.
float32
)
v_fp32
=
op_cast
(
v
,
mstype
.
float32
)
gradient_fp32
=
op_cast
(
gradient
,
mstype
.
float32
)
next_m
=
op_mul
(
beta1
,
m_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta1
,
gradient_fp32
)
next_v
=
op_mul
(
beta2
,
v_fp32
)
+
op_mul
(
op_cast
(
F
.
tuple_to_array
((
1.0
,)),
mstype
.
float32
)
-
beta2
,
op_square
(
gradient_fp32
))
update
=
next_m
/
(
op_sqrt
(
next_v
)
+
eps
)
if
decay_flag
:
update
=
update
+
op_mul
(
weight_decay_tensor
,
param_fp32
)
update_with_lr
=
op_mul
(
lr
,
update
)
next_param
=
param_fp32
-
op_reshape
(
update_with_lr
,
op_shape
(
param_fp32
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
param
,
next_param
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
m
,
next_m
))
next_v
=
F
.
depend
(
next_v
,
F
.
assign
(
v
,
next_v
))
return
next_v
@
adam_opt_for_map
.
register
(
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tensor"
,
"Tuple"
,
"Bool"
)
def
_update_run_op_sparse_for_map
(
beta1
,
beta2
,
eps
,
lr
,
weight_decay_tensor
,
param
,
m
,
v
,
gradient
,
decay_flag
):
return
gradient
[
2
][
2
]
def
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
prim_name
):
"""Check the type of inputs."""
validator
.
check_value_type
(
"beta1"
,
beta1
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"beta2"
,
beta2
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"eps"
,
eps
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"weight_dacay"
,
weight_decay
,
[
float
],
prim_name
)
validator
.
check_number_range
(
"beta1"
,
beta1
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"beta2"
,
beta2
,
0.0
,
1.0
,
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"eps"
,
eps
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
prim_name
)
validator
.
check_number_range
(
"weight_decay"
,
weight_decay
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
prim_name
)
class
AdamWeightDecaySparse
(
Optimizer
):
"""
Implements Adam algorithm weight decay fix.
Args:
params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
should be class mindspore.Parameter.
learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
Iterable or a Tensor and the dims of the Tensor is 1,
use dynamic learning rate, then the i-th step will
take the i-th value as the learning rate.
When the learning_rate is float or learning_rate is a Tensor
but the dims of the Tensor is 0, use fixed learning rate.
Other cases are not supported. Default: 1e-3.
beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
Should be in range (0.0, 1.0).
beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
Should be in range (0.0, 1.0).
eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
Should be greater than 0.
weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
Inputs:
- **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`,
and might be in sparse format.
Outputs:
tuple[Parameter], the updated velocity value, the shape is the same as `params`.
Examples:
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.AdamWeightDecay(params=net.trainable_params())
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
def
__init__
(
self
,
params
,
learning_rate
=
1e-3
,
beta1
=
0.9
,
beta2
=
0.999
,
eps
=
1e-6
,
weight_decay
=
0.0
,
decay_filter
=
lambda
x
:
'beta'
not
in
x
.
name
and
'gamma'
not
in
x
.
name
):
super
(
AdamWeightDecaySparse
,
self
).
__init__
(
learning_rate
,
params
)
if
self
.
is_group
:
raise
RuntimeError
(
f
"The
{
self
.
cls_name
}
optimizer cannot support group setting."
)
_check_param_value
(
beta1
,
beta2
,
eps
,
weight_decay
,
self
.
cls_name
)
self
.
beta1
=
Tensor
(
np
.
array
([
beta1
]).
astype
(
np
.
float32
))
self
.
beta2
=
Tensor
(
np
.
array
([
beta2
]).
astype
(
np
.
float32
))
self
.
eps
=
Tensor
(
np
.
array
([
eps
]).
astype
(
np
.
float32
))
self
.
weight_decay_tensor
=
Tensor
(
np
.
array
([
weight_decay
]).
astype
(
np
.
float32
))
self
.
params
=
self
.
parameters
self
.
moments1
=
self
.
params
.
clone
(
prefix
=
"adam_m"
,
init
=
'zeros'
)
self
.
moments2
=
self
.
params
.
clone
(
prefix
=
"adam_v"
,
init
=
'zeros'
)
self
.
decay_flag
=
tuple
(
decay_filter
(
x
)
for
x
in
self
.
params
)
self
.
map
=
C
.
Map
()
def
construct
(
self
,
gradients
):
lr
=
self
.
get_lr
()
updated_velocity
=
self
.
map
(
F
.
partial
(
adam_opt_for_map
,
self
.
beta1
,
self
.
beta2
,
self
.
eps
,
lr
,
self
.
weight_decay_tensor
),
self
.
params
,
self
.
moments1
,
self
.
moments2
,
gradients
,
self
.
decay_flag
)
return
updated_velocity
def
test_AdamWeightDecaySparse
():
""" test_AdamWeightDecaySparse """
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
class
Loss
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Loss
,
self
).
__init__
()
def
construct
(
self
,
base
,
target
):
return
base
class
NetWithSparseGatherV2
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetWithSparseGatherV2
,
self
).
__init__
()
self
.
w1
=
Parameter
(
Tensor
(
np
.
ones
([
3
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w1"
)
self
.
w2
=
Parameter
(
Tensor
(
np
.
ones
([
2
,
1
,
2
]).
astype
(
np
.
float32
)),
name
=
"w2"
)
self
.
gatherv2
=
P
.
SparseGatherV2
()
self
.
axis
=
0
def
construct
(
self
,
indices
):
return
self
.
gatherv2
(
self
.
w1
,
indices
,
self
.
axis
)
*
self
.
w2
inputs
=
Tensor
(
np
.
array
([
0
,
1
]).
astype
(
np
.
int32
))
label
=
Tensor
(
np
.
zeros
([
2
,
1
,
2
]).
astype
(
np
.
float32
))
net
=
NetWithSparseGatherV2
()
net
.
set_train
()
loss
=
Loss
()
optimizer
=
AdamWeightDecaySparse
(
net
.
trainable_params
())
net_with_loss
=
WithLossCell
(
net
,
loss
)
train_network
=
TrainOneStepCell
(
net_with_loss
,
optimizer
)
_executor
.
compile
(
train_network
,
inputs
,
label
)
tests/ut/python/parallel/test_sparse_feature_bprop.py
浏览文件 @
86889c59
...
@@ -19,8 +19,8 @@ import mindspore as ms
...
@@ -19,8 +19,8 @@ import mindspore as ms
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
,
IndexedSlices
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
,
operations
as
P
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
_MirrorOperator
from
mindspore.ops.operations.comm_ops
import
AllReduce
,
_MirrorOperator
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
...
@@ -65,7 +65,7 @@ def get_bprop_gather_v2(self):
...
@@ -65,7 +65,7 @@ def get_bprop_gather_v2(self):
"""Generate bprop for GatherV2"""
"""Generate bprop for GatherV2"""
def
bprop
(
x
,
indices
,
axis
,
out
,
dout
):
def
bprop
(
x
,
indices
,
axis
,
out
,
dout
):
return
(
indices
,
dout
,
x
),
axis
,
out
return
IndexedSlices
(
indices
,
dout
,
x
),
axis
,
out
return
bprop
return
bprop
...
@@ -78,7 +78,7 @@ def test_bprop_with_sparse_feature_allreduce():
...
@@ -78,7 +78,7 @@ def test_bprop_with_sparse_feature_allreduce():
if
shape
is
None
:
if
shape
is
None
:
shape
=
[
8
,
8
]
shape
=
[
8
,
8
]
self
.
all_reduce
=
AllReduce
()
self
.
all_reduce
=
AllReduce
()
self
.
gatherv2
=
Virtual
GatherV2
()
self
.
gatherv2
=
P
.
GatherV2
()
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
axis
=
axis
self
.
axis
=
axis
...
@@ -102,7 +102,7 @@ def test_bprop_with_sparse_feature_mirror():
...
@@ -102,7 +102,7 @@ def test_bprop_with_sparse_feature_mirror():
if
shape
is
None
:
if
shape
is
None
:
shape
=
[
8
,
8
]
shape
=
[
8
,
8
]
self
.
mirror
=
_MirrorOperator
(
group
=
HCCL_WORLD_COMM_GROUP
)
self
.
mirror
=
_MirrorOperator
(
group
=
HCCL_WORLD_COMM_GROUP
)
self
.
gatherv2
=
Virtual
GatherV2
()
self
.
gatherv2
=
P
.
GatherV2
()
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
index
=
Tensor
(
np
.
ones
(
shape
),
dtype
=
ms
.
int32
)
self
.
axis
=
axis
self
.
axis
=
axis
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录