Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
ce0c3e3d
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
11 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ce0c3e3d
编写于
10月 18, 2019
作者:
J
jrzaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adapted the code to include the possibility of using different learning rates for different groups
上级
195bc274
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
9 deletion
+21
-9
pytorch_widedeep/optimizers.py
pytorch_widedeep/optimizers.py
+21
-9
未找到文件。
pytorch_widedeep/optimizers.py
浏览文件 @
ce0c3e3d
...
...
@@ -17,15 +17,20 @@ class MultipleOptimizers(object):
else
:
instantiated_optimizers
[
model_name
]
=
optimizer
self
.
_optimizers
=
instantiated_optimizers
def
apply
(
self
,
model
:
TorchModel
):
def
apply
(
self
,
model
:
TorchModel
,
param_group
=
None
):
children
=
list
(
model
.
children
())
children_names
=
[
child
.
__class__
.
__name__
.
lower
()
for
child
in
children
]
if
not
all
([
cn
in
children_names
for
cn
in
self
.
_optimizers
.
keys
()]):
raise
ValueError
(
'Model name has to be one of: {}'
.
format
(
children_names
))
for
child
,
name
in
zip
(
children
,
children_names
):
try
:
if
name
in
self
.
_optimizers
and
name
in
param_group
:
self
.
_optimizers
[
name
]
=
self
.
_optimizers
[
name
](
child
,
param_group
[
name
])
elif
name
in
self
.
_optimizers
:
self
.
_optimizers
[
name
]
=
self
.
_optimizers
[
name
](
child
)
e
xcept
:
e
lse
:
warnings
.
warn
(
"No optimizer found for {}. Adam optimizer with default "
"settings will be used"
.
format
(
name
))
...
...
@@ -51,8 +56,10 @@ class Adam:
self
.
weight_decay
=
weight_decay
self
.
amsgrad
=
amsgrad
def
__call__
(
self
,
submodel
:
TorchModel
)
->
Optimizer
:
self
.
opt
=
torch
.
optim
.
Adam
(
submodel
.
parameters
(),
lr
=
self
.
lr
,
betas
=
self
.
betas
,
eps
=
self
.
eps
,
def
__call__
(
self
,
submodel
:
TorchModel
,
param_group
=
None
)
->
Optimizer
:
if
param_group
is
not
None
:
params
=
param_group
else
:
params
=
submodel
.
parameters
()
self
.
opt
=
torch
.
optim
.
Adam
(
params
,
lr
=
self
.
lr
,
betas
=
self
.
betas
,
eps
=
self
.
eps
,
weight_decay
=
self
.
weight_decay
,
amsgrad
=
self
.
amsgrad
)
return
self
.
opt
...
...
@@ -66,7 +73,9 @@ class RAdam:
self
.
eps
=
eps
self
.
weight_decay
=
weight_decay
def
__call__
(
self
,
submodel
:
TorchModel
)
->
Optimizer
:
def
__call__
(
self
,
submodel
:
TorchModel
,
param_group
=
None
)
->
Optimizer
:
if
param_group
is
not
None
:
params
=
param_group
else
:
params
=
submodel
.
parameters
()
self
.
opt
=
orgRAdam
(
submodel
.
parameters
(),
lr
=
self
.
lr
,
betas
=
self
.
betas
,
eps
=
self
.
eps
,
weight_decay
=
self
.
weight_decay
)
return
self
.
opt
...
...
@@ -82,7 +91,9 @@ class SGD:
self
.
weight_decay
=
weight_decay
self
.
nesterov
=
nesterov
def
__call__
(
self
,
submodel
:
TorchModel
)
->
Optimizer
:
def
__call__
(
self
,
submodel
:
TorchModel
,
param_group
=
None
)
->
Optimizer
:
if
param_group
is
not
None
:
params
=
param_group
else
:
params
=
submodel
.
parameters
()
self
.
opt
=
torch
.
optim
.
SGD
(
submodel
.
parameters
(),
lr
=
self
.
lr
,
momentum
=
self
.
momentum
,
dampening
=
self
.
dampening
,
weight_decay
=
self
.
weight_decay
,
nesterov
=
self
.
nesterov
)
return
self
.
opt
...
...
@@ -99,8 +110,9 @@ class RMSprop:
self
.
momentum
=
momentum
self
.
centered
=
centered
def
__call__
(
self
,
submodel
:
TorchModel
)
->
Optimizer
:
def
__call__
(
self
,
submodel
:
TorchModel
,
param_group
=
None
)
->
Optimizer
:
if
param_group
is
not
None
:
params
=
param_group
else
:
params
=
submodel
.
parameters
()
self
.
opt
=
torch
.
optim
.
RMSprop
(
submodel
.
parameters
(),
lr
=
self
.
lr
,
alpha
=
self
.
alpha
,
eps
=
self
.
eps
,
weight_decay
=
self
.
weight_decay
,
momentum
=
self
.
momentum
,
centered
=
self
.
centered
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录