Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
66e9406d
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
66e9406d
编写于
7月 07, 2020
作者:
Y
yaoxuefeng
提交者:
GitHub
7月 07, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix AutoInt (#133)
* fix AutoInt * bug fix
上级
5f6ab95e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
9 addition
and
10 deletion
+9
-10
models/rank/AutoInt/config.yaml
models/rank/AutoInt/config.yaml
+2
-1
models/rank/AutoInt/model.py
models/rank/AutoInt/model.py
+7
-9
未找到文件。
models/rank/AutoInt/config.yaml
浏览文件 @
66e9406d
...
...
@@ -36,8 +36,9 @@ hyper_parameters:
class
:
SGD
learning_rate
:
0.0001
sparse_feature_number
:
1086460
sparse_feature_dim
:
9
sparse_feature_dim
:
9
6
num_field
:
39
d_model
:
96
d_key
:
16
d_value
:
16
n_head
:
6
...
...
models/rank/AutoInt/model.py
浏览文件 @
66e9406d
...
...
@@ -31,6 +31,7 @@ class Model(ModelBase):
"hyper_parameters.sparse_feature_dim"
,
None
)
self
.
num_field
=
envs
.
get_global_env
(
"hyper_parameters.num_field"
,
None
)
self
.
d_model
=
envs
.
get_global_env
(
"hyper_parameters.d_model"
,
None
)
self
.
d_key
=
envs
.
get_global_env
(
"hyper_parameters.d_key"
,
None
)
self
.
d_value
=
envs
.
get_global_env
(
"hyper_parameters.d_value"
,
None
)
self
.
n_head
=
envs
.
get_global_env
(
"hyper_parameters.n_head"
,
None
)
...
...
@@ -40,7 +41,7 @@ class Model(ModelBase):
"hyper_parameters.n_interacting_layers"
,
1
)
def
multi_head_attention
(
self
,
queries
,
keys
,
values
,
d_key
,
d_value
,
n_head
,
dropout_rate
):
d_model
,
n_head
,
dropout_rate
):
keys
=
queries
if
keys
is
None
else
keys
values
=
keys
if
values
is
None
else
values
if
not
(
len
(
queries
.
shape
)
==
len
(
keys
.
shape
)
==
len
(
values
.
shape
)
==
3
...
...
@@ -126,9 +127,8 @@ class Model(ModelBase):
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
q
,
k
,
v
=
__split_heads_qkv
(
q
,
k
,
v
,
n_head
,
d_key
,
d_value
)
d_model
=
d_key
*
n_head
ctx_multiheads
=
scaled_dot_product_attention
(
q
,
k
,
v
,
d_model
,
ctx_multiheads
=
scaled_dot_product_attention
(
q
,
k
,
v
,
self
.
d_model
,
dropout_rate
)
out
=
__combine_heads
(
ctx_multiheads
)
...
...
@@ -136,16 +136,14 @@ class Model(ModelBase):
return
out
def
interacting_layer
(
self
,
x
):
attention_out
=
self
.
multi_head_attention
(
x
,
None
,
None
,
self
.
d_key
,
self
.
d_value
,
self
.
n_head
,
self
.
dropout_rate
)
attention_out
=
self
.
multi_head_attention
(
x
,
None
,
None
,
self
.
d_key
,
self
.
d_value
,
self
.
d_model
,
self
.
n_head
,
self
.
dropout_rate
)
W_0_x
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
self
.
d_
key
*
self
.
n_head
,
size
=
self
.
d_
model
,
bias_attr
=
False
,
num_flatten_dims
=
2
)
res_out
=
fluid
.
layers
.
relu
(
attention_out
+
W_0_x
)
self
.
d_key
=
self
.
d_key
*
self
.
n_head
self
.
d_value
=
self
.
d_value
*
self
.
n_head
return
res_out
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录