Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
82cff5ec
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82cff5ec
编写于
4月 14, 2019
作者:
乔
乔龙飞 Qiao Longfei
提交者:
GitHub
4月 14, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16762 from jacquesqiao/add-async_sparse_param_update_recorder
Add async sparse param update recorder
上级
4267a81a
1526a3e4
变更
25
隐藏空白更改
内联
并排
Showing
25 changed file
with
570 addition
and
85 deletion
+570
-85
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+7
-2
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+4
-1
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc
...erators/distributed/async_sparse_param_update_recorder.cc
+27
-0
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h
...perators/distributed/async_sparse_param_update_recorder.h
+183
-0
paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc
...rs/distributed/async_sparse_param_update_recorder_test.cc
+99
-0
paddle/fluid/operators/distributed/brpc/brpc_client.cc
paddle/fluid/operators/distributed/brpc/brpc_client.cc
+1
-0
paddle/fluid/operators/distributed/brpc/brpc_client.h
paddle/fluid/operators/distributed/brpc/brpc_client.h
+8
-7
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+26
-9
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+1
-1
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+31
-25
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+5
-1
paddle/fluid/operators/distributed/grpc/grpc_server.cc
paddle/fluid/operators/distributed/grpc/grpc_server.cc
+7
-2
paddle/fluid/operators/distributed/parameter_recv.cc
paddle/fluid/operators/distributed/parameter_recv.cc
+51
-17
paddle/fluid/operators/distributed/parameter_send.cc
paddle/fluid/operators/distributed/parameter_send.cc
+1
-1
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+7
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+46
-3
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+3
-0
paddle/fluid/operators/distributed/rpc_common.h
paddle/fluid/operators/distributed/rpc_common.h
+5
-2
paddle/fluid/operators/distributed_ops/CMakeLists.txt
paddle/fluid/operators/distributed_ops/CMakeLists.txt
+2
-2
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+29
-4
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
+3
-0
paddle/fluid/operators/distributed_ops/recv_op.cc
paddle/fluid/operators/distributed_ops/recv_op.cc
+4
-3
paddle/fluid/operators/distributed_ops/send_op.cc
paddle/fluid/operators/distributed_ops/send_op.cc
+3
-3
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+1
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+16
-2
未找到文件。
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
82cff5ec
...
...
@@ -64,9 +64,12 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
height_section
=
boost
::
get
<
std
::
vector
<
int64_t
>>
(
node
->
Op
()
->
GetNullableAttr
(
"sections"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
send_varname_to_ctx
[
send_var_name
]
=
operators
::
distributed
::
RpcContext
(
send_var_name
,
send_varnames
,
epmap
,
height_section
);
epmap
,
height_section
,
trainer_id
);
VLOG
(
3
)
<<
"find and init an send op: "
<<
send_varname_to_ctx
[
send_var_name
];
}
else
if
(
node
->
Name
()
==
"recv"
)
{
...
...
@@ -75,9 +78,11 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node
->
Op
()
->
GetNullableAttr
(
"recv_varnames"
));
auto
epmap
=
boost
::
get
<
std
::
vector
<
std
::
string
>>
(
node
->
Op
()
->
GetNullableAttr
(
"epmap"
));
auto
trainer_id
=
boost
::
get
<
int
>
(
node
->
Op
()
->
GetNullableAttr
(
"trainer_id"
));
recv_varname_to_ctx
[
recv_var_name
]
=
operators
::
distributed
::
RpcContext
(
recv_var_name
,
recv_varnames
,
epmap
,
{});
epmap
,
{}
,
trainer_id
);
nodes_to_delete
.
push_back
(
node
);
VLOG
(
3
)
<<
"find and remove an recv op: "
<<
recv_varname_to_ctx
[
recv_var_name
];
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
82cff5ec
...
...
@@ -9,6 +9,9 @@ else()
endif
()
configure_file
(
send_recv.proto.in
${
CMAKE_CURRENT_SOURCE_DIR
}
/send_recv.proto @ONLY
)
cc_library
(
async_sparse_param_update_recorder SRCS async_sparse_param_update_recorder.cc DEPS enforce simple_threadpool
)
cc_test
(
async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder
)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
...
...
@@ -20,7 +23,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${
GRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
)
DEPS lod_tensor selected_rows_functor memory scope
${
GRPC_DEPS
}
async_sparse_param_update_recorder
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
RPC_DEPS sendrecvop_rpc
${
GRPC_DEPS
}
)
...
...
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.cc
0 → 100644
浏览文件 @
82cff5ec
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
std
::
once_flag
AsyncSparseParamUpdateRecorder
::
init_flag_
;
std
::
unique_ptr
<
AsyncSparseParamUpdateRecorder
>
AsyncSparseParamUpdateRecorder
::
recorder_
(
nullptr
);
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h
0 → 100644
浏览文件 @
82cff5ec
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
class
ConcurrentSet
{
public:
ConcurrentSet
()
:
pool_
(
new
::
ThreadPool
(
1
))
{}
~
ConcurrentSet
()
{}
std
::
future
<
void
>
Update
(
const
std
::
vector
<
int64_t
>&
rows
)
{
auto
task
=
[
this
,
rows
]
{
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
id
:
rows
)
{
sstream
<<
id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"update ids -> "
<<
sstream
.
str
();
}
for
(
auto
row
:
rows
)
{
set_
.
insert
(
row
);
}
};
return
pool_
->
enqueue
(
std
::
move
(
task
));
}
std
::
future
<
void
>
GetAndClear
(
std
::
vector
<
int64_t
>*
result
)
{
auto
task
=
[
this
,
&
result
]
{
result
->
clear
();
for
(
auto
&
id
:
set_
)
{
result
->
push_back
(
id
);
}
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
id
:
*
result
)
{
sstream
<<
id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"result ids size: "
<<
result
->
size
()
<<
" "
<<
sstream
.
str
();
}
set_
.
clear
();
};
return
pool_
->
enqueue
(
std
::
move
(
task
));
}
private:
std
::
unordered_set
<
int64_t
>
set_
;
std
::
unique_ptr
<::
ThreadPool
>
pool_
{
nullptr
};
};
class
AsyncSparseParamUpdateRecorder
{
using
TrainerToRows
=
std
::
vector
<
std
::
unique_ptr
<
ConcurrentSet
>>
;
public:
AsyncSparseParamUpdateRecorder
(
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
:
trainer_num_
(
trainer_num
),
grad_to_param_
(
grad_to_param
)
{
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
item
:
grad_to_param
)
{
sstream
<<
item
.
first
<<
":"
<<
item
.
second
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"trainer_num: "
<<
trainer_num
<<
" grad_to_param_: "
<<
sstream
.
str
();
}
for
(
auto
&
iter
:
grad_to_param
)
{
param_to_grad_
[
iter
.
second
]
=
iter
.
first
;
auto
&
param_name
=
iter
.
second
;
param_to_updated_rows_
[
param_name
]
=
TrainerToRows
();
auto
&
trainer_to_rows
=
param_to_updated_rows_
[
param_name
];
for
(
auto
i
=
0
;
i
<
trainer_num
;
++
i
)
{
trainer_to_rows
.
emplace_back
(
new
ConcurrentSet
());
}
}
}
~
AsyncSparseParamUpdateRecorder
()
=
default
;
void
Update
(
const
std
::
string
&
grad_name
,
const
std
::
vector
<
int64_t
>&
update_rows
)
{
VLOG
(
3
)
<<
"update grad: "
<<
grad_name
<<
" row size: "
<<
update_rows
.
size
();
auto
&
param_name
=
grad_to_param_
.
at
(
grad_name
);
auto
&
trainer_to_rows
=
param_to_updated_rows_
.
at
(
param_name
);
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
auto
&
set
:
trainer_to_rows
)
{
fs
.
push_back
(
set
->
Update
(
update_rows
));
}
for
(
auto
&
f
:
fs
)
{
f
.
wait
();
}
}
void
GetAndClear
(
const
std
::
string
&
param_name
,
int
trainer_id
,
std
::
vector
<
int64_t
>*
result
)
{
VLOG
(
3
)
<<
"GetAndClear param: "
<<
param_name
<<
" for trainer: "
<<
trainer_id
;
PADDLE_ENFORCE_LT
(
trainer_id
,
trainer_num_
);
param_to_updated_rows_
.
at
(
param_name
)[
trainer_id
]
->
GetAndClear
(
result
)
.
wait
();
}
bool
HasParam
(
const
std
::
string
&
param_name
)
{
return
param_to_grad_
.
find
(
param_name
)
!=
param_to_grad_
.
end
();
}
bool
HasGrad
(
const
std
::
string
&
grad_name
)
{
return
grad_to_param_
.
find
(
grad_name
)
!=
grad_to_param_
.
end
();
}
private:
const
int
trainer_num_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_param_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
param_to_grad_
;
std
::
unordered_map
<
std
::
string
,
TrainerToRows
>
param_to_updated_rows_
;
// init recorder
public:
static
void
Init
(
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
{
InitImpl
(
trainer_num
,
grad_to_param
);
}
static
AsyncSparseParamUpdateRecorder
*
GetInstance
()
{
return
recorder_
.
get
();
}
private:
// Init is called by GetInstance.
static
void
InitImpl
(
int
trainer_num
,
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
grad_to_param
)
{
if
(
recorder_
==
nullptr
)
{
recorder_
.
reset
(
new
AsyncSparseParamUpdateRecorder
(
trainer_num
,
grad_to_param
));
}
}
static
std
::
once_flag
init_flag_
;
static
std
::
unique_ptr
<
AsyncSparseParamUpdateRecorder
>
recorder_
;
};
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/async_sparse_param_update_recorder_test.cc
0 → 100644
浏览文件 @
82cff5ec
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include <algorithm>
#include "gtest/gtest.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
TEST
(
ConcurrentSet
,
All
)
{
ConcurrentSet
concurrent_set
;
std
::
vector
<
int64_t
>
in1
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
in2
=
{
2
,
3
,
5
,
6
};
std
::
vector
<
std
::
future
<
void
>>
futures
;
futures
.
push_back
(
concurrent_set
.
Update
(
in1
));
futures
.
push_back
(
concurrent_set
.
Update
(
in2
));
for
(
auto
&
f
:
futures
)
{
f
.
wait
();
}
std
::
unordered_set
<
int64_t
>
in
;
std
::
copy
(
in1
.
begin
(),
in1
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
std
::
copy
(
in2
.
begin
(),
in2
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
std
::
vector
<
int64_t
>
ret
;
concurrent_set
.
GetAndClear
(
&
ret
).
wait
();
std
::
unordered_set
<
int64_t
>
out
;
std
::
copy
(
ret
.
begin
(),
ret
.
end
(),
std
::
inserter
(
out
,
out
.
begin
()));
EXPECT_EQ
(
in
,
out
);
concurrent_set
.
GetAndClear
(
&
ret
).
wait
();
EXPECT_EQ
(
ret
.
size
(),
0
);
}
TEST
(
AsyncSparseParamUpdateRecorder
,
All
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_param
;
grad_to_param
[
"grad1"
]
=
"param1"
;
grad_to_param
[
"grad2"
]
=
"param2"
;
int
trainer_num
=
10
;
AsyncSparseParamUpdateRecorder
recorder
(
trainer_num
,
grad_to_param
);
std
::
vector
<
int64_t
>
in1
=
{
1
,
2
,
3
,
4
};
std
::
vector
<
int64_t
>
in2
=
{
2
,
3
,
5
,
6
};
std
::
unordered_set
<
int64_t
>
in
;
std
::
copy
(
in1
.
begin
(),
in1
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
std
::
copy
(
in2
.
begin
(),
in2
.
end
(),
std
::
inserter
(
in
,
in
.
begin
()));
recorder
.
Update
(
"grad1"
,
in1
);
recorder
.
Update
(
"grad1"
,
in2
);
EXPECT_TRUE
(
recorder
.
HasParam
(
"param1"
));
EXPECT_TRUE
(
recorder
.
HasParam
(
"param2"
));
EXPECT_FALSE
(
recorder
.
HasParam
(
"param3"
));
EXPECT_TRUE
(
recorder
.
HasGrad
(
"grad1"
));
EXPECT_TRUE
(
recorder
.
HasGrad
(
"grad2"
));
EXPECT_FALSE
(
recorder
.
HasGrad
(
"grad3"
));
std
::
vector
<
int64_t
>
ret
;
EXPECT_ANY_THROW
(
recorder
.
GetAndClear
(
"param1"
,
trainer_num
,
&
ret
));
for
(
int
i
=
0
;
i
<
trainer_num
;
++
i
)
{
std
::
vector
<
int64_t
>
ret
;
std
::
unordered_set
<
int64_t
>
out
;
recorder
.
GetAndClear
(
"param1"
,
i
,
&
ret
);
std
::
copy
(
ret
.
begin
(),
ret
.
end
(),
std
::
inserter
(
out
,
out
.
begin
()));
EXPECT_EQ
(
in
,
out
);
recorder
.
GetAndClear
(
"param1"
,
i
,
&
ret
);
EXPECT_EQ
(
ret
.
size
(),
0
);
}
}
}
// namespace distributed
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/distributed/brpc/brpc_client.cc
浏览文件 @
82cff5ec
...
...
@@ -234,6 +234,7 @@ VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
out_var_name
,
kGetRPC
,
time_out
);
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.h
浏览文件 @
82cff5ec
...
...
@@ -21,8 +21,10 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
#include "brpc/channel.h"
...
...
@@ -66,6 +68,7 @@ class BRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerBarrier
(
...
...
@@ -107,13 +110,11 @@ class BRPCClient : public RPCClient {
void
SendComplete
()
override
;
private:
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
VarHandlePtr
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
void
Proceed
();
ChannelQueuePtr
GetChannel
(
const
std
::
string
&
ep
);
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
82cff5ec
...
...
@@ -32,6 +32,9 @@ DEFINE_int32(communicator_send_queue_size, 20,
DEFINE_int32
(
communicator_max_send_grad_num_before_recv
,
20
,
"max grad num to send before recv parameters"
);
DEFINE_int32
(
communicator_thread_pool_size
,
5
,
"thread num to do send or recv"
);
DEFINE_int32
(
communicator_send_wait_times
,
5
,
"times that send thread will wait if merge num does not reach "
"max_merge_var_num"
);
DEFINE_int32
(
communicator_max_merge_var_num
,
20
,
"max var num to merge and send"
);
DEFINE_bool
(
communicator_fake_rpc
,
false
,
...
...
@@ -65,6 +68,8 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
<<
FLAGS_communicator_max_send_grad_num_before_recv
;
VLOG
(
0
)
<<
"communicator_thread_pool_size: "
<<
FLAGS_communicator_thread_pool_size
;
VLOG
(
0
)
<<
"communicator_send_wait_times: "
<<
FLAGS_communicator_send_wait_times
;
VLOG
(
0
)
<<
"communicator_max_merge_var_num: "
<<
FLAGS_communicator_max_merge_var_num
;
VLOG
(
0
)
<<
"communicator_fake_rpc: "
<<
FLAGS_communicator_fake_rpc
;
...
...
@@ -101,20 +106,32 @@ void Communicator::SendThread() {
VLOG
(
3
)
<<
var_name
<<
" merge and send"
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
size_t
merged_var_num
=
0
;
while
(
var_queue
->
Size
()
>
0
&&
merged_var_num
<
FLAGS_communicator_max_merge_var_num
)
{
vars
.
push_back
(
var_queue
->
Pop
());
// only count the send number of the first var
if
(
var_name
==
send_varname_to_queue_
.
begin
()
->
first
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
size_t
wait_times
=
0
;
while
(
merged_var_num
<
FLAGS_communicator_max_merge_var_num
)
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
FLAGS_communicator_send_wait_times
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
wait_times
=
0
;
vars
.
push_back
(
var_queue
->
Pop
());
// only count the send number of the first var
if
(
var_name
==
send_varname_to_queue_
.
begin
()
->
first
)
{
grad_num_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
}
merged_var_num
++
;
}
merged_var_num
++
;
}
auto
before_merge
=
GetCurrentUS
();
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
VLOG
(
3
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
FLAGS_communicator_fake_rpc
)
{
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
82cff5ec
...
...
@@ -109,7 +109,7 @@ inline void MergeVars(const std::string& var_name,
auto
*
out_var
=
scope
->
Var
(
var_name
);
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
dims
=
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor "
<<
dims
;
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor
dims
"
<<
dims
;
// init output tensor
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
82cff5ec
...
...
@@ -128,9 +128,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetRPC
,
var_name
,
out_varname
,
"/sendrecv.SendRecvService/GetVariable"
,
time_out
);
"/sendrecv.SendRecvService/GetVariable"
,
table_name
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetVarNoBarrier
(
...
...
@@ -142,7 +144,7 @@ VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetNoBarrierRPC
,
var_name_no_barrier
,
out_varname
,
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
time_out
);
"/sendrecv.SendRecvService/GetVariableNoBarrier"
,
""
,
time_out
);
}
VarHandlePtr
GRPCClient
::
AsyncGetMonomerVariable
(
...
...
@@ -150,18 +152,21 @@ VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
kGetMonomerRPC
,
var_name
,
var_name
,
"/sendrecv.SendRecvService/GetMonomerVariable"
,
time_out
);
"/sendrecv.SendRecvService/GetMonomerVariable"
,
""
,
time_out
);
}
VarHandlePtr
GRPCClient
::
_AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
)
{
const
std
::
string
&
rpc_path
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
out_varname_val
=
out_varname
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
...
...
@@ -169,32 +174,33 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_varname_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
(
[
var_name_val
,
out_varname_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
framework
::
AsyncIO
([
var_name_val
,
out_varname_val
,
table_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
req
.
set_table_name
(
table_name_val
);
::
grpc
::
ByteBuffer
buf
;
RequestToByteBuffer
<
sendrecv
::
VariableMessage
>
(
req
,
&
buf
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
platform
::
RecordRPCEvent
record_event
(
method
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
rpc_path
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
req_count_
++
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
82cff5ec
...
...
@@ -23,9 +23,11 @@ limitations under the License. */
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <vector>
#include "grpc++/channel.h"
...
...
@@ -187,6 +189,7 @@ class GRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVarNoBarrier
(
...
...
@@ -239,7 +242,8 @@ class GRPCClient : public RPCClient {
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
method
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
rpc_path
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
const
std
::
string
&
rpc_path
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
private:
grpc
::
CompletionQueue
cq_
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_server.cc
浏览文件 @
82cff5ec
...
...
@@ -137,6 +137,7 @@ class RequestGet final : public RequestBase {
// proc request.
std
::
string
varname
=
request_
.
varname
();
std
::
string
out_varname
=
request_
.
out_varname
();
std
::
string
table_name
=
request_
.
table_name
();
int
trainer_id
=
request_
.
trainer_id
();
VLOG
(
4
)
<<
"RequestGet "
<<
out_varname
<<
" from "
<<
varname
;
...
...
@@ -145,19 +146,23 @@ class RequestGet final : public RequestBase {
framework
::
Variable
*
invar
=
nullptr
;
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
tmp_scope_
=
std
::
move
(
scope
->
NewTmpScope
());
request_handler_
->
Handle
(
varname
,
tmp_scope_
.
get
(),
invar
,
&
outvar
,
trainer_id
,
out_varname
,
table_name
);
VLOG
(
1
)
<<
"before SerializeToByteBuffer"
;
if
(
outvar
)
{
SerializeToByteBuffer
(
out_varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
VLOG
(
1
)
<<
"after SerializeToByteBuffer"
;
Finish
(
reply_
,
&
responder_
);
}
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
std
::
unique_ptr
<
framework
::
Scope
>
tmp_scope_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
};
...
...
paddle/fluid/operators/distributed/parameter_recv.cc
浏览文件 @
82cff5ec
...
...
@@ -42,27 +42,23 @@ using DDim = framework::DDim;
template
<
typename
T
>
void
ParameterRecv
<
T
>::
operator
()(
const
RpcContext
&
rpc_ctx
,
const
framework
::
Scope
&
scope
)
{
VLOG
(
3
)
<<
"ParameterRecv in
"
;
VLOG
(
3
)
<<
"ParameterRecv in
"
<<
rpc_ctx
.
var_name
;
std
::
unique_ptr
<
framework
::
Scope
>
local_scope
=
scope
.
NewTmpScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
auto
*
recv_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
std
::
vector
<
framework
::
Tensor
*>
recved_tensors
;
// recv all vars to local scope
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
rpc_ctx
.
splited_var_names
.
size
();
i
++
)
{
auto
&
recv_var_name
=
rpc_ctx
.
splited_var_names
[
i
];
framework
::
Tensor
*
t
=
local_scope
->
Var
(
recv_var_name
)
->
GetMutable
<
framework
::
LoDTensor
>
();
recved_tensors
.
push_back
(
t
);
local_scope
->
Var
(
recv_var_name
);
VLOG
(
3
)
<<
"recv "
<<
recv_var_name
<<
" from "
<<
rpc_ctx
.
epmap
[
i
];
rets
.
push_back
(
rpc_client
->
AsyncGetVar
(
rpc_ctx
.
epmap
[
i
],
cpu_ctx
,
*
local_scope
.
get
(),
recv_var_name
,
...
...
@@ -78,23 +74,61 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
// concat recved tensor into one var
{
size_t
output_offset
=
0
;
size_t
row_offset
=
0
;
framework
::
Tensor
*
recv_tensor
=
recv_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
int64_t
recv_numel
=
0
;
for
(
auto
*
in
:
recved_tensors
)
{
recv_numel
+=
in
->
numel
();
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
auto
out_stride
=
framework
::
stride_numel
(
recv_tensor
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
dev_ctx
,
0
,
recv_tensor
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
->
data
<
T
>
(),
in_stride
,
in_stride
[
0
]);
output_offset
+=
in_stride
[
0
];
for
(
auto
&
recv_var_name
:
rpc_ctx
.
splited_var_names
)
{
auto
*
recv_var
=
local_scope
->
FindVar
(
recv_var_name
);
if
(
recv_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
in
=
recv_var
->
Get
<
framework
::
LoDTensor
>
();
recv_numel
+=
in
.
numel
();
auto
in_stride
=
framework
::
stride_numel
(
in
.
dims
());
auto
out_stride
=
framework
::
stride_numel
(
recv_tensor
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
dev_ctx
,
0
,
recv_tensor
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
.
data
<
T
>
(),
in_stride
,
in_stride
[
0
]);
output_offset
+=
in_stride
[
0
];
}
else
if
(
recv_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
recv_slr
=
recv_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
recv_dims
=
recv_tensor
->
dims
();
int64_t
width
=
recv_dims
[
1
];
recv_numel
+=
recv_slr
.
height
()
*
width
;
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
1
],
width
);
PADDLE_ENFORCE_EQ
(
recv_slr
.
value
().
dims
()[
0
],
recv_slr
.
rows
().
size
());
VLOG
(
3
)
<<
"recv slr "
<<
recv_var_name
<<
" dims "
<<
recv_slr
.
value
().
dims
();
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
recv_slr
.
rows
())
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"recv_slr size: "
<<
recv_slr
.
rows
().
size
()
<<
" "
<<
sstream
.
str
();
}
for
(
auto
i
=
0
;
i
<
recv_slr
.
rows
().
size
();
++
i
)
{
auto
row_id
=
recv_slr
.
rows
()[
i
]
+
row_offset
;
PADDLE_ENFORCE_LT
(
row_id
,
recv_dims
[
0
]);
memcpy
(
recv_tensor
->
data
<
T
>
()
+
row_id
*
width
,
recv_slr
.
value
().
data
<
T
>
()
+
i
*
width
,
sizeof
(
T
)
*
width
);
}
row_offset
+=
recv_slr
.
height
();
}
else
{
PADDLE_THROW
(
"unsupported recieved var type"
);
}
}
auto
numel
=
recv_tensor
->
numel
();
if
(
recv_numel
!=
numel
)
{
LOG
(
FATAL
)
<<
"recv_numel: "
<<
recv_numel
<<
" acture numel: "
<<
numel
;
}
PADDLE_ENFORCE_EQ
(
recv_numel
,
recv_tensor
->
numel
()
);
PADDLE_ENFORCE_EQ
(
recv_numel
,
numel
);
}
VLOG
(
3
)
<<
"ParameterRecv out
"
;
VLOG
(
3
)
<<
"ParameterRecv out
"
<<
rpc_ctx
.
var_name
;
}
template
struct
ParameterRecv
<
float
>;
...
...
paddle/fluid/operators/distributed/parameter_send.cc
浏览文件 @
82cff5ec
...
...
@@ -47,7 +47,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
0
);
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
rpc_ctx
.
trainer_id
);
auto
*
send_var
=
scope
.
FindVar
(
rpc_ctx
.
var_name
);
size_t
out_num
=
rpc_ctx
.
splited_var_names
.
size
();
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
82cff5ec
...
...
@@ -18,7 +18,9 @@
#include <condition_variable> // NOLINT
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
...
...
@@ -180,6 +182,10 @@ class RequestHandler {
grad_to_prepared_ctx_
=
g
;
}
void
SetSparseGradToParam
(
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
g
)
{
sparse_grad_to_param_
=
g
;
}
void
SetRPCServer
(
RPCServer
*
rpc_server
)
{
rpc_server_
=
rpc_server
;
}
// Get attributes.
...
...
@@ -228,6 +234,7 @@ class RequestHandler {
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>*
grad_to_prepared_ctx_
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>*
sparse_grad_to_param_
;
RPCServer
*
rpc_server_
;
};
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
82cff5ec
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
...
...
@@ -59,6 +60,12 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasGrad
(
varname
))
{
auto
&
grad_slr
=
scope
->
FindVar
(
varname
)
->
Get
<
framework
::
SelectedRows
>
();
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
Update
(
varname
,
grad_slr
.
rows
());
}
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
return
true
;
...
...
@@ -82,8 +89,9 @@ bool RequestGetHandler::Handle(const std::string& varname,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
;
VLOG
(
3
)
<<
"RequestGetHandler:"
<<
varname
<<
" out_var_name: "
<<
out_var_name
<<
" trainer_id: "
<<
trainer_id
<<
" table_name: "
<<
table_name
;
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
...
...
@@ -108,7 +116,42 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG
(
3
)
<<
"copying "
<<
varname
<<
" to "
<<
param_bak_name
;
framework
::
TensorCopy
(
t_orig
,
dev_ctx_
->
GetPlace
(),
t
);
}
*
outvar
=
scope_
->
FindVar
(
varname
);
if
(
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
HasParam
(
varname
)
&&
!
table_name
.
empty
())
{
std
::
vector
<
int64_t
>
updated_rows
;
AsyncSparseParamUpdateRecorder
::
GetInstance
()
->
GetAndClear
(
varname
,
trainer_id
,
&
updated_rows
);
if
(
VLOG_IS_ON
(
3
))
{
std
::
ostringstream
sstream
;
sstream
<<
"["
;
for
(
auto
&
row_id
:
updated_rows
)
{
sstream
<<
row_id
<<
", "
;
}
sstream
<<
"]"
;
VLOG
(
3
)
<<
"updated_rows size: "
<<
updated_rows
.
size
()
<<
" "
<<
sstream
.
str
();
}
auto
&
origin_tensor
=
scope_
->
FindVar
(
varname
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
origin_tensor_data
=
origin_tensor
.
data
<
float
>
();
auto
&
dims
=
origin_tensor
.
dims
();
*
outvar
=
scope
->
Var
();
auto
*
out_slr
=
(
*
outvar
)
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
set_rows
(
updated_rows
);
out_slr
->
set_height
(
dims
[
0
]);
auto
out_dims
=
framework
::
make_ddim
(
{
static_cast
<
int64_t
>
(
updated_rows
.
size
()),
dims
[
1
]});
auto
*
data
=
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
(
out_dims
,
origin_tensor
.
place
());
auto
width
=
dims
[
1
];
for
(
auto
i
=
0
;
i
<
updated_rows
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
updated_rows
[
i
],
dims
[
0
]);
memcpy
(
data
+
i
*
width
,
origin_tensor_data
+
updated_rows
[
i
]
*
width
,
sizeof
(
float
)
*
width
);
}
}
else
{
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
}
return
true
;
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
82cff5ec
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <condition_variable> // NOLINT
#include <memory>
#include <string>
#include "gflags/gflags.h"
...
...
@@ -44,6 +45,7 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncGetVarNoBarrier
(
...
...
@@ -96,6 +98,7 @@ class RPCClient {
// Init is called by GetInstance.
template
<
typename
T
>
static
void
Init
(
int
trainer_id
)
{
VLOG
(
0
)
<<
"init rpc client with trainer_id "
<<
trainer_id
;
trainer_id_
=
trainer_id
;
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
T
());
...
...
paddle/fluid/operators/distributed/rpc_common.h
浏览文件 @
82cff5ec
...
...
@@ -27,23 +27,26 @@ struct RpcContext {
RpcContext
(
const
std
::
string
&
name
,
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
std
::
string
>
&
emap
,
const
std
::
vector
<
int64_t
>
&
sections
)
const
std
::
vector
<
int64_t
>
&
sections
,
int
id
)
:
var_name
(
name
),
splited_var_names
(
names
),
epmap
(
emap
),
height_sections
(
sections
)
{}
height_sections
(
sections
),
trainer_id
(
id
)
{}
RpcContext
(
const
RpcContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_var_names
=
ctx
.
splited_var_names
;
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
trainer_id
=
ctx
.
trainer_id
;
}
std
::
string
var_name
;
std
::
vector
<
std
::
string
>
splited_var_names
;
std
::
vector
<
std
::
string
>
epmap
;
std
::
vector
<
int64_t
>
height_sections
;
int
trainer_id
;
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RpcContext
&
rpc_ctx
)
{
...
...
paddle/fluid/operators/distributed_ops/CMakeLists.txt
浏览文件 @
82cff5ec
...
...
@@ -2,9 +2,9 @@ include(operators)
set
(
DISTRIBUTE_DEPS
""
)
if
(
WITH_GRPC
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator
async_sparse_param_update_recorder
grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node
)
else
()
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
set
(
DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator
async_sparse_param_update_recorder
brpc leveldb snappystream snappy protobuf ssl crypto zlib node
)
if
(
WITH_BRPC_RDMA
)
find_library
(
IBVERBS_LIBRARY NAMES ibverbs
)
ADD_LIBRARY
(
ibverbs SHARED IMPORTED GLOBAL
)
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
82cff5ec
...
...
@@ -24,8 +24,10 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
rpc_send_thread_num
,
12
,
"number of threads for rpc send"
);
...
...
@@ -292,6 +294,8 @@ static void FillRequestCtx(
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
*
prefetch_ctx
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
*
sparse_grad_name_to_param_name
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
checkpoint_ctx
,
distributed
::
RPCServer
*
rpc_server
)
{
h
->
SetScope
(
scope
);
...
...
@@ -299,6 +303,7 @@ static void FillRequestCtx(
h
->
SetExecutor
(
executor
);
h
->
SetProgram
(
program
);
h
->
SetPrefetchPreparedCtx
(
prefetch_ctx
);
h
->
SetSparseGradToParam
(
sparse_grad_name_to_param_name
);
h
->
SetRPCServer
(
rpc_server
);
h
->
SetCheckpointNotifyPreparedCtx
(
checkpoint_ctx
);
}
...
...
@@ -414,10 +419,24 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx
[
prefetch_var_name
]
=
prefetch_prepared
[
i
];
}
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
ckpt_pre_context
,
rpc_service_
.
get
());
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std
::
unordered_map
<
std
::
string
,
std
::
string
>
sparse_grad_name_to_param_name
;
auto
sparse_grad_name_to_param_name_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
kSparseGradToParam
);
for
(
const
auto
&
sparse_grad_name_and_param_name
:
sparse_grad_name_to_param_name_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
sparse_grad_name_and_param_name
,
':'
,
&
pieces
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
VLOG
(
3
)
<<
"after split, sparse_grad_name = "
<<
pieces
[
0
]
<<
", param_name = "
<<
pieces
[
1
];
sparse_grad_name_to_param_name
[
pieces
[
0
]]
=
pieces
[
1
];
}
auto
f
=
std
::
bind
(
FillRequestCtx
,
std
::
placeholders
::
_1
,
&
recv_scope
,
&
dev_ctx
,
&
executor
,
program
,
&
prefetch_var_name_to_prepared_ctx
,
&
sparse_grad_name_to_param_name
,
ckpt_pre_context
,
rpc_service_
.
get
());
f
(
request_send_handler_
.
get
());
f
(
request_get_handler_
.
get
());
...
...
@@ -445,6 +464,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
RunSyncLoop
(
&
executor
,
program
,
&
recv_scope
,
&
dev_ctx
,
prefetch_block_id_list
,
checkpoint_block_id
);
}
else
{
distributed
::
AsyncSparseParamUpdateRecorder
::
Init
(
fan_in
,
sparse_grad_name_to_param_name
);
RunAsyncLoop
(
&
executor
,
program
,
&
recv_scope
);
}
}
...
...
@@ -475,6 +496,10 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kPrefetchVarNameToBlockId
,
"prefetch blocks to run on server side."
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
kSparseGradToParam
,
"sparse grad name to param name. like: 'emb@Grad:emb'"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
.
SetDefault
(
1
);
AddAttr
<
int
>
(
kCheckpointBlockId
,
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.h
浏览文件 @
82cff5ec
...
...
@@ -16,8 +16,10 @@ limitations under the License. */
#include <stdint.h>
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
...
...
@@ -35,6 +37,7 @@ namespace operators {
constexpr
char
kOptimizeBlocks
[]
=
"optimize_blocks"
;
constexpr
char
kPrefetchVarNameToBlockId
[]
=
"prefetch_var_name_to_block_id"
;
constexpr
char
kCheckpointBlockId
[]
=
"checkpint_block_id"
;
constexpr
char
kSparseGradToParam
[]
=
"sparse_grad_to_param"
;
void
RunServer
(
std
::
shared_ptr
<
distributed
::
RPCServer
>
service
);
...
...
paddle/fluid/operators/distributed_ops/recv_op.cc
浏览文件 @
82cff5ec
...
...
@@ -50,17 +50,18 @@ class RecvOp : public framework::OperatorBase {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
std
::
string
>
recv_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"recv_varnames"
);
if
(
recv_varnames
.
size
()
>
0
)
{
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
auto
rpc_ctx
=
distributed
::
RpcContext
(
outs
[
0
],
recv_varnames
,
epmap
,
{});
auto
rpc_ctx
=
distributed
::
RpcContext
(
outs
[
0
],
recv_varnames
,
epmap
,
{},
trainer_id
);
recv_functor
(
rpc_ctx
,
scope
);
}
else
{
if
(
with_barrier
)
{
...
...
paddle/fluid/operators/distributed_ops/send_op.cc
浏览文件 @
82cff5ec
...
...
@@ -42,6 +42,7 @@ class SendOp : public framework::OperatorBase {
auto
epmap
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
int
sync_send
=
Attr
<
int
>
(
"sync_mode"
);
auto
trainer_id
=
Attr
<
int
>
(
"trainer_id"
);
auto
send_varnames
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"send_varnames"
);
auto
height_sections
=
Attr
<
std
::
vector
<
int64_t
>>
(
"sections"
);
...
...
@@ -51,7 +52,7 @@ class SendOp : public framework::OperatorBase {
if
(
distributed
::
Communicator
::
GetInstance
()
==
nullptr
)
{
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
rpc_ctx
=
distributed
::
RpcContext
(
ins
[
0
],
send_varnames
,
epmap
,
height_sections
);
height_sections
,
trainer_id
);
send_functor
(
rpc_ctx
,
scope
,
true
);
}
else
{
distributed
::
Communicator
::
GetInstance
()
->
Send
(
ins
[
0
],
scope
);
...
...
@@ -62,8 +63,7 @@ class SendOp : public framework::OperatorBase {
auto
&
ctx
=
*
pool
.
Get
(
place
);
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
Attr
<
int
>
(
"trainer_id"
));
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
trainer_id
);
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
i
++
)
{
...
...
python/paddle/fluid/__init__.py
浏览文件 @
82cff5ec
...
...
@@ -175,6 +175,7 @@ def __bootstrap__():
read_env_flags
.
append
(
'communicator_thread_pool_size'
)
read_env_flags
.
append
(
'communicator_max_merge_var_num'
)
read_env_flags
.
append
(
'communicator_fake_rpc'
)
read_env_flags
.
append
(
'communicator_send_wait_times'
)
if
core
.
is_compiled_with_brpc
():
read_env_flags
.
append
(
'max_body_size'
)
#set brpc max body size
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
82cff5ec
...
...
@@ -658,6 +658,7 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
splited_var
},
attrs
=
{
"epmap"
:
eps
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -669,6 +670,7 @@ class DistributeTranspiler(object):
outputs
=
{
"Out"
:
fetch_barrier_out
},
attrs
=
{
"endpoints"
:
self
.
pserver_endpoints
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
})
...
...
@@ -791,11 +793,15 @@ class DistributeTranspiler(object):
global_ops
=
[]
# sparse grad name to param name
sparse_grad_to_param
=
[]
def
__append_optimize_op__
(
op
,
block
,
grad_to_block_id
,
merged_var
,
lr_ops
):
if
self
.
_is_optimizer_op
(
op
):
self
.
_append_pserver_ops
(
block
,
op
,
endpoint
,
grad_to_block_id
,
self
.
origin_program
,
merged_var
)
self
.
origin_program
,
merged_var
,
sparse_grad_to_param
)
elif
op
not
in
lr_ops
:
self
.
_append_pserver_non_opt_ops
(
block
,
op
)
...
...
@@ -911,6 +917,7 @@ class DistributeTranspiler(object):
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
,
"sparse_grad_to_param"
:
sparse_grad_to_param
,
}
if
self
.
has_distributed_lookup_table
:
...
...
@@ -1779,7 +1786,8 @@ class DistributeTranspiler(object):
return
o4
def
_append_pserver_ops
(
self
,
optimize_block
,
opt_op
,
endpoint
,
grad_to_block_id
,
origin_program
,
merged_var
):
grad_to_block_id
,
origin_program
,
merged_var
,
sparse_grad_to_param
):
program
=
optimize_block
.
program
pserver_block
=
program
.
global_block
()
new_inputs
=
collections
.
OrderedDict
()
...
...
@@ -1863,6 +1871,12 @@ class DistributeTranspiler(object):
outputs
=
outputs
,
attrs
=
opt_op
.
all_attrs
())
# record sparse grad to param name
if
new_inputs
[
"Grad"
].
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
sparse_grad_to_param
.
append
(
str
(
new_inputs
[
"Grad"
].
name
)
+
":"
+
str
(
new_inputs
[
"Param"
]
.
name
))
def
_get_pserver_grad_param_var
(
self
,
var
,
var_dict
):
"""
Return pserver side grad/param variable, return None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录