Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
OneFlow-Benchmark
提交
d1275e63
O
OneFlow-Benchmark
项目概览
Oneflow-Inc
/
OneFlow-Benchmark
上一次同步 2 年多
通知
1
Star
92
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
OneFlow-Benchmark
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
d1275e63
编写于
6月 15, 2021
作者:
S
ShawnXuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm usless comments
上级
3aec0ec1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
2 addition
and
8 deletion
+2
-8
ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py
...hRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py
+2
-8
未找到文件。
ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py
浏览文件 @
d1275e63
...
...
@@ -155,19 +155,17 @@ def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size):
dtype
=
flow
.
float
,
initializer
=
flow
.
random_uniform_initializer
(
minval
=-
0.05
,
maxval
=
0.05
),
)
hf_embedding
=
flow
.
gather
(
params
=
hf_embedding_table
,
indices
=
hf_ids
)
#, no_duplicates_in_indices=True)
hf_embedding
=
flow
.
gather
(
params
=
hf_embedding_table
,
indices
=
hf_ids
)
lf_ids
=
lf_ids
-
hf_vocab_size_constant
with
flow
.
scope
.
placement
(
'cpu'
,
'0:0'
):
lf_embedding_table
=
flow
.
get_variable
(
name
=
f
'lf_
{
name
}
'
,
shape
=
(
vocab_size
-
hf_vocab_size
,
embedding_size
),
#shape=(vocab_size, embedding_size),
dtype
=
flow
.
float
,
initializer
=
flow
.
random_uniform_initializer
(
minval
=-
0.05
,
maxval
=
0.05
),
)
lf_embedding
=
flow
.
gather
(
params
=
lf_embedding_table
,
indices
=
lf_ids
)
#, no_duplicates_in_indices=True)
lf_embedding
=
flow
.
gather
(
params
=
lf_embedding_table
,
indices
=
lf_ids
)
unique_embedding
=
flow
.
reshape
(
flow
.
zeros_like
(
unique_ids
,
dtype
=
flow
.
float
),
(
-
1
,
1
))
*
flow
.
constant
(
0.0
,
dtype
=
flow
.
float
,
shape
=
(
1
,
embedding_size
))
# unique_embedding = flow.constant(0.0, dtype=flow.float, shape=(b*s, embedding_size))
unique_embedding
=
flow
.
tensor_scatter_nd_update
(
params
=
unique_embedding
,
updates
=
hf_embedding
,
indices
=
hf_indices
)
unique_embedding
=
flow
.
tensor_scatter_nd_update
(
params
=
unique_embedding
,
updates
=
lf_embedding
,
indices
=
lf_indices
)
unique_embedding
=
flow
.
gather
(
params
=
unique_embedding
,
indices
=
unique_ids_idx
)
...
...
@@ -309,8 +307,6 @@ def print_args(args):
for
arg
in
vars
(
args
):
print
(
"{} = {}"
.
format
(
arg
,
getattr
(
args
,
arg
)))
print
(
"-"
.
ljust
(
66
,
"-"
))
#print("Time stamp: {}".format(
# str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S"))))
def
main
():
...
...
@@ -320,8 +316,6 @@ def main():
flow
.
config
.
enable_model_io_v2
(
True
)
flow
.
config
.
enable_debug_mode
(
True
)
flow
.
config
.
collective_boxing
.
nccl_enable_all_to_all
(
True
)
#flow.config.enable_numa_aware_cuda_malloc_host(True)
#flow.config.collective_boxing.enable_fusion(False)
check_point
=
flow
.
train
.
CheckPoint
()
check_point
.
init
()
for
i
in
range
(
FLAGS
.
max_iter
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录