Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b052ecf4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b052ecf4
编写于
7月 02, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2804 Add TBE operators SparseApplyFtrlV2\SparseApplyAdagradV2 for VM.
Merge pull request !2804 from liuxiao93/ops-for-VM
上级
68731921
ddedc416
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
316 addition
and
2 deletion
+316
-2
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
+2
-0
mindspore/ops/_op_impl/tbe/__init__.py
mindspore/ops/_op_impl/tbe/__init__.py
+2
-0
mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py
mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py
+48
-0
mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py
mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py
+52
-0
mindspore/ops/operations/__init__.py
mindspore/ops/operations/__init__.py
+3
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+176
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+33
-0
未找到文件。
mindspore/ccsrc/kernel/tbe/tbe_adapter.cc
浏览文件 @
b052ecf4
...
...
@@ -70,11 +70,13 @@ static std::map<string, string> tbe_func_adapter_map = {
{
"strided_slice"
,
"strided_slice_d"
},
{
"strided_slice_grad"
,
"strided_slice_grad_d"
},
{
"sparse_apply_ftrl"
,
"sparse_apply_ftrl_d"
},
{
"sparse_apply_ftrl_v2"
,
"sparse_apply_ftrl_v2_d"
},
{
"apply_ada_max"
,
"apply_ada_max_d"
},
{
"apply_adadelta"
,
"apply_adadelta_d"
},
{
"apply_adagrad"
,
"apply_adagrad_d"
},
{
"apply_adagrad_v2"
,
"apply_adagradv2_d"
},
{
"sparse_apply_adagrad"
,
"sparse_apply_adagrad_d"
},
{
"sparse_apply_adagrad_v2"
,
"sparse_apply_adagrad_v2_d"
},
{
"apply_proximal_adagrad"
,
"apply_proximal_adagrad_d"
},
{
"sparse_apply_proximal_adagrad"
,
"sparse_apply_proximal_adagrad_d"
},
{
"apply_add_sign"
,
"apply_add_sign_d"
},
...
...
mindspore/ops/_op_impl/tbe/__init__.py
浏览文件 @
b052ecf4
...
...
@@ -38,6 +38,8 @@ from .apply_add_sign import _apply_add_sign_tbe
from
.apply_power_sign
import
_apply_power_sign_tbe
from
.apply_gradient_descent
import
_apply_gradient_descent_tbe
from
.apply_proximal_gradient_descent
import
_apply_proximal_gradient_descent_tbe
from
.sparse_apply_ftrl_v2
import
_sparse_apply_ftrl_v2_tbe
from
.sparse_apply_adagrad_v2
import
_sparse_apply_adagrad_v2_tbe
from
.approximate_equal
import
_approximate_equal_tbe
from
.adam_apply_one
import
_adam_apply_one_tbe
from
.assign
import
_assign_tbe
...
...
mindspore/ops/_op_impl/tbe/sparse_apply_adagrad_v2.py
0 → 100644
浏览文件 @
b052ecf4
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyAdagradV2D op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
sparse_apply_adagrad_v2_d_op_info
=
TBERegOp
(
"SparseApplyAdagradV2"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"sparse_apply_adagrad_v2_d.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"sparse_apply_adagrad_v2_d"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"lr"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"epsilon"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"use_locking"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"update_slots"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"grad"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"indices"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
I32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
)
\
.
dtype_format
(
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
I32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
sparse_apply_adagrad_v2_d_op_info
)
def
_sparse_apply_adagrad_v2_tbe
():
"""SparseApplyAdagradV2D TBE register"""
return
mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_v2.py
0 → 100644
浏览文件 @
b052ecf4
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SparseApplyFtrlV2D op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
sparse_apply_ftrl_v2_d_op_info
=
TBERegOp
(
"SparseApplyFtrlV2"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"sparse_apply_ftrl_v2_d.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"sparse_apply_ftrl_v2_d"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"lr"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"l1"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"l2"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"l2_shrinkage"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"lr_power"
,
"required"
,
"float"
,
"all"
)
\
.
attr
(
"use_locking"
,
"optional"
,
"bool"
,
"true,false"
,
"false"
)
\
.
input
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"linear"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"grad"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"indices"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"var"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"accum"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"linear"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
I32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
,
DataType
.
F32_NCHW
)
\
.
dtype_format
(
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
I32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
,
DataType
.
F32_NHWC
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
I32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
get_op_info
()
@
op_info_register
(
sparse_apply_ftrl_v2_d_op_info
)
def
_sparse_apply_ftrl_v2_tbe
():
"""SparseApplyFtrlV2D TBE register"""
return
mindspore/ops/operations/__init__.py
浏览文件 @
b052ecf4
...
...
@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
SoftmaxCrossEntropyWithLogits
,
ROIAlign
,
SparseSoftmaxCrossEntropyWithLogits
,
Tanh
,
TopK
,
BinaryCrossEntropy
,
SparseApplyAdagrad
,
LARSUpdate
,
ApplyFtrl
,
SparseApplyFtrl
,
ApplyProximalAdagrad
,
SparseApplyProximalAdagrad
,
ApplyProximalAdagrad
,
SparseApplyProximalAdagrad
,
SparseApplyAdagradV2
,
SparseApplyFtrlV2
,
ApplyAdaMax
,
ApplyAdadelta
,
ApplyAdagrad
,
ApplyAdagradV2
,
ApplyAddSign
,
ApplyPowerSign
,
ApplyGradientDescent
,
ApplyProximalGradientDescent
,
ApplyRMSProp
,
ApplyCenteredRMSProp
,
BasicLSTMCell
,
InTopK
)
...
...
@@ -284,6 +284,7 @@ __all__ = [
"Abs"
,
"BinaryCrossEntropy"
,
"SparseApplyAdagrad"
,
"SparseApplyAdagradV2"
,
"SpaceToDepth"
,
"DepthToSpace"
,
"Conv2DBackpropInput"
,
...
...
@@ -294,6 +295,7 @@ __all__ = [
"ApplyFtrl"
,
"SpaceToBatch"
,
"SparseApplyFtrl"
,
"SparseApplyFtrlV2"
,
"ApplyProximalAdagrad"
,
"SparseApplyProximalAdagrad"
,
"ApplyAdaMax"
,
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
b052ecf4
...
...
@@ -3600,6 +3600,88 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
return
var_type
,
accum_type
class
SparseApplyAdagradV2
(
PrimitiveWithInfer
):
r
"""
Update relevant entries according to the adagrad scheme.
.. math::
accum += grad * grad
.. math::
var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon}
Args:
lr (float): Learning rate.
epsilon (float): A small value added for numerical stability.
use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False.
update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True.
Inputs:
- **var** (Parameter) - Variable to be updated. The type must be float32.
- **accum** (Parameter) - Accum to be updated. The shape must be the same as `var`'s shape,
the type must be float32.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension,
the type must be float32.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
Outputs:
Tuple of 2 Tensor, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
Examples:
>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, Parameter
>>> from mindspore.ops import operations as P
>>> import mindspore.common.dtype as mstype
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6)
>>> self.var = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.ones([3, 3, 3]).astype(np.float32)), name="accum")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices)
>>> return out
>>> net = Net()
>>> grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
>>> indices = Tensor([0, 1, 2], mstype.int32)
>>> result = net(grad, indices)
"""
__mindspore_signature__
=
(
(
'var'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'accum'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'grad'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
)
)
@
prim_attr_register
def
__init__
(
self
,
lr
,
epsilon
,
use_locking
=
False
,
update_slots
=
True
):
self
.
lr
=
validator
.
check_value_type
(
"lr"
,
lr
,
[
float
],
self
.
name
)
self
.
epsilon
=
validator
.
check_value_type
(
"epsilon"
,
epsilon
,
[
float
],
self
.
name
)
self
.
use_locking
=
validator
.
check_value_type
(
"update_slots"
,
update_slots
,
[
bool
],
self
.
name
)
self
.
update_slots
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
def
infer_shape
(
self
,
var_shape
,
accum_shape
,
grad_shape
,
indices_shape
):
validator
.
check
(
'var shape'
,
var_shape
,
'accum shape'
,
accum_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'len of var shape'
,
len
(
var_shape
),
'len of grad shape'
,
len
(
grad_shape
),
Rel
.
EQ
,
self
.
name
)
if
len
(
var_shape
)
>
1
:
validator
.
check
(
'var_shape[1:]'
,
var_shape
[
1
:],
'grad_shape[1:]'
,
grad_shape
[
1
:],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'grad_shape[0]'
,
grad_shape
[
0
],
'indices_shape[0]'
,
indices_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
return
var_shape
,
accum_shape
def
infer_dtype
(
self
,
var_type
,
accum_type
,
grad_type
,
indices_type
):
args
=
{
'var'
:
var_type
,
'accum'
:
accum_type
,
'grad'
:
grad_type
}
validator
.
check_tensor_type_same
(
args
,
[
mstype
.
float32
],
self
.
name
)
validator
.
check_tensor_type_same
({
'indices'
:
indices_type
},
[
mstype
.
int32
],
self
.
name
)
return
var_type
,
accum_type
class
ApplyProximalAdagrad
(
PrimitiveWithInfer
):
r
"""
Update relevant entries according to the proximal adagrad algorithm.
...
...
@@ -3664,7 +3746,8 @@ class ApplyProximalAdagrad(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'accum'
,
'lr'
,
'l1'
,
'l2'
,
'grad'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'var'
,
'accum'
,
'lr'
,
'l1'
,
'l2'
,
'grad'
],
outputs
=
[
'var'
,
'accum'
])
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
def
infer_shape
(
self
,
var_shape
,
accum_shape
,
lr_shape
,
l1_shape
,
l2_shape
,
grad_shape
):
...
...
@@ -4371,6 +4454,98 @@ class SparseApplyFtrl(PrimitiveWithInfer):
return
var_dtype
,
accum_dtype
,
linear_dtype
class
SparseApplyFtrlV2
(
PrimitiveWithInfer
):
"""
Update relevant entries according to the FTRL-proximal scheme.
Args:
lr (float): The learning rate value, must be positive.
l1 (float): l1 regularization strength, must be greater than or equal to zero.
l2 (float): l2 regularization strength, must be greater than or equal to zero.
l2_shrinkage (float): L2 shrinkage regularization.
lr_power (float): Learning rate power controls how the learning rate decreases during training,
must be less than or equal to zero. Use fixed learning rate if `lr_power` is zero.
use_locking (bool): If `True`, updating of the var and accum tensors will be protected. Default: False.
Inputs:
- **var** (Parameter) - The variable to be updated. The data type must be float32.
- **accum** (Parameter) - The accum to be updated, must be same type and shape as `var`.
- **linear** (Parameter) - The linear to be updated, must be same type and shape as `var`.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32.
Outputs:
Tuple of 3 Tensor, the updated parameters.
- **var** (Tensor): Tensor, has the same shape and type as `var`.
- **accum** (Tensor): Tensor, has the same shape and type as `accum`.
- **linear** (Tensor): Tensor, has the same shape and type as `linear`.
Examples:
>>> import mindspore
>>> import mindspore.nn as nn
>>> import numpy as np
>>> from mindspore import Parameter
>>> from mindspore import Tensor
>>> from mindspore.ops import operations as P
>>> class SparseApplyFtrlV2Net(nn.Cell):
>>> def __init__(self):
>>> super(SparseApplyFtrlV2Net, self).__init__()
>>> self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2(lr=0.01, l1=0.0, l2=0.0,
l2_shrinkage=0.0, lr_power=-0.5)
>>> self.var = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="var")
>>> self.accum = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="accum")
>>> self.linear = Parameter(Tensor(np.random.rand(3, 3).astype(np.float32)), name="linear")
>>>
>>> def construct(self, grad, indices):
>>> out = self.sparse_apply_ftrl_v2(self.var, self.accum, self.linear, grad, indices)
>>> return out
>>>
>>> net = SparseApplyFtrlV2Net()
>>> grad = Tensor(np.random.rand(3, 3).astype(np.float32))
>>> indices = Tensor(np.ones([3]), mindspore.int32)
>>> output = net(grad, indices)
"""
__mindspore_signature__
=
(
(
'var'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'accum'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'linear'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'grad'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
)
)
@
prim_attr_register
def
__init__
(
self
,
lr
,
l1
,
l2
,
l2_shrinkage
,
lr_power
,
use_locking
=
False
):
validator
.
check_value_type
(
"lr"
,
lr
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l1"
,
l1
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"l2"
,
l2
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
"lr_power"
,
lr_power
,
[
float
],
self
.
name
)
self
.
lr
=
validator
.
check_number_range
(
"lr"
,
lr
,
0.0
,
float
(
"inf"
),
Rel
.
INC_NEITHER
,
self
.
name
)
self
.
l1
=
validator
.
check_number_range
(
"l1"
,
l1
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
self
.
name
)
self
.
l2
=
validator
.
check_number_range
(
"l2"
,
l2
,
0.0
,
float
(
"inf"
),
Rel
.
INC_LEFT
,
self
.
name
)
self
.
lr_power
=
validator
.
check_number
(
"lr_power"
,
lr_power
,
0
,
Rel
.
LE
,
self
.
name
)
self
.
l2_shrinkage
=
validator
.
check_value_type
(
"l2_shrinkage"
,
l2_shrinkage
,
[
float
],
self
.
name
)
self
.
use_locking
=
validator
.
check_value_type
(
"use_locking"
,
use_locking
,
[
bool
],
self
.
name
)
def
infer_shape
(
self
,
var_shape
,
accum_shape
,
linear_shape
,
grad_shape
,
indices_shape
):
validator
.
check
(
'var shape'
,
var_shape
,
'accum shape'
,
accum_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'var shape'
,
var_shape
,
'linear shape'
,
linear_shape
,
Rel
.
EQ
,
self
.
name
)
if
len
(
var_shape
)
>
1
:
validator
.
check
(
'var_shape[1:]'
,
var_shape
[
1
:],
'grad_shape[1:]'
,
grad_shape
[
1
:],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"indices rank"
,
len
(
indices_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
'grad_shape[0]'
,
grad_shape
[
0
],
'indices_shape[0]'
,
indices_shape
[
0
],
Rel
.
EQ
,
self
.
name
)
return
var_shape
,
accum_shape
,
linear_shape
def
infer_dtype
(
self
,
var_dtype
,
accum_dtype
,
linear_dtype
,
grad_dtype
,
indices_dtype
):
args
=
{
"var_dtype"
:
var_dtype
,
"accum_dtype"
:
accum_dtype
,
"linear_dtype"
:
linear_dtype
,
"grad_dtype"
:
grad_dtype
}
validator
.
check_tensor_type_same
(
args
,
[
mstype
.
float32
],
self
.
name
)
validator
.
check_tensor_type_same
({
"indices_dtype"
:
indices_dtype
},
[
mstype
.
int32
],
self
.
name
)
return
var_dtype
,
accum_dtype
,
linear_dtype
class
ConfusionMulGrad
(
PrimitiveWithInfer
):
"""
`output0` is the result of which input0 dot multily input1.
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
b052ecf4
...
...
@@ -306,6 +306,19 @@ class SparseApplyFtrlNet(nn.Cell):
return
out
class
SparseApplyFtrlV2Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseApplyFtrlV2Net
,
self
).
__init__
()
self
.
sparse_apply_ftrl_v2
=
P
.
SparseApplyFtrlV2
(
lr
=
0.001
,
l1
=
0.0
,
l2
=
0.0
,
l2_shrinkage
=
0.0
,
lr_power
=-
0.5
)
self
.
var
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"var"
)
self
.
accum
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"accum"
)
self
.
linear
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"linear"
)
def
construct
(
self
,
grad
,
indices
):
out
=
self
.
sparse_apply_ftrl_v2
(
self
.
var
,
self
.
accum
,
self
.
linear
,
grad
,
indices
)
return
out
class
SparseApplyProximalAdagradNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseApplyProximalAdagradNet
,
self
).
__init__
()
...
...
@@ -467,6 +480,18 @@ class SparseApplyAdagradNet(nn.Cell):
return
out
class
SparseApplyAdagradV2Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SparseApplyAdagradV2Net
,
self
).
__init__
()
self
.
sparse_apply_adagrad_v2
=
P
.
SparseApplyAdagradV2
(
lr
=
0.01
,
epsilon
=
0.001
)
self
.
var
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"var"
)
self
.
accum
=
Parameter
(
Tensor
(
np
.
random
.
rand
(
3
,
3
).
astype
(
np
.
float32
)),
name
=
"accum"
)
def
construct
(
self
,
grad
,
indices
):
out
=
self
.
sparse_apply_adagrad_v2
(
self
.
var
,
self
.
accum
,
grad
,
indices
)
return
out
class
ApplyRMSNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ApplyRMSNet
,
self
).
__init__
()
...
...
@@ -1376,10 +1401,18 @@ test_case_nn_ops = [
'desc_inputs'
:
[[
3
,
3
],
Tensor
(
np
.
ones
((
3
,),
np
.
int32
))],
'desc_bprop'
:
[[
3
,
3
],
[
3
,
3
]],
'skip'
:
[
'backward'
]}),
(
'SparseApplyAdagradV2'
,
{
'block'
:
SparseApplyAdagradV2Net
(),
'desc_inputs'
:
[[
3
,
3
],
Tensor
(
np
.
ones
((
3
,),
np
.
int32
))],
'skip'
:
[
'backward'
]}),
(
'SparseApplyFtrl'
,
{
'block'
:
SparseApplyFtrlNet
(),
'desc_inputs'
:
[[
3
,
3
],
Tensor
(
np
.
ones
((
3
,),
np
.
int32
))],
'skip'
:
[
'backward'
]}),
(
'SparseApplyFtrlV2'
,
{
'block'
:
SparseApplyFtrlV2Net
(),
'desc_inputs'
:
[[
3
,
3
],
Tensor
(
np
.
ones
((
3
,),
np
.
int32
))],
'skip'
:
[
'backward'
]}),
(
'ApplyProximalAdagrad'
,
{
'block'
:
ApplyProximalAdagradNet
(),
'desc_inputs'
:
[[
3
,
3
]],
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录