Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
82cff5ec
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82cff5ec
编写于
4月 14, 2019
作者:
乔
乔龙飞 Qiao Longfei
提交者:
GitHub
4月 14, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16762 from jacquesqiao/add-async_sparse_param_update_recorder
Add async sparse param update recorder
上级
4267a81a
1526a3e4
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
570 addition
and
85 deletion
+570
-85
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+7
-2
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+4
-1
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc
...erators/distributed/async_sparse_param_update_recorder.cc
+27
-0
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h
...perators/distributed/async_sparse_param_update_recorder.h
+183
-0
paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc
...rs/distributed/async_sparse_param_update_recorder_test.cc
+99
-0
paddle/fluid/operators/distributed/brpc/brpc_client.cc
paddle/fluid/operators/distributed/brpc/brpc_client.cc
+1
-0
paddle/fluid/operators/distributed/brpc/brpc_client.h
paddle/fluid/operators/distributed/brpc/brpc_client.h
+8
-7
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+26
-9
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+1
-1
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+31
-25
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+5
-1
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+7
-2
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+51
-17
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+1
-1
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+7
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+46
-3
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+3
-0
paddle/fluid/operators/distributed/rpc_common.h
paddle/fluid/operators/distributed/rpc_common.h
+5
-2
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+2
-2
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+29
-4
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
+3
-0
paddle/fluid/operators/distributed_ops/recv_op.cc
paddle/fluid/operators/distributed_ops/recv_op.cc
+4
-3
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+3
-3
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+16
-2
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
82cff5ec
...
...
@@ -64,9 +64,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
height_section
=
boost
::
get
<
std
::
vector
<
int64_t
>>
(
node
->
Op
()
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
);
epmap
,
height_section
,
trainer_id
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Name
()
==
"recv"
)
{
...
...
@@ -75,9 +78,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"recv_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{});
epmap
,
{}
,
trainer_id
);
nodes_to_delete
.
push_back
(
node
);
VLOG
(
3
)
<<
"find and remove an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
82cff5ec
...
...
@@ -9,6 +9,9 @@ else()
endif
()
configure_file
(
send_recv.proto.in
${
CMAKE_CURRENT_SOURCE_DIR
}
/send_recv.proto @ONLY
)
cc_library
(
async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool
)
cc_test
(
async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder
)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
...
...
@@ -20,7 +23,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${
GRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
)
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
RPC_DEPS sendrecvop_rpc
${
GRPC_DEPS
}
)
...
...
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc
0 → 100644
浏览文件 @
82cff5ec
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
std
::
once_flag
AsyncSparseParamUpdateRecorder
::
init_flag_
;
std
::
unique_ptr
<
AsyncSparseParamUpdateRecorder
>
AsyncSparseParamUpdateRecorder
::
recorder_
(
nullptr
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h
0 → 100644
浏览文件 @
82cff5ec
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
ConcurrentSet
{
public:
ConcurrentSet
()
:
pool_
(
new
::
ThreadPool
(
1
))
{}
~
ConcurrentSet
()
{}
std
::
future
<
void
>
Update
(
const
std
::
vector
<
int64_t
>&
rows
)
{
auto
task
=
[
this
,
rows
]
{
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
id
:
rows
)
{
sstream
<<
id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"update ids -> "
<<
sstream
.
str
();
}
for
(
auto
row
:
rows
)
{
set_
.
insert
(
row
);
}
};
return
pool_
->
enqueue
(
std
::
move
(
task
));
}
std
::
future
<
void
>
GetAndClear
(
std
::
vector
<
int64_t
>*
result
)
{
auto
task
=
[
this
,
&
result
]
{
result
->
clear
();
for
(
auto
&
id
:
set_
)
{
result
->
push_back
(
id
);
}
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
id
:
*
result
)
{
sstream
<<
id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"result ids size: "
<<
result
->
size
()
<<
" "
<<
sstream
.
str
();
}
set_
.
clear
();
};
return
pool_
->
enqueue
(
std
::
move
(
task
));
}
private:
std
::
unordered_set
<
int64_t
>
set_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
};
class
AsyncSparseParamUpdateRecorder
{
using
TrainerToRows
=
std
::
vector
<
std
::
unique_ptr
<
ConcurrentSet
>>
;
public:
AsyncSparseParamUpdateRecorder
(
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
:
trainer_num_
(
trainer_num
),
grad_to_param_
(
grad_to_param
)
{
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
item
:
grad_to_param
)
{
sstream
<<
item
.
first
<<
":"
<<
item
.
second
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"trainer_num: "
<<
trainer_num
<<
" grad_to_param_: "
<<
sstream
.
str
();
}
for
(
auto
&
iter
:
grad_to_param
)
{
param_to_grad_
[
iter
.
second
]
=
iter
.
first
;
auto
&
param_name
=
iter
.
second
;
param_to_updated_rows_
[
param_name
]
=
TrainerToRows
();
auto
&
trainer_to_rows
=
param_to_updated_rows_
[
param_name
];
for
(
auto
i
=
0
;
i
<
trainer_num
;
++
i
)
{
trainer_to_rows
.
emplace_back
(
new
ConcurrentSet
());
}
}
}
~
AsyncSparseParamUpdateRecorder
()
=
default
;
void
Update
(
const
std
::
string
&
grad_name
,
const
std
::
vector
<
int64_t
>&
update_rows
)
{
VLOG
(
3
)
<<
"update grad: "
<<
grad_name
<<
" row size: "
<<
update_rows
.
size
();
auto
&
param_name
=
grad_to_param_
.
at
(
grad_name
);
auto
&
trainer_to_rows
=
param_to_updated_rows_
.
at
(
param_name
);
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
&
set
:
trainer_to_rows
)
{
fs
.
push_back
(
set
->
Update
(
update_rows
));
}
for
(
auto
&
f
:
fs
)
{
f
.
wait
();
}
}
void
GetAndClear
(
const
std
::
string
&
param_name
,
int
trainer_id
,
std
::
vector
<
int64_t
>*
result
)
{
VLOG
(
3
)
<<
"GetAndClear param: "
<<
param_name
<<
" for trainer: "
<<
trainer_id
;
PADDLE_ENFORCE_LT
(
trainer_id
,
trainer_num_
);
param_to_updated_rows_
.
at
(
param_name
)[
trainer_id
]
->
GetAndClear
(
result
)
.
wait
();
}
bool
HasParam
(
const
std
::
string
&
param_name
)
{
return
param_to_grad_
.
find
(
param_name
)
!=
param_to_grad_
.
end
();
}
bool
HasGrad
(
const
std
::
string
&
grad_name
)
{
return
grad_to_param_
.
find
(
grad_name
)
!=
grad_to_param_
.
end
();
}
private:
const
int
trainer_num_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_param_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
param_to_grad_
;
std
::
unordered_map
<
std
::
string
,
TrainerToRows
>
param_to_updated_rows_
;
// init recorder
public:
static
void
Init
(
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
{
InitImpl
(
trainer_num
,
grad_to_param
);
}
static
AsyncSparseParamUpdateRecorder
*
GetInstance
()
{
return
recorder_
.
get
();
}
private:
// Init is called by GetInstance.
static
void
InitImpl
(
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
{
if
(
recorder_
==
nullptr
)
{
recorder_
.
reset
(
new
AsyncSparseParamUpdateRecorder
(
trainer_num
,
grad_to_param
));
}
}
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
AsyncSparseParamUpdateRecorder
>
recorder_
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc
0 → 100644
浏览文件 @
82cff5ec
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include <algorithm>
#include "gtest/gtest.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
TEST
(
ConcurrentSet
,
All
)
{
ConcurrentSet
concurrent_set
;
std
::
vector
<
int64_t
>
in1
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
in2
=
{
2
,
3
,
5
,
6
};
std
::
vector
<
std
::
future
<
void
>>
futures
;
futures
.
push_back
(
concurrent_set
.
Update
(
in1
));
futures
.
push_back
(
concurrent_set
.
Update
(
in2
));
for
(
auto
&
f
:
futures
)
{
f
.
wait
();
}
std
::
unordered_set
<
int64_t
>
in
;
std
::
copy
(
in1
.
begin
(),
in1
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
std
::
copy
(
in2
.
begin
(),
in2
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
std
::
vector
<
int64_t
>
ret
;
concurrent_set
.
GetAndClear
(
&
ret
).
wait
();
std
::
unordered_set
<
int64_t
>
out
;
std
::
copy
(
ret
.
begin
(),
ret
.
end
(),
std
::
inserter
(
out
,
out
.
begin
()));
EXPECT_EQ
(
in
,
out
);
concurrent_set
.
GetAndClear
(
&
ret
).
wait
();
EXPECT_EQ
(
ret
.
size
(),
0
);
}
TEST
(
AsyncSparseParamUpdateRecorder
,
All
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_param
;
grad_to_param
[
"grad1"
]
=
"param1"
;
grad_to_param
[
"grad2"
]
=
"param2"
;
int
trainer_num
=
10
;
AsyncSparseParamUpdateRecorder
recorder
(
trainer_num
,
grad_to_param
);
std
::
vector
<
int64_t
>
in1
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
in2
=
{
2
,
3
,
5
,
6
};
std
::
unordered_set
<
int64_t
>
in
;
std
::
copy
(
in1
.
begin
(),
in1
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
std
::
copy
(
in2
.
begin
(),
in2
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
recorder
.
Update
(
"grad1"
,
in1
);
recorder
.
Update
(
"grad1"
,
in2
);
EXPECT_TRUE
(
recorder
.
HasParam
(
"param1"
));
EXPECT_TRUE
(
recorder
.
HasParam
(
"param2"
));
EXPECT_FALSE
(
recorder
.
HasParam
(
"param3"
));
EXPECT_TRUE
(
recorder
.
HasGrad
(
"grad1"
));
EXPECT_TRUE
(
recorder
.
HasGrad
(
"grad2"
));
EXPECT_FALSE
(
recorder
.
HasGrad
(
"grad3"
));
std
::
vector
<
int64_t
>
ret
;
EXPECT_ANY_THROW
(
recorder
.
GetAndClear
(
"param1"
,
trainer_num
,
&
ret
));
for
(
int
i
=
0
;
i
<
trainer_num
;
++
i
)
{
std
::
vector
<
int64_t
>
ret
;
std
::
unordered_set
<
int64_t
>
out
;
recorder
.
GetAndClear
(
"param1"
,
i
,
&
ret
);
std
::
copy
(
ret
.
begin
(),
ret
.
end
(),
std
::
inserter
(
out
,
out
.
begin
()));
EXPECT_EQ
(
in
,
out
);
recorder
.
GetAndClear
(
"param1"
,
i
,
&
ret
);
EXPECT_EQ
(
ret
.
size
(),
0
);
}
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc/brpc_client.cc
浏览文件 @
82cff5ec
...
...
@@ -234,6 +234,7 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
out_var_name
,
kGetRPC
,
time_out
);
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.h
浏览文件 @
82cff5ec
...
...
@@ -21,8 +21,10 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
...
...
@@ -66,6 +68,7 @@ class BRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerBarrier
(
...
...
@@ -107,13 +110,11 @@ class BRPCClient : public RPCClient {
void
SendComplete
()
override
;
private:
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
void
Proceed
();
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
82cff5ec
...
...
@@ -32,6 +32,9 @@ DEFINE_int32(communicator_send_queue_size, 20,
DEFINE_int32
(
communicator_max_send_grad_num_before_recv
,
20
,
"max grad num to send before recv parameters"
);
DEFINE_int32
(
communicator_thread_pool_size
,
5
,
"thread num to do send or recv"
);
DEFINE_int32
(
communicator_send_wait_times
,
5
,
"times that send thread will wait if merge num does not reach "
"max_merge_var_num"
);
DEFINE_int32
(
communicator_max_merge_var_num
,
20
,
"max var num to merge and send"
);
DEFINE_bool
(
communicator_fake_rpc
,
false
,
...
...
@@ -65,6 +68,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
<<
FLAGS_communicator_max_send_grad_num_before_recv
;
VLOG
(
0
)
<<
"communicator_thread_pool_size: "
<<
FLAGS_communicator_thread_pool_size
;
VLOG
(
0
)
<<
"communicator_send_wait_times: "
<<
FLAGS_communicator_send_wait_times
;
VLOG
(
0
)
<<
"communicator_max_merge_var_num: "
<<
FLAGS_communicator_max_merge_var_num
;
VLOG
(
0
)
<<
"communicator_fake_rpc: "
<<
FLAGS_communicator_fake_rpc
;
...
...
@@ -101,20 +106,32 @@ void Communicator::SendThread() {
VLOG
(
3
)
<<
var_name
<<
" merge and send"
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
size_t
merged_var_num
=
0
;
while
(
var_queue
->
Size
()
>
0
&&
merged_var_num
<
FLAGS_communicator_max_merge_var_num
)
{
vars
.
push_back
(
var_queue
->
Pop
());
// only count the send number of the first var
if
(
var_name
==
send_varname_to_queue_
.
begin
()
->
first
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
size_t
wait_times
=
0
;
while
(
merged_var_num
<
FLAGS_communicator_max_merge_var_num
)
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
FLAGS_communicator_send_wait_times
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
wait_times
=
0
;
vars
.
push_back
(
var_queue
->
Pop
());
// only count the send number of the first var
if
(
var_name
==
send_varname_to_queue_
.
begin
()
->
first
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
merged_var_num
++
;
}
merged_var_num
++
;
}
auto
before_merge
=
GetCurrentUS
();
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
VLOG
(
3
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
FLAGS_communicator_fake_rpc
)
{
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
82cff5ec
...
...
@@ -109,7 +109,7 @@ inline void MergeVars(const std::string& var_name,
auto
*
out_var
=
scope
->
Var
(
var_name
);
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
dims
=
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor "
<<
dims
;
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor
dims
"
<<
dims
;
// init output tensor
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
82cff5ec
...
...
@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetRPC
,
var_name
,
out_varname
,
"/sendrecv.SendRecvService/GetVariable"
,
time_out
);
"/sendrecv.SendRecvService/GetVariable"
,
table_name
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetVarNoBarrier
(
...
...
@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetNoBarrierRPC
,
var_name_no_barrier
,
out_varname
,
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
time_out
);
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
""
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
...
...
@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetMonomerRPC
,
var_name
,
var_name
,
"/sendrecv.SendRecvService/GetMonomerVariable"
,
time_out
);
"/sendrecv.SendRecvService/GetMonomerVariable"
,
""
,
time_out
);
}
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
const
std
::
string
&
rpc_path
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
out_varname_val
=
out_varname
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
...
...
@@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_varname_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
(
[
var_name_val
,
out_varname_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
framework
::
AsyncIO
([
var_name_val
,
out_varname_val
,
table_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_table_name
(
table_name_val
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
platform
::
RecordRPCEvent
record_event
(
method
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
req_count_
++
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
82cff5ec
...
...
@@ -23,9 +23,11 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "grpc++/channel.h"
...
...
@@ -187,6 +189,7 @@ class GRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVarNoBarrier
(
...
...
@@ -239,7 +242,8 @@ class GRPCClient : public RPCClient {
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
const
std
::
string
&
rpc_path
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
private:
grpc
::
CompletionQueue
cq_
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
82cff5ec
...
...
@@ -137,6 +137,7 @@ class RequestGet final : public RequestBase {
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
std
::
string
table_name
=
request_
.
table_name
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGet "
<<
out_varname
<<
" from "
<<
varname
;
...
...
@@ -145,19 +146,23 @@ class RequestGet final : public RequestBase {
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
tmp_scope_
=
std
::
move
(
scope
->
NewTmpScope
());
request_handler_
->
Handle
(
varname
,
tmp_scope_
.
get
(),
invar
,
&
outvar
,
trainer_id
,
out_varname
,
table_name
);
VLOG
(
1
)
<<
"before SerializeToByteBuffer"
;
if
(
outvar
)
{
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
VLOG
(
1
)
<<
"after SerializeToByteBuffer"
;
Finish
(
reply_
,
&
responder_
);
}
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
std
::
unique_ptr
<
framework
::
Scope
>
tmp_scope_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
};
...
...
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
82cff5ec
...
...
@@ -42,27 +42,23 @@ using DDim = framework::DDim;
template
<
typename
T
>
void
ParameterRecv
<
T
>::
operator
()(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
VLOG
(
3
)
<<
"ParameterRecv in
"
;
VLOG
(
3
)
<<
"ParameterRecv in
"
<<
rpc_ctx
.
var_name
;
std
::
unique_ptr
<
framework
::
Scope
>
local_scope
=
scope
.
NewTmpScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
auto
*
recv_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
std
::
vector
<
framework
::
Tensor
*>
recved_tensors
;
// recv all vars to local scope
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
recv_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
framework
::
Tensor
*
t
=
local_scope
->
Var
(
recv_var_name
)
->
GetMutable
<
framework
::
LoDTensor
>
();
recved_tensors
.
push_back
(
t
);
local_scope
->
Var
(
recv_var_name
);
VLOG
(
3
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
*
local_scope
.
get
(),
recv_var_name
,
...
...
@@ -78,23 +74,61 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// concat recved tensor into one var
{
size_t
output_offset
=
0
;
size_t
row_offset
=
0
;
framework
::
Tensor
*
recv_tensor
=
recv_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
int64_t
recv_numel
=
0
;
for
(
auto
*
in
:
recved_tensors
)
{
recv_numel
+=
in
->
numel
();
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
auto
out_stride
=
framework
::
stride_numel
(
recv_tensor
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
dev_ctx
,
0
,
recv_tensor
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
->
data
<
T
>
(),
in_stride
,
in_stride
[
0
]);
output_offset
+=
in_stride
[
0
];
for
(
auto
&
recv_var_name
:
rpc_ctx
.
splited_var_names
)
{
auto
*
recv_var
=
local_scope
->
FindVar
(
recv_var_name
);
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
in
=
recv_var
->
Get
<
framework
::
LoDTensor
>
();
recv_numel
+=
in
.
numel
();
auto
in_stride
=
framework
::
stride_numel
(
in
.
dims
());
auto
out_stride
=
framework
::
stride_numel
(
recv_tensor
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
dev_ctx
,
0
,
recv_tensor
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
.
data
<
T
>
(),
in_stride
,
in_stride
[
0
]);
output_offset
+=
in_stride
[
0
];
}
else
if
(
recv_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
recv_slr
=
recv_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
recv_dims
=
recv_tensor
->
dims
();
int64_t
width
=
recv_dims
[
1
];
recv_numel
+=
recv_slr
.
height
()
*
width
;
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
1
],
width
);
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
0
],
recv_slr
.
rows
().
size
());
VLOG
(
3
)
<<
"recv slr "
<<
recv_var_name
<<
" dims "
<<
recv_slr
.
value
().
dims
();
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
recv_slr
.
rows
())
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"recv_slr size: "
<<
recv_slr
.
rows
().
size
()
<<
" "
<<
sstream
.
str
();
}
for
(
auto
i
=
0
;
i
<
recv_slr
.
rows
().
size
();
++
i
)
{
auto
row_id
=
recv_slr
.
rows
()[
i
]
+
row_offset
;
PADDLE_ENFORCE_LT
(
row_id
,
recv_dims
[
0
]);
memcpy
(
recv_tensor
->
data
<
T
>
()
+
row_id
*
width
,
recv_slr
.
value
().
data
<
T
>
()
+
i
*
width
,
sizeof
(
T
)
*
width
);
}
row_offset
+=
recv_slr
.
height
();
}
else
{
PADDLE_THROW
(
"unsupported recieved var type"
);
}
}
auto
numel
=
recv_tensor
->
numel
();
if
(
recv_numel
!=
numel
)
{
LOG
(
FATAL
)
<<
"recv_numel: "
<<
recv_numel
<<
" acture numel: "
<<
numel
;
}
PADDLE_ENFORCE_EQ
(
recv_numel
,
recv_tensor
->
numel
()
);
PADDLE_ENFORCE_EQ
(
recv_numel
,
numel
);
}
VLOG
(
3
)
<<
"ParameterRecv out
"
;
VLOG
(
3
)
<<
"ParameterRecv out
"
<<
rpc_ctx
.
var_name
;
}
template
struct
ParameterRecv
<
float
>;
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
82cff5ec
...
...
@@ -47,7 +47,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
size_t
out_num
=
rpc_ctx
.
splited_var_names
.
size
();
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
82cff5ec
...
...
@@ -18,7 +18,9 @@
#include <condition_variable> // NOLINT
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
...
...
@@ -180,6 +182,10 @@ class RequestHandler {
grad_to_prepared_ctx_
=
g
;
}
void
SetSparseGradToParam
(
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
g
)
{
sparse_grad_to_param_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
...
...
@@ -228,6 +234,7 @@ class RequestHandler {
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
grad_to_prepared_ctx_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
sparse_grad_to_param_
;
RPCServer
*
rpc_server_
;
};
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
82cff5ec
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
...
...
@@ -59,6 +60,12 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
varname
))
{
auto
&
grad_slr
=
scope
->
FindVar
(
varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
varname
,
grad_slr
.
rows
());
}
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
return
true
;
...
...
@@ -82,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
VLOG
(
3
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" table_name: "
<<
table_name
;
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
...
...
@@ -108,7 +116,42 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
varname
,
trainer_id
,
&
updated_rows
);
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
updated_rows
)
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"updated_rows size: "
<<
updated_rows
.
size
()
<<
" "
<<
sstream
.
str
();
}
auto
&
origin_tensor
=
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
auto
&
dims
=
origin_tensor
.
dims
();
*
outvar
=
scope
->
Var
();
auto
*
out_slr
=
(
*
outvar
)
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
set_rows
(
updated_rows
);
out_slr
->
set_height
(
dims
[
0
]);
auto
out_dims
=
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
updated_rows
.
size
()),
dims
[
1
]});
auto
*
data
=
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
(
out_dims
,
origin_tensor
.
place
());
auto
width
=
dims
[
1
];
for
(
auto
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
updated_rows
[
i
],
dims
[
0
]);
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
sizeof
(
float
)
*
width
);
}
}
else
{
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
}
return
true
;
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
82cff5ec
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <condition_variable> // NOLINT
#include <memory>
#include <string>
#include "gflags/gflags.h"
...
...
@@ -44,6 +45,7 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetVarNoBarrier
(
...
...
@@ -96,6 +98,7 @@ class RPCClient {
// Init is called by GetInstance.
template
<
typename
T
>
static
void
Init
(
int
trainer_id
)
{
VLOG
(
0
)
<<
"init rpc client with trainer_id "
<<
trainer_id
;
trainer_id_
=
trainer_id
;
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
T
());
...
...
paddle/fluid/operators/distributed/rpc_common.h
浏览文件 @
82cff5ec
...
...
@@ -27,23 +27,26 @@ struct RpcContext {
RpcContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
)
const
std
::
vector
<
int64_t
>
&
sections
,
int
id
)
:
var_name
(
name
),
splited_var_names
(
names
),
epmap
(
emap
),
height_sections
(
sections
)
{}
height_sections
(
sections
),
trainer_id
(
id
)
{}
RpcContext
(
const
RpcContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_var_names
=
ctx
.
splited_var_names
;
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
trainer_id
=
ctx
.
trainer_id
;
}
std
::
string
var_name
;
std
::
vector
<
std
::
string
>
splited_var_names
;
std
::
vector
<
std
::
string
>
epmap
;
std
::
vector
<
int64_t
>
height_sections
;
int
trainer_id
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RpcContext
&
rpc_ctx
)
{
...
...
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
82cff5ec
...
...
@@ -2,9 +2,9 @@ include(operators)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator
async_sparse_param_update_recorder
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator
async_sparse_param_update_recorder
brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
82cff5ec
...
...
@@ -24,8 +24,10 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
rpc_send_thread_num
,
12
,
"number of threads for rpc send"
);
...
...
@@ -292,6 +294,8 @@ static void FillRequestCtx(
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
*
prefetch_ctx
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
sparse_grad_name_to_param_name
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
checkpoint_ctx
,
distributed
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
...
...
@@ -299,6 +303,7 @@ static void FillRequestCtx(
h
->
SetExecutor
(
executor
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
prefetch_ctx
);
h
->
SetSparseGradToParam
(
sparse_grad_name_to_param_name
);
h
->
SetRPCServer
(
rpc_server
);
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
}
...
...
@@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx
[
prefetch_var_name
]
=
prefetch_prepared
[
i
];
}
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
ckpt_pre_context
,
rpc_service_
.
get
());
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std
::
unordered_map
<
std
::
string
,
std
::
string
>
sparse_grad_name_to_param_name
;
auto
sparse_grad_name_to_param_name_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSparseGradToParam
);
for
(
const
auto
&
sparse_grad_name_and_param_name
:
sparse_grad_name_to_param_name_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
sparse_grad_name_and_param_name
,
':'
,
&
pieces
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
VLOG
(
3
)
<<
"after split, sparse_grad_name = "
<<
pieces
[
0
]
<<
", param_name = "
<<
pieces
[
1
];
sparse_grad_name_to_param_name
[
pieces
[
0
]]
=
pieces
[
1
];
}
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
&
sparse_grad_name_to_param_name
,
ckpt_pre_context
,
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
...
...
@@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
fan_in
,
sparse_grad_name_to_param_name
);
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
...
@@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kPrefetchVarNameToBlockId
,
"prefetch blocks to run on server side."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kSparseGradToParam
,
"sparse grad name to param name. like: 'emb@Grad:emb'"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
kCheckpointBlockId
,
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
浏览文件 @
82cff5ec
...
...
@@ -16,8 +16,10 @@ limitations under the License. */
#include <stdint.h>
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
...
...
@@ -35,6 +37,7 @@ namespace operators {
constexpr
char
kOptimizeBlocks
[]
=
"optimize_blocks"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kCheckpointBlockId
[]
=
"checkpint_block_id"
;
constexpr
char
kSparseGradToParam
[]
=
"sparse_grad_to_param"
;
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
...
...
paddle/fluid/operators/distributed_ops/recv_op.cc
浏览文件 @
82cff5ec
...
...
@@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
std
::
string
>
recv_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"recv_varnames"
);
if
(
recv_varnames
.
size
()
>
0
)
{
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
auto
rpc_ctx
=
distributed
::
RpcContext
(
outs
[
0
],
recv_varnames
,
epmap
,
{});
auto
rpc_ctx
=
distributed
::
RpcContext
(
outs
[
0
],
recv_varnames
,
epmap
,
{},
trainer_id
);
recv_functor
(
rpc_ctx
,
scope
);
}
else
{
if
(
with_barrier
)
{
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
82cff5ec
...
...
@@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase {
auto
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_send
=
Attr
<
int
>
(
"sync_mode"
);
auto
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
auto
send_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"send_varnames"
);
auto
height_sections
=
Attr
<
std
::
vector
<
int64_t
>>
(
"sections"
);
...
...
@@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase {
if
(
distributed
::
Communicator
::
GetInstance
()
==
nullptr
)
{
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
rpc_ctx
=
distributed
::
RpcContext
(
ins
[
0
],
send_varnames
,
epmap
,
height_sections
);
height_sections
,
trainer_id
);
send_functor
(
rpc_ctx
,
scope
,
true
);
}
else
{
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
[
0
],
scope
);
...
...
@@ -62,8 +63,7 @@ class SendOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
...
...
python/paddle/fluid/__init__.py
浏览文件 @
82cff5ec
...
...
@@ -175,6 +175,7 @@ def __bootstrap__():
read_env_flags
.
append
(
'communicator_thread_pool_size'
)
read_env_flags
.
append
(
'communicator_max_merge_var_num'
)
read_env_flags
.
append
(
'communicator_fake_rpc'
)
read_env_flags
.
append
(
'communicator_send_wait_times'
)
if
core
.
is_compiled_with_brpc
():
read_env_flags
.
append
(
'max_body_size'
)
#set brpc max body size
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
82cff5ec
...
...
@@ -658,6 +658,7 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -669,6 +670,7 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
fetch_barrier_out
},
attrs
=
{
"endpoints"
:
self
.
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -791,11 +793,15 @@ class DistributeTranspiler(object):
global_ops
=
[]
# sparse grad name to param name
sparse_grad_to_param
=
[]
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
,
merged_var
,
lr_ops
):
if
self
.
_is_optimizer_op
(
op
):
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
self
.
origin_program
,
merged_var
)
self
.
origin_program
,
merged_var
,
sparse_grad_to_param
)
elif
op
not
in
lr_ops
:
self
.
_append_pserver_non_opt_ops
(
block
,
op
)
...
...
@@ -911,6 +917,7 @@ class DistributeTranspiler(object):
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
}
if
self
.
has_distributed_lookup_table
:
...
...
@@ -1779,7 +1786,8 @@ class DistributeTranspiler(object):
return
o4
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
grad_to_block_id
,
origin_program
,
merged_var
):
grad_to_block_id
,
origin_program
,
merged_var
,
sparse_grad_to_param
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
collections
.
OrderedDict
()
...
...
@@ -1863,6 +1871,12 @@ class DistributeTranspiler(object):
outputs
=
outputs
,
attrs
=
opt_op
.
all_attrs
())
# record sparse grad to param name
if
new_inputs
[
"Grad"
].
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
sparse_grad_to_param
.
append
(
str
(
new_inputs
[
"Grad"
].
name
)
+
":"
+
str
(
new_inputs
[
"Param"
]
.
name
))
def
_get_pserver_grad_param_var
(
self
,
var
,
var_dict
):
"""
Return pserver side grad/param variable, return None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录