Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
b1b7af40
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1b7af40
编写于
12月 19, 2017
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multi node
上级
7be79231
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
17 addition
and
20 deletion
+17
-20
paddle/operators/detail/recv_impl.cc
paddle/operators/detail/recv_impl.cc
+9
-5
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+5
-12
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+2
-1
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+1
-2
未找到文件。
paddle/operators/detail/recv_impl.cc
浏览文件 @
b1b7af40
...
...
@@ -51,19 +51,23 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context,
Status
SendRecvServerImpl
::
Wait
(
ServerContext
*
context
,
const
VoidMessage
*
in_var
,
VoidMessage
*
out_var
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
}
return
Status
::
OK
;
}
void
SendRecvServerImpl
::
Start
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
}
void
SendRecvServerImpl
::
Done
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
}
condition_
.
notify_all
();
}
...
...
paddle/operators/recv_op.cc
浏览文件 @
b1b7af40
...
...
@@ -14,7 +14,6 @@
#include <stdint.h>
#include <sys/stat.h>
#include <iostream>
#include <ostream>
#include <thread>
...
...
@@ -81,9 +80,9 @@ class RecvOp : public framework::OperatorBase {
auto
grad_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"GradList"
);
auto
trainer_count
=
Attr
<
int
>
(
"Trainers"
);
size_t
param_count
=
param_list
.
size
();
rpc_service_
->
Start
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while
(
true
)
{
rpc_service_
->
Start
();
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for
(
size_t
i
=
0
;
i
<
param_count
*
trainer_count
;
++
i
)
{
...
...
@@ -95,8 +94,8 @@ class RecvOp : public framework::OperatorBase {
if
(
it
!=
grad_list
.
end
())
{
param_var_name
=
param_list
[
it
-
grad_list
.
begin
()];
}
VLOG
(
10
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
auto
*
merged_grad
=
recv_scope
.
FindVar
(
grad_var_name
);
if
(
merged_grad
==
nullptr
)
{
// create output of merged var.
...
...
@@ -113,6 +112,7 @@ class RecvOp : public framework::OperatorBase {
// FIXME(typhoonzero): do not copy
framework
::
CopyFrom
(
v
.
second
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
tensor
);
}
rpc_service_
->
Start
();
std
::
string
program_str
=
Attr
<
std
::
string
>
(
"OptimizeProgram"
);
framework
::
ProgramDesc
program_desc
;
...
...
@@ -127,14 +127,7 @@ class RecvOp : public framework::OperatorBase {
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
rpc_service_
->
Done
();
// for (size_t i = 0; i < param_count; ++i) {
// auto *out_var = recv_scope.FindVar(param_list[i]);
// detail::TensorWithName out;
// out.first = param_list[i];
// out.second = out_var->Get<framework::LoDTensor>();
// rpc_service_->Push(out);
// }
grads_counter_
.
clear
();
}
// while(true)
}
...
...
paddle/operators/send_op.cc
浏览文件 @
b1b7af40
...
...
@@ -52,7 +52,8 @@ class SendOp : public framework::OperatorBase {
LOG
(
ERROR
)
<<
"send variable error: "
<<
ins
[
i
];
}
}
client_map_
[
0
]
->
Wait
();
// TODO(typhoonzero): support async optimization
// TODO(typhoonzero): support async optimization
client_map_
[
epmap
[
0
]]
->
Wait
();
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
bool
ret
=
client_map_
[
epmap
[
i
]]
->
GetVariable
(
scope
,
ins
[
i
]);
if
(
!
ret
)
{
...
...
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
b1b7af40
...
...
@@ -149,9 +149,8 @@ class DistributeTranspiler:
epmap
=
[]
for
ep
,
v
in
self
.
param_grad_map
.
iteritems
():
send_op_ordered_inputs
.
extend
(
v
[
"grads"
])
for
i
in
v
:
for
i
in
v
[
"grads"
]
:
epmap
.
append
(
ep
)
send_op
=
program
.
global_block
().
append_op
(
type
=
"send"
,
inputs
=
{
"X"
:
send_op_ordered_inputs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录