Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c4a52b83
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c4a52b83
编写于
6月 27, 2022
作者:
Z
zmxdream
提交者:
GitHub
6月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[GPUPS]fix merge_grad&push_sparse (#43840)
上级
40a77319
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
8 addition
and
26 deletion
+8
-26
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
+2
-15
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
+3
-8
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
+3
-3
未找到文件。
paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.cu
浏览文件 @
c4a52b83
...
...
@@ -112,20 +112,7 @@ __global__ void dy_mf_search_kernel(Table* table,
}
}
else
{
if
(
keys
[
i
]
!=
0
)
{
printf
(
"warning::pull miss key: %d"
,
keys
[
i
]);
}
FeatureValue
*
cur
=
(
FeatureValue
*
)(
vals
+
i
*
pull_feature_value_size
);
cur
->
delta_score
=
0
;
cur
->
show
=
0
;
cur
->
clk
=
0
;
cur
->
slot
=
-
1
;
cur
->
lr
=
0
;
cur
->
lr_g2sum
=
0
;
cur
->
mf_size
=
0
;
cur
->
mf_dim
=
8
;
cur
->
cpu_ptr
;
for
(
int
j
=
0
;
j
<
cur
->
mf_dim
+
1
;
j
++
)
{
cur
->
mf
[
j
]
=
0
;
printf
(
"warning::pull miss key: %llu"
,
keys
[
i
]);
}
}
}
...
...
@@ -163,7 +150,7 @@ __global__ void dy_mf_update_kernel(Table* table,
sgd
.
dy_mf_update_value
(
optimizer_config
,
(
it
.
getter
())
->
second
,
*
cur
);
}
else
{
if
(
keys
[
i
]
!=
0
)
{
printf
(
"warning::push miss key: %
d
"
,
keys
[
i
]);
printf
(
"warning::push miss key: %
llu
"
,
keys
[
i
]);
}
}
}
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
浏览文件 @
c4a52b83
...
...
@@ -1026,14 +1026,9 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
auto
d_shard_keys
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
KeyType
));
KeyType
*
d_shard_keys_ptr
=
reinterpret_cast
<
KeyType
*>
(
d_shard_keys
->
ptr
());
GradType
*
d_shard_grads_ptr
;
if
(
!
multi_mf_dim_
)
{
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
sizeof
(
GradType
));
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
}
else
{
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
grad_value_size
);
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
}
auto
d_shard_grads
=
memory
::
Alloc
(
place
,
len
*
grad_value_size
);
GradType
*
d_shard_grads_ptr
=
reinterpret_cast
<
GradType
*>
(
d_shard_grads
->
ptr
());
int
uniq_len
=
len
;
dynamic_merge_grad
(
dev_num
,
d_keys
,
d_grads
,
len
,
uniq_len
);
...
...
paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.cu
浏览文件 @
c4a52b83
...
...
@@ -153,7 +153,6 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
size_t
grad_value_size
,
DynamicGradMerger
&
merger_
)
{
const
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i
<
n
)
{
uint32_t
start
=
offset
[
i
];
uint32_t
num
=
fea_num
[
i
];
...
...
@@ -164,8 +163,9 @@ __global__ void merge_gradients_kernel(const uint32_t* offset,
merger_
.
update_one
(
out
,
in
);
for
(
int
j
=
1
;
j
<
num
;
++
j
)
{
ori_index
=
index
[
start
+
j
];
in
=
*
(
FeaturePushValue
*
)(
input
+
size_t
(
ori_index
)
*
grad_value_size
);
merger_
.
merge_one
(
out
,
in
);
FeaturePushValue
&
rhs
=
*
(
FeaturePushValue
*
)(
input
+
size_t
(
ori_index
)
*
grad_value_size
);
merger_
.
merge_one
(
out
,
rhs
);
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录