Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2891ce64
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2891ce64
编写于
7月 19, 2021
作者:
Y
yaoxuefeng
提交者:
GitHub
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pull&push dim (#34217)
上级
2c945737
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
16 addition
and
3 deletion
+16
-3
paddle/fluid/framework/fleet/heter_ps/feature_value.h
paddle/fluid/framework/fleet/heter_ps/feature_value.h
+13
-0
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
+3
-3
未找到文件。
paddle/fluid/framework/fleet/heter_ps/feature_value.h
浏览文件 @
2891ce64
...
...
@@ -51,6 +51,19 @@ struct FeaturePushValue {
int
slot
;
float
lr_g
;
float
mf_g
[
MF_DIM
];
__device__
__forceinline__
FeaturePushValue
operator
+
(
const
FeaturePushValue
&
a
)
const
{
FeaturePushValue
out
;
out
.
slot
=
a
.
slot
;
out
.
show
=
a
.
show
+
show
;
out
.
clk
=
a
.
clk
+
clk
;
out
.
lr_g
=
a
.
lr_g
+
lr_g
;
for
(
int
i
=
0
;
i
<
MF_DIM
;
++
i
)
{
out
.
mf_g
[
i
]
=
a
.
mf_g
[
i
]
+
mf_g
[
i
];
}
return
out
;
}
};
}
// end namespace framework
...
...
paddle/fluid/framework/fleet/ps_gpu_wrapper.cu
浏览文件 @
2891ce64
...
...
@@ -50,11 +50,11 @@ __global__ void PullCopy(float** dest, const FeatureValue* src,
*
(
dest
[
x
]
+
y
*
hidden
+
2
)
=
(
src
+
i
)
->
lr
;
}
if
((
src
+
i
)
->
mf_size
==
0
||
*
(
keys
[
x
]
+
y
)
==
0
)
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
hidden
-
3
;
j
++
)
{
*
(
dest
[
x
]
+
y
*
hidden
+
3
+
j
)
=
0
;
}
}
else
{
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
hidden
-
3
;
j
++
)
{
*
(
dest
[
x
]
+
y
*
hidden
+
3
+
j
)
=
(
src
+
i
)
->
mf
[
1
+
j
];
}
}
...
...
@@ -99,7 +99,7 @@ __global__ void PushCopy(FeaturePushValue* dest, float** src, int64_t* len,
(
dest
+
i
)
->
show
=
*
(
src
[
x
]
+
y
*
hidden
);
(
dest
+
i
)
->
clk
=
*
(
src
[
x
]
+
y
*
hidden
+
1
);
(
dest
+
i
)
->
lr_g
=
*
(
src
[
x
]
+
y
*
hidden
+
2
)
*
-
1.
*
bs
;
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
for
(
int
j
=
0
;
j
<
hidden
-
3
;
j
++
)
{
(
dest
+
i
)
->
mf_g
[
j
]
=
*
(
src
[
x
]
+
y
*
hidden
+
3
+
j
)
*
-
1.
*
bs
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录