Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
68bd5cf6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68bd5cf6
编写于
6月 29, 2020
作者:
W
wangnan39@huawei.com
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add cpu sparse optimizer ops with no return
上级
7604b66f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
221 addition
and
12 deletion
+221
-12
mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h
mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h
+12
-0
mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
...src/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
+13
-0
mindspore/nn/optim/ftrl.py
mindspore/nn/optim/ftrl.py
+2
-1
mindspore/nn/optim/proximal_ada_grad.py
mindspore/nn/optim/proximal_ada_grad.py
+2
-1
mindspore/ops/operations/_inner_ops.py
mindspore/ops/operations/_inner_ops.py
+180
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+12
-10
未找到文件。
mindspore/ccsrc/kernel/cpu/sparse_apply_ftrl_cpu_kernel.h
浏览文件 @
68bd5cf6
...
...
@@ -53,6 +53,18 @@ MS_REG_CPU_KERNEL(SparseApplyFtrl,
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
SparseApplyFtrlCPUKernel
);
MS_REG_CPU_KERNEL
(
SparseApplyFtrlNoReturn
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
SparseApplyFtrlCPUKernel
);
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/cpu/sparse_apply_proximal_adagrad_cpu_kernel.h
浏览文件 @
68bd5cf6
...
...
@@ -51,6 +51,19 @@ MS_REG_CPU_KERNEL(SparseApplyProximalAdagrad,
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
SparseApplyProximalAdagradCPUKernel
);
MS_REG_CPU_KERNEL
(
SparseApplyProximalAdagradNoReturn
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
SparseApplyProximalAdagradCPUKernel
);
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/nn/optim/ftrl.py
浏览文件 @
68bd5cf6
...
...
@@ -16,6 +16,7 @@
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common
import
Tensor
import
mindspore.common.dtype
as
mstype
from
mindspore.ops.operations
import
_inner_ops
as
inner
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
.optimizer
import
Optimizer
,
_apply_decay
,
_grad_scale
...
...
@@ -116,7 +117,7 @@ class FTRL(Optimizer):
self
.
decay_tf
=
tuple
((
lambda
:
True
)()
for
x
in
self
.
parameters
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyFtrl
(
use_locking
=
use_locking
)
self
.
sparse_opt
=
P
.
SparseApplyFtrl
(
learning_rate
,
l1
,
l2
,
lr_power
,
use_locking
=
use_locking
)
self
.
sparse_opt
=
inner
.
SparseApplyFtrlNoReturn
(
learning_rate
,
l1
,
l2
,
lr_power
,
use_locking
=
use_locking
)
def
construct
(
self
,
grads
):
params
=
self
.
parameters
...
...
mindspore/nn/optim/proximal_ada_grad.py
浏览文件 @
68bd5cf6
...
...
@@ -16,6 +16,7 @@
from
mindspore.ops
import
functional
as
F
,
composite
as
C
,
operations
as
P
from
mindspore.common
import
Tensor
import
mindspore.common.dtype
as
mstype
from
mindspore.ops.operations
import
_inner_ops
as
inner
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Rel
from
.optimizer
import
Optimizer
...
...
@@ -99,7 +100,7 @@ class ProximalAdagrad(Optimizer):
self
.
weight_decay
=
weight_decay
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyProximalAdagrad
(
use_locking
=
use_locking
)
self
.
sparse_opt
=
P
.
SparseApplyProximalAdagrad
(
use_locking
=
use_locking
)
self
.
sparse_opt
=
inner
.
SparseApplyProximalAdagradNoReturn
(
use_locking
=
use_locking
)
def
construct
(
self
,
grads
):
params
=
self
.
parameters
...
...
mindspore/ops/operations/_inner_ops.py
浏览文件 @
68bd5cf6
...
...
@@ -18,6 +18,9 @@
from
..._checkparam
import
Rel
from
..._checkparam
import
Validator
as
validator
from
...common
import
dtype
as
mstype
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._c_expression
import
signature_dtype
as
sig_dtype
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
...
...
@@ -330,6 +333,183 @@ class EmbeddingLookup(PrimitiveWithInfer):
return
out
class
SparseApplyFtrlNoReturn
(
PrimitiveWithInfer
):
"""
Update relevant entries according to the FTRL-proximal scheme.
Args:
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l2 (float): l2 regularization strength, must be greater than or equal to zero.
lr_power (float): Learning rate power controls how the learning rate decreases during training,
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
use_locking (bool): Use locks for update operation if True . Default: False.
Inputs:
- **var** (Parameter): The variable to be updated. The data type must be float32.
- **accum** (Parameter): The accum to be updated, must be same type and shape as `var`.
- **linear** (Parameter): The linear to be updated, must be same type and shape as `var`.
- **grad** (Tensor): A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor): A vector of indices into the first dimension of `var` and `accum`. The shape
of `indices` must be the same as `grad` in first dimension. The type must be int32.
Outputs:
Tuple of 3 Tensor, this operator will update the input parameters directly, the outputs are useless.
- **var** (Tensor) - A Tensor with shape (1,).
- **accum** (Tensor) - A Tensor with shape (1,).
- **linear** (Tensor) - A Tensor with shape (1,).
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlNet(nn.Cell):
>>> def __init__(self):
>>> super(SparseApplyFtrlNet, self).__init__()
>>> self.sparse_apply_ftrl = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5)
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>> self.linear = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="linear")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_ftrl(self.var, self.accum, self.linear, grad, indices)
>>> return out
>>>
>>> net = SparseApplyFtrlNet()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
"""
__mindspore_signature__
=
(
(
'var'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'accum'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'linear'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'grad'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
)
)
@
prim_attr_register
def
__init__
(
self
,
lr
,
l1
,
l2
,
lr_power
,
use_locking
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'accum'
,
'linear'
,
'grad'
,
'indices'
],
outputs
=
[
'output'
])
validator
.
check_value_type
(
"lr"
,
lr
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l1"
,
l1
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l2"
,
l2
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"lr_power"
,
lr_power
,
[
float
],
self
.
name
)
self
.
lr
=
validator
.
check_number_range
(
"lr"
,
lr
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
self
.
name
)
self
.
l1
=
validator
.
check_number_range
(
"l1"
,
l1
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
self
.
name
)
self
.
l2
=
validator
.
check_number_range
(
"l2"
,
l2
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
self
.
name
)
self
.
lr_power
=
validator
.
check_number
(
"lr_power"
,
lr_power
,
0
,
Rel
.
LE
,
self
.
name
)
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
self
.
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
infer_shape
(
self
,
var_shape
,
accum_shape
,
linear_shape
,
grad_shape
,
indices_shape
):
validator
.
check
(
'var shape'
,
var_shape
,
'accum shape'
,
accum_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'var shape'
,
var_shape
,
'linear shape'
,
linear_shape
,
Rel
.
EQ
,
self
.
name
)
if
len
(
var_shape
)
>
1
:
validator
.
check
(
'var_shape[1:]'
,
var_shape
[
1
:],
'grad_shape[1:]'
,
grad_shape
[
1
:],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'grad_shape[0]'
,
grad_shape
[
0
],
'indices_shape[0]'
,
indices_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
return
[
1
],
[
1
],
[
1
]
def
infer_dtype
(
self
,
var_dtype
,
accum_dtype
,
linear_dtype
,
grad_dtype
,
indices_dtype
):
args
=
{
"var_dtype"
:
var_dtype
,
"accum_dtype"
:
accum_dtype
,
"linear_dtype"
:
linear_dtype
,
"grad_dtype"
:
grad_dtype
}
validator
.
check_tensor_type_same
(
args
,
[
mstype
.
float32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"indices_dtype"
:
indices_dtype
},
[
mstype
.
int32
],
self
.
name
)
return
var_dtype
,
accum_dtype
,
linear_dtype
class
SparseApplyProximalAdagradNoReturn
(
PrimitiveWithInfer
):
r
"""
Updates relevant entries according to the proximal adagrad algorithm.
.. math::
accum += grad * grad
.. math::
\text{prox_v} = var - lr * grad * \frac{1}{\sqrt{accum}}
.. math::
var = \frac{sign(\text{prox_v})}{1 + lr * l2} * \max(\left| \text{prox_v} \right| - lr * l1, 0)
Args:
use_locking (bool): If True, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Parameter) - Variable tensor to be updated. The data type must be float32.
- **accum** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Tensor): The learning rate value. The data type must be float32.
- **l1** (Tensor): l1 regularization strength. The data type must be float32.
- **l2** (Tensor): l2 regularization strength. The data type must be float32.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient. The data type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`. The data type
must be int32.
Outputs:
Tuple of 2 Tensor, this operator will update the input parameters directly, the outputs are useless.
- **var** (Tensor) - A Tensor with shape (1,).
- **accum** (Tensor) - A Tensor with shape (1,).
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagradV2()
>>> self.var = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 1, 2).astype(np.float32)), name="accum")
>>> self.lr = Tensor(0.01, mstype.float32)
>>> self.l1 = Tensor(0.0, mstype.float32)
>>> self.l2 = Tensor(0.0, mstype.float32)
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_proximal_adagrad(self.var, self.accum, self.lr, self.l1,
>>> self.l2, grad, indices)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32))
>>> indices = Tensor(np.array([0, 1]).astype(np.int32))
>>> output = net(grad, indices)
"""
__mindspore_signature__
=
(
(
'var'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'accum'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'lr'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'l1'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'l2'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'grad'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
)
)
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'accum'
,
'lr'
,
'l1'
,
'l2'
,
'grad'
,
'indices'
],
outputs
=
[
'output'
])
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
self
.
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
infer_shape
(
self
,
var_shape
,
accum_shape
,
lr_shape
,
l1_shape
,
l2_shape
,
grad_shape
,
indices_shape
):
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
[
1
],
[
1
]
def
infer_dtype
(
self
,
var_dtype
,
accum_dtype
,
lr_dtype
,
l1_dtype
,
l2_dtype
,
grad_dtype
,
indices_dtype
):
args
=
{
'var'
:
var_dtype
,
'accum'
:
accum_dtype
,
'grad'
:
grad_dtype
}
validator
.
check_tensor_type_same
(
args
,
[
mstype
.
float32
],
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"lr"
:
lr_dtype
},
[
mstype
.
float32
],
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"l1"
:
l1_dtype
},
[
mstype
.
float32
],
self
.
name
)
validator
.
check_scalar_or_tensor_type_same
({
"l2"
:
l2_dtype
},
[
mstype
.
float32
],
self
.
name
)
valid_types
=
[
mstype
.
int16
,
mstype
.
int32
,
mstype
.
int64
,
mstype
.
uint16
,
mstype
.
uint32
,
mstype
.
uint64
]
validator
.
check_tensor_type_same
({
'indices'
:
indices_dtype
},
valid_types
,
self
.
name
)
return
var_dtype
,
accum_dtype
class
LinSpace
(
PrimitiveWithInfer
):
r
"""
Generates values in an interval. And return the corresponding interpolation accroding to assist.
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
68bd5cf6
...
...
@@ -2835,11 +2835,11 @@ class SparseApplyAdam(PrimitiveWithInfer):
- **indices** (Tensor) - Gradient indices. With int32 data type.
Outputs:
Tuple of 3 Tensor, th
e updated parameter
s.
Tuple of 3 Tensor, th
is operator will update the input parameters directly, the outputs are useles
s.
- **var** (Tensor) -
The same shape and data type as `var`
.
- **m** (Tensor) -
The same shape and data type as `m`
.
- **v** (Tensor) -
The same shape and data type as `v`
.
- **var** (Tensor) -
A Tensor with shape (1,)
.
- **m** (Tensor) -
A Tensor with shape (1,)
.
- **v** (Tensor) -
A Tensor with shape (1,)
.
Examples:
>>> import numpy as np
...
...
@@ -2896,6 +2896,7 @@ class SparseApplyAdam(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'm'
,
'v'
,
'beta1_power'
,
'beta2_power'
,
'lr'
,
'beta1'
,
'beta2'
,
'epsilon'
,
'grad'
,
'indices'
],
outputs
=
[
'var'
,
'm'
,
'v'
])
self
.
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
infer_shape
(
self
,
var_shape
,
m_shape
,
v_shape
,
beta1_power_shape
,
beta2_power_shape
,
lr_shape
,
beta1_shape
,
beta2_shape
,
epsilon_shape
,
grad_shape
,
indices_shape
):
...
...
@@ -2907,7 +2908,7 @@ class SparseApplyAdam(PrimitiveWithInfer):
raise
ValueError
(
f
"For '
{
self
.
name
}
', the shape of updates should be [] or "
f
"grad_shape = indices_shape + var_shape[1:], but got var_shape:
{
var_shape
}
, "
f
"indices_shape:
{
indices_shape
}
, grad_shape:
{
grad_shape
}
."
)
return
var_shape
,
m_shape
,
v_shape
return
[
1
],
[
1
],
[
1
]
def
infer_dtype
(
self
,
var_dtype
,
m_dtype
,
v_dtype
,
beta1_power_dtype
,
beta2_power_dtype
,
lr_dtype
,
beta1_dtype
,
beta2_dtype
,
epsilon_dtype
,
grad_dtype
,
indices_dtype
):
...
...
@@ -2969,11 +2970,11 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
- **indices** (Tensor) - Gradient indices. With int32 data type.
Outputs:
Tuple of 3 Tensor, th
e updated parameter
s.
Tuple of 3 Tensor, th
is operator will update the input parameters directly, the outputs are useles
s.
- **var** (Tensor) -
The same shape and data type as `var`
.
- **m** (Tensor) -
The same shape and data type as `m`
.
- **v** (Tensor) -
The same shape and data type as `v`
.
- **var** (Tensor) -
A Tensor with shape (1,)
.
- **m** (Tensor) -
A Tensor with shape (1,)
.
- **v** (Tensor) -
A Tensor with shape (1,)
.
Examples:
>>> import numpy as np
...
...
@@ -3030,6 +3031,7 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'm'
,
'v'
,
'beta1_power'
,
'beta2_power'
,
'lr'
,
'beta1'
,
'beta2'
,
'epsilon'
,
'grad'
,
'indices'
],
outputs
=
[
'var'
,
'm'
,
'v'
])
self
.
add_prim_attr
(
'primitive_target'
,
'CPU'
)
def
infer_shape
(
self
,
var_shape
,
m_shape
,
v_shape
,
beta1_power_shape
,
beta2_power_shape
,
lr_shape
,
beta1_shape
,
beta2_shape
,
epsilon_shape
,
grad_shape
,
indices_shape
):
...
...
@@ -3041,7 +3043,7 @@ class SparseApplyLazyAdam(PrimitiveWithInfer):
raise
ValueError
(
f
"For '
{
self
.
name
}
', the shape of updates should be [] or "
f
"grad_shape = indices_shape + var_shape[1:], but got var_shape:
{
var_shape
}
, "
f
"indices_shape:
{
indices_shape
}
, grad_shape:
{
grad_shape
}
."
)
return
var_shape
,
m_shape
,
v_shape
return
[
1
],
[
1
],
[
1
]
def
infer_dtype
(
self
,
var_dtype
,
m_dtype
,
v_dtype
,
beta1_power_dtype
,
beta2_power_dtype
,
lr_dtype
,
beta1_dtype
,
beta2_dtype
,
epsilon_dtype
,
grad_dtype
,
indices_dtype
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录