Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
845d630d
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
10 个月 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
845d630d
编写于
4月 18, 2022
作者:
J
Javier Rodriguez Zaurin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
#89 improved the solution for device consistency
上级
2339c3a0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
5 addition
and
7 deletion
+5
-7
pytorch_widedeep/models/wide_deep.py
pytorch_widedeep/models/wide_deep.py
+3
-6
pytorch_widedeep/training/trainer.py
pytorch_widedeep/training/trainer.py
+2
-1
未找到文件。
pytorch_widedeep/models/wide_deep.py
浏览文件 @
845d630d
...
...
@@ -136,7 +136,6 @@ class WideDeep(nn.Module):
enforce_positive_activation
:
str
=
"softplus"
,
pred_dim
:
int
=
1
,
with_fds
:
bool
=
False
,
device
:
Optional
[
str
]
=
None
,
**
fds_config
,
):
super
(
WideDeep
,
self
).
__init__
()
...
...
@@ -152,11 +151,9 @@ class WideDeep(nn.Module):
with_fds
,
)
self
.
wd_device
=
(
device
if
device
is
not
None
else
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
)
# this attribute will be eventually over-written by the Trainer's
# device. Acts here as a 'placeholder'.
self
.
wd_device
=
"cpu"
# required as attribute just in case we pass a deephead
self
.
pred_dim
=
pred_dim
...
...
pytorch_widedeep/training/trainer.py
浏览文件 @
845d630d
...
...
@@ -238,6 +238,7 @@ class Trainer:
self
.
lambda_sparse
=
kwargs
.
get
(
"lambda_sparse"
,
1e-3
)
self
.
reducing_matrix
=
create_explain_matrix
(
self
.
model
)
self
.
model
.
to
(
self
.
device
)
self
.
model
.
wd_device
=
self
.
device
self
.
objective
=
objective
self
.
method
=
_ObjectiveToMethod
.
get
(
objective
)
...
...
@@ -1474,7 +1475,7 @@ class Trainer:
if
sys
.
platform
==
"darwin"
and
sys
.
version_info
.
minor
>
7
else
os
.
cpu_count
()
)
default_device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
default_device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
device
=
kwargs
.
get
(
"device"
,
default_device
)
num_workers
=
kwargs
.
get
(
"num_workers"
,
default_num_workers
)
return
device
,
num_workers
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录