Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0f2fc082
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0f2fc082
编写于
5月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1405 add lars parameter check
Merge pull request !1405 from gziyan/add_lars_paramter_check
上级
61daa654
e7560214
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
8 addition
and
1 deletion
+8
-1
mindspore/nn/optim/lars.py
mindspore/nn/optim/lars.py
+8
-1
未找到文件。
mindspore/nn/optim/lars.py
浏览文件 @
0f2fc082
...
...
@@ -21,6 +21,7 @@ from mindspore.common.parameter import Parameter
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore._checkparam
import
Validator
as
validator
from
.optimizer
import
grad_scale
,
Optimizer
lars_opt
=
C
.
MultitypeFuncGraph
(
"lars_opt"
)
...
...
@@ -41,6 +42,11 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f
return
gradient
def
_check_param_value
(
optimizer
,
epsilon
,
hyperpara
,
use_clip
,
prim_name
):
validator
.
check_value_type
(
"optimizer"
,
optimizer
,
Optimizer
,
prim_name
)
validator
.
check_value_type
(
"epsilon"
,
epsilon
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"hyperpara"
,
hyperpara
,
[
float
],
prim_name
)
validator
.
check_value_type
(
"use_clip"
,
use_clip
,
[
bool
],
prim_name
)
class
LARS
(
Optimizer
):
"""
...
...
@@ -79,9 +85,10 @@ class LARS(Optimizer):
def
__init__
(
self
,
optimizer
,
epsilon
=
1e-05
,
hyperpara
=
0.001
,
weight_decay
=
0.0
,
use_clip
=
False
,
decay_filter
=
lambda
x
:
'LayerNorm'
not
in
x
.
name
and
'bias'
not
in
x
.
name
,
lars_filter
=
lambda
x
:
'LayerNorm'
not
in
x
.
name
and
'bias'
not
in
x
.
name
,
loss_scale
=
1.0
):
super
(
LARS
,
self
).
__init__
(
0.0
,
[
Parameter
(
Tensor
(
0.0
),
name
=
"trivial"
)])
super
(
LARS
,
self
).
__init__
(
0.0
,
[
Parameter
(
Tensor
(
0.0
),
name
=
"trivial"
)]
,
weight_decay
,
loss_scale
)
if
optimizer
.
is_group
:
raise
RuntimeError
(
f
"The
{
self
.
cls_name
}
optimizer cannot support group setting."
)
_check_param_value
(
optimizer
,
epsilon
,
hyperpara
,
use_clip
,
self
.
cls_name
)
self
.
opt
=
optimizer
self
.
parameters
=
optimizer
.
parameters
self
.
learning_rate
=
optimizer
.
learning_rate
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录