Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
44debca8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
44debca8
编写于
11月 29, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
11月 29, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14589 from jacquesqiao/refactor-prefetch
Refactor prefetch
上级
20120d9c
839193fd
变更
26
隐藏空白更改
内联
并排
Showing
26 changed file
with
853 addition
and
123 deletion
+853
-123
paddle/fluid/framework/details/multi_devices_graph_pass.cc
paddle/fluid/framework/details/multi_devices_graph_pass.cc
+1
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+7
-1
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+18
-17
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+5
-2
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+1
-0
paddle/fluid/operators/distributed/grpc_serde.cc
paddle/fluid/operators/distributed/grpc_serde.cc
+5
-1
paddle/fluid/operators/distributed/grpc_serde.h
paddle/fluid/operators/distributed/grpc_serde.h
+2
-1
paddle/fluid/operators/distributed/grpc_serde_test.cc
paddle/fluid/operators/distributed/grpc_serde_test.cc
+2
-1
paddle/fluid/operators/distributed/grpc_server.cc
paddle/fluid/operators/distributed/grpc_server.cc
+2
-1
paddle/fluid/operators/distributed/grpc_variable_response.cc
paddle/fluid/operators/distributed/grpc_variable_response.cc
+14
-0
paddle/fluid/operators/distributed/parameter_prefetch.cc
paddle/fluid/operators/distributed/parameter_prefetch.cc
+255
-0
paddle/fluid/operators/distributed/parameter_prefetch.h
paddle/fluid/operators/distributed/parameter_prefetch.h
+34
-0
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+2
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+21
-9
paddle/fluid/operators/distributed/request_handler_impl.h
paddle/fluid/operators/distributed/request_handler_impl.h
+32
-8
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+1
-1
paddle/fluid/operators/distributed/send_recv.proto.in
paddle/fluid/operators/distributed/send_recv.proto.in
+1
-0
paddle/fluid/operators/distributed/variable_response.h
paddle/fluid/operators/distributed/variable_response.h
+1
-0
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+19
-0
paddle/fluid/operators/lookup_table_op.cu
paddle/fluid/operators/lookup_table_op.cu
+42
-21
paddle/fluid/operators/lookup_table_op.h
paddle/fluid/operators/lookup_table_op.h
+63
-37
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+6
-0
python/paddle/fluid/tests/unittests/dist_ctr.py
python/paddle/fluid/tests/unittests/dist_ctr.py
+2
-0
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+41
-0
python/paddle/fluid/tests/unittests/test_lookup_remote_table_op.py
...ddle/fluid/tests/unittests/test_lookup_remote_table_op.py
+203
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+73
-21
未找到文件。
paddle/fluid/framework/details/multi_devices_graph_pass.cc
浏览文件 @
44debca8
...
@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
...
@@ -862,7 +862,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
if
(
node
->
Op
()
->
Type
()
==
"fetch_barrier"
)
{
outvar_dev_id
=
outvar_dev_id
=
GetVarDeviceID
(
*
result
,
output
->
Name
(),
*
sharded_var_device
);
GetVarDeviceID
(
*
result
,
output
->
Name
(),
*
sharded_var_device
);
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
);
PADDLE_ENFORCE_NE
(
outvar_dev_id
,
-
1
,
"output name %s"
,
output
->
Name
()
);
}
}
p
=
places_
[
outvar_dev_id
];
p
=
places_
[
outvar_dev_id
];
ir
::
Node
*
new_node
=
nullptr
;
ir
::
Node
*
new_node
=
nullptr
;
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
44debca8
...
@@ -37,7 +37,13 @@ if (WITH_GPU)
...
@@ -37,7 +37,13 @@ if (WITH_GPU)
SET
(
OP_HEADER_DEPS
${
OP_HEADER_DEPS
}
cub
)
SET
(
OP_HEADER_DEPS
${
OP_HEADER_DEPS
}
cub
)
endif
()
endif
()
register_operators
(
EXCLUDES warpctc_op conv_fusion_op DEPS
${
OP_HEADER_DEPS
}
)
SET
(
OP_PREFETCH_DEPS
""
)
if
(
WITH_DISTRIBUTE
)
SET
(
OP_PREFETCH_DEPS
${
OP_PREFETCH_DEPS
}
parameter_prefetch
)
endif
()
register_operators
(
EXCLUDES warpctc_op conv_fusion_op DEPS
${
OP_HEADER_DEPS
}
${
OP_PREFETCH_DEPS
}
)
# warpctc_op needs cudnn 7 above
# warpctc_op needs cudnn 7 above
if
(
WITH_GPU AND NOT WIN32
)
if
(
WITH_GPU AND NOT WIN32
)
...
...
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
44debca8
...
@@ -9,36 +9,37 @@ else()
...
@@ -9,36 +9,37 @@ else()
endif
()
endif
()
configure_file
(
send_recv.proto.in
${
CMAKE_CURRENT_SOURCE_DIR
}
/send_recv.proto @ONLY
)
configure_file
(
send_recv.proto.in
${
CMAKE_CURRENT_SOURCE_DIR
}
/send_recv.proto @ONLY
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
grpc_library
(
sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_library
(
sendrecvop_grpc SRCS grpc_bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc grpc_server.cc variable_response.cc grpc_variable_response.cc grpc_serde.cc
PROTO send_recv.proto
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
DEPS lod_tensor selected_rows memory
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
cc_test
(
grpc_serde_test SRCS grpc_serde_test.cc
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL
)
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
cc_test
(
rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler
)
cc_test
(
varhandle_test SRCS varhandle_test.cc DEPS profiler
)
return
()
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_grpc memory
)
endif
()
else
()
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
brpc_server.cc brpc_client.cc rpc_server_test.cc brpc_serde_test.cc
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
brpc_variable_response.cc brpc_sendrecvop_utils.cc brpc_rdma_pool.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
brpc_library
(
sendrecvop_brpc SRCS brpc_client.cc brpc_server.cc rpc_server.cc rpc_client.cc request_handler_impl.cc brpc_sendrecvop_utils.cc
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_brpc memory
)
brpc_variable_response.cc variable_response.cc sendrecvop_utils.cc brpc_rdma_pool.cc
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
set
(
brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy
)
set
(
brpc_test_depends sendrecvop_brpc brpc ssl crypto protobuf leveldb gflags glog executor proto_desc lookup_table_op snappystream snappy
)
cc_test
(
brpc_server_test SRCS rpc_server_test.cc
cc_test
(
brpc_server_test SRCS rpc_server_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
DEPS
${
brpc_test_depends
}
SERIAL
)
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
cc_test
(
brpc_serde_test SRCS brpc_serde_test.cc
DEPS
${
brpc_test_depends
}
SERIAL
)
DEPS
${
brpc_test_depends
}
SERIAL
)
endif
()
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
44debca8
...
@@ -171,11 +171,13 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
...
@@ -171,11 +171,13 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
,
int64_t
time_out
)
{
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
in_var_name_val
=
in_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
out_var_name_val
=
out_var_name
;
const
std
::
string
table_name_val
=
table_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch
=
GetChannel
(
ep_val
);
const
auto
ch
=
GetChannel
(
ep_val
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
GetProcessor
*
s
=
new
GetProcessor
(
ch
);
...
@@ -186,11 +188,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
...
@@ -186,11 +188,12 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
s
->
Prepare
(
h
,
time_out
);
s
->
Prepare
(
h
,
time_out
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
s
,
method
,
h
,
t
able_name_val
,
t
his
]
{
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
::
grpc
::
ByteBuffer
req
;
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
out_var_name_val
);
SerializeToByteBuffer
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
out_var_name_val
,
0
,
table_name_val
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
44debca8
...
@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
...
@@ -194,6 +194,7 @@ class GRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncSendBatchBarrier
(
VarHandlePtr
AsyncSendBatchBarrier
(
...
...
paddle/fluid/operators/distributed/grpc_serde.cc
浏览文件 @
44debca8
...
@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
...
@@ -42,7 +42,8 @@ static void SerializeDestroyCallback(void* payload) {
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
void
SerializeToByteBuffer
(
const
std
::
string
&
name
,
framework
::
Variable
*
var
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_name
,
const
int
trainer_id
)
{
const
int
trainer_id
,
const
std
::
string
&
table_name
)
{
platform
::
RecordRPCEvent
record_event
(
"serial"
,
&
ctx
);
platform
::
RecordRPCEvent
record_event
(
"serial"
,
&
ctx
);
VarMsg
request
;
VarMsg
request
;
TensorPayload
*
payload
=
nullptr
;
TensorPayload
*
payload
=
nullptr
;
...
@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -63,6 +64,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if
(
!
out_name
.
empty
())
{
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
request
.
set_out_varname
(
out_name
);
}
}
if
(
!
table_name
.
empty
())
{
request
.
set_table_name
(
table_name
);
}
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
request
.
set_type
(
::
sendrecv
::
LOD_TENSOR
);
payload
=
new
TensorPayload
(
GetTensorPayload
(
var
,
ctx
,
&
request
));
payload
=
new
TensorPayload
(
GetTensorPayload
(
var
,
ctx
,
&
request
));
...
...
paddle/fluid/operators/distributed/grpc_serde.h
浏览文件 @
44debca8
...
@@ -40,7 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
...
@@ -40,7 +40,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
::
grpc
::
ByteBuffer
*
msg
,
::
grpc
::
ByteBuffer
*
msg
,
const
std
::
string
&
out_varname
=
std
::
string
(),
const
std
::
string
&
out_varname
=
std
::
string
(),
const
int
trainer_id
=
0
);
const
int
trainer_id
=
0
,
const
std
::
string
&
table_name
=
std
::
string
());
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
void
DeserializeFromByteBuffer
(
const
::
grpc
::
ByteBuffer
&
msg
,
const
platform
::
DeviceContext
&
ctx
,
const
platform
::
DeviceContext
&
ctx
,
...
...
paddle/fluid/operators/distributed/grpc_serde_test.cc
浏览文件 @
44debca8
...
@@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
...
@@ -130,7 +130,8 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
math
::
set_constant
(
ctx
,
tensor
,
31.9
);
math
::
set_constant
(
ctx
,
tensor
,
31.9
);
::
grpc
::
ByteBuffer
msg
;
::
grpc
::
ByteBuffer
msg
;
operators
::
distributed
::
SerializeToByteBuffer
(
"myvar"
,
&
var
,
ctx
,
&
msg
);
operators
::
distributed
::
SerializeToByteBuffer
(
"myvar"
,
&
var
,
ctx
,
&
msg
,
"outvar"
,
0
,
"table_name"
);
EXPECT_GT
(
msg
.
Length
(),
static_cast
<
size_t
>
(
0
));
EXPECT_GT
(
msg
.
Length
(),
static_cast
<
size_t
>
(
0
));
// deserialize
// deserialize
...
...
paddle/fluid/operators/distributed/grpc_server.cc
浏览文件 @
44debca8
...
@@ -183,6 +183,7 @@ class RequestPrefetch final : public RequestBase {
...
@@ -183,6 +183,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
// prefetch process...
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
in_var_name
=
request_
->
Varname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
std
::
string
out_var_name
=
request_
->
OutVarname
();
std
::
string
table_name
=
request_
->
TableName
();
int
trainer_id
=
request_
->
GetTrainerId
();
int
trainer_id
=
request_
->
GetTrainerId
();
VLOG
(
4
)
<<
"RequestPrefetch, in_var_name: "
<<
in_var_name
VLOG
(
4
)
<<
"RequestPrefetch, in_var_name: "
<<
in_var_name
<<
" out_var_name: "
<<
out_var_name
;
<<
" out_var_name: "
<<
out_var_name
;
...
@@ -193,7 +194,7 @@ class RequestPrefetch final : public RequestBase {
...
@@ -193,7 +194,7 @@ class RequestPrefetch final : public RequestBase {
framework
::
Variable
*
outvar
=
scope
->
Var
(
out_var_name
);
framework
::
Variable
*
outvar
=
scope
->
Var
(
out_var_name
);
request_handler_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
trainer_id
,
request_handler_
->
Handle
(
in_var_name
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_var_name
);
out_var_name
,
table_name
);
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
SerializeToByteBuffer
(
out_var_name
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
&
reply_
);
...
...
paddle/fluid/operators/distributed/grpc_variable_response.cc
浏览文件 @
44debca8
...
@@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) {
...
@@ -301,6 +301,20 @@ int GRPCVariableResponse::Parse(Source* source) {
meta_
.
set_trainer_id
(
trainer_id
);
meta_
.
set_trainer_id
(
trainer_id
);
break
;
break
;
}
}
case
sendrecv
::
VariableMessage
::
kTableNameFieldNumber
:
{
uint32_t
length
;
if
((
wt
!=
WIRETYPE_LENGTH_DELIMITED
)
||
!
input
.
ReadVarint32
(
&
length
))
{
return
tag
;
}
std
::
string
temp
;
if
(
!
input
.
ReadString
(
&
temp
,
length
))
{
return
tag
;
}
meta_
.
set_table_name
(
temp
);
break
;
}
default:
{
default:
{
// Unknown tag, return unknown error.
// Unknown tag, return unknown error.
return
-
1
;
return
-
1
;
...
...
paddle/fluid/operators/distributed/parameter_prefetch.cc
0 → 100644
浏览文件 @
44debca8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include <string>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
SelectedRows
=
framework
::
SelectedRows
;
using
DDim
=
framework
::
DDim
;
static
size_t
GetSectionIndex
(
int64_t
id
,
const
std
::
vector
<
int64_t
>&
abs_sections
)
{
for
(
size_t
i
=
1
;
i
<
abs_sections
.
size
();
++
i
)
{
if
(
id
<
abs_sections
[
i
])
{
return
i
-
1
;
}
}
return
abs_sections
.
size
()
-
1
;
}
static
std
::
vector
<
int64_t
>
ToAbsoluteSection
(
const
std
::
vector
<
int
>&
height_sections
)
{
std
::
vector
<
int64_t
>
abs_sections
;
abs_sections
.
resize
(
height_sections
.
size
());
abs_sections
[
0
]
=
0
;
for
(
size_t
i
=
1
;
i
<
height_sections
.
size
();
++
i
)
{
abs_sections
[
i
]
=
height_sections
[
i
-
1
]
+
abs_sections
[
i
-
1
];
}
return
abs_sections
;
}
static
std
::
vector
<
std
::
vector
<
int64_t
>>
SplitIds
(
const
std
::
vector
<
int64_t
>&
ids_vector
,
const
std
::
vector
<
int
>&
height_section
,
framework
::
Scope
*
scope
)
{
std
::
set
<
int64_t
>
all_ids
;
for
(
auto
id
:
ids_vector
)
{
all_ids
.
insert
(
id
);
}
auto
abs_sections
=
ToAbsoluteSection
(
height_section
);
std
::
vector
<
std
::
vector
<
int64_t
>>
splited_ids
;
splited_ids
.
resize
(
height_section
.
size
()
+
1
);
for
(
auto
&
id
:
all_ids
)
{
auto
section_index
=
GetSectionIndex
(
id
,
abs_sections
);
splited_ids
[
section_index
].
push_back
(
id
-
abs_sections
[
section_index
]);
}
return
splited_ids
;
}
static
void
SplitIdsIntoMultipleVarsBySection
(
const
std
::
vector
<
std
::
string
>&
in_var_names
,
const
std
::
vector
<
int
>&
height_section
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
splited_ids
,
framework
::
Scope
*
scope
)
{
PADDLE_ENFORCE_EQ
(
in_var_names
.
size
(),
height_section
.
size
(),
""
);
auto
place
=
platform
::
CPUPlace
();
for
(
size_t
i
=
0
;
i
<
in_var_names
.
size
();
++
i
)
{
auto
*
id_tensor
=
scope
->
Var
(
in_var_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
&
ids
=
splited_ids
[
i
];
if
(
!
ids
.
empty
())
{
auto
*
id_tensor_data
=
id_tensor
->
mutable_data
<
int64_t
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
ids
.
size
()),
1
}),
place
);
memcpy
(
id_tensor_data
,
ids
.
data
(),
sizeof
(
int64_t
)
*
ids
.
size
());
}
}
}
static
void
MergeMultipleVarsIntoOneBySection
(
const
std
::
string
&
id_name
,
const
std
::
vector
<
int64_t
>&
ids_vector
,
const
std
::
string
&
out_name
,
const
std
::
vector
<
std
::
string
>&
out_var_names
,
const
std
::
vector
<
int
>&
height_section
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
splited_ids
,
const
framework
::
ExecutionContext
&
context
,
framework
::
Scope
*
scope
,
platform
::
DeviceContext
*
actual_ctx
)
{
PADDLE_ENFORCE_EQ
(
out_var_names
.
size
(),
height_section
.
size
(),
""
);
auto
cpu_place
=
platform
::
CPUPlace
();
auto
abs_sections
=
ToAbsoluteSection
(
height_section
);
std
::
unordered_map
<
int64_t
,
std
::
vector
<
size_t
>>
id_to_offset
;
for
(
size_t
i
=
0
;
i
<
ids_vector
.
size
();
++
i
)
{
id_to_offset
[
ids_vector
[
i
]].
push_back
(
i
);
}
auto
&
id_tensor
=
scope
->
FindVar
(
id_name
)
->
Get
<
framework
::
LoDTensor
>
();
auto
*
out_tensor
=
scope
->
FindVar
(
out_name
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
out_tensor_data
=
out_tensor
->
mutable_data
<
float
>
(
id_tensor
.
place
());
bool
is_on_cpu_place
=
true
;
if
(
!
platform
::
is_cpu_place
(
id_tensor
.
place
()))
{
is_on_cpu_place
=
false
;
}
for
(
size_t
section_idx
=
0
;
section_idx
<
out_var_names
.
size
();
++
section_idx
)
{
auto
&
ids_in_this_section
=
splited_ids
[
section_idx
];
if
(
!
ids_in_this_section
.
empty
())
{
auto
&
prefetch_out_var
=
scope
->
Var
(
out_var_names
[
section_idx
])
->
Get
<
framework
::
LoDTensor
>
();
const
auto
*
out_var_data
=
prefetch_out_var
.
data
<
float
>
();
auto
&
dims
=
prefetch_out_var
.
dims
();
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
2
,
""
);
PADDLE_ENFORCE_EQ
(
ids_in_this_section
.
size
(),
dims
[
0
]);
auto
row_numel
=
dims
[
1
];
for
(
size_t
i
=
0
;
i
<
dims
[
0
];
++
i
)
{
auto
id
=
ids_in_this_section
[
i
];
auto
origin_id
=
id
+
abs_sections
[
section_idx
];
auto
&
offsets
=
id_to_offset
[
origin_id
];
for
(
auto
&
offset
:
offsets
)
{
// should support GPU tensor
if
(
is_on_cpu_place
)
{
memory
::
Copy
(
cpu_place
,
out_tensor_data
+
offset
*
row_numel
,
cpu_place
,
out_var_data
+
i
*
row_numel
,
sizeof
(
float
)
*
row_numel
);
}
else
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"paddle is not compiled with CUDA!"
);
#else
auto
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
actual_ctx
)
->
stream
();
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
id_tensor
.
place
()),
out_tensor_data
+
offset
*
row_numel
,
cpu_place
,
out_var_data
+
i
*
row_numel
,
sizeof
(
float
)
*
row_numel
,
stream
);
#endif
}
}
}
}
else
{
VLOG
(
3
)
<<
"ids in this section is empty"
;
}
}
}
void
prefetch
(
const
std
::
string
&
id_name
,
const
std
::
string
&
out_name
,
const
std
::
vector
<
std
::
string
>&
table_names
,
const
std
::
vector
<
std
::
string
>&
epmap
,
const
std
::
vector
<
int
>&
height_sections
,
const
framework
::
ExecutionContext
&
context
)
{
auto
&
local_scope
=
context
.
scope
().
NewScope
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
cpu_ctx
=
*
pool
.
Get
(
platform
::
CPUPlace
());
auto
&
actual_ctx
=
*
pool
.
Get
(
context
.
GetPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
::
GetInstance
<
RPCCLIENT_T
>
(
context
.
Attr
<
int
>
(
"trainer_id"
));
std
::
vector
<
std
::
string
>
in_var_names
;
std
::
vector
<
std
::
string
>
out_var_names
;
for
(
size_t
i
=
0
;
i
<
epmap
.
size
();
++
i
)
{
in_var_names
.
push_back
(
id_name
+
"@"
+
epmap
[
i
]);
out_var_names
.
push_back
(
out_name
+
"@"
+
epmap
[
i
]);
}
auto
&
id_tensor
=
local_scope
.
FindVar
(
id_name
)
->
Get
<
framework
::
LoDTensor
>
();
std
::
vector
<
int64_t
>
ids_vector
;
if
(
platform
::
is_cpu_place
(
id_tensor
.
place
()))
{
auto
*
id_data
=
id_tensor
.
data
<
int64_t
>
();
for
(
size_t
i
=
0
;
i
<
id_tensor
.
numel
();
++
i
)
{
ids_vector
.
push_back
(
id_data
[
i
]);
}
}
else
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"paddle is not compiled with CUDA!"
);
#else
auto
cpu_place
=
platform
::
CPUPlace
();
framework
::
Tensor
cpu_tensor
;
auto
*
cpu_tensor_data
=
cpu_tensor
.
mutable_data
<
int64_t
>
(
id_tensor
.
dims
(),
cpu_place
);
auto
stream
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
&
actual_ctx
)
->
stream
();
memory
::
Copy
(
cpu_place
,
cpu_tensor_data
,
boost
::
get
<
platform
::
CUDAPlace
>
(
id_tensor
.
place
()),
id_tensor
.
data
<
int64_t
>
(),
sizeof
(
int64_t
)
*
id_tensor
.
numel
(),
stream
);
for
(
size_t
i
=
0
;
i
<
cpu_tensor
.
numel
();
++
i
)
{
ids_vector
.
push_back
(
cpu_tensor_data
[
i
]);
}
#endif
}
auto
splited_ids
=
SplitIds
(
ids_vector
,
height_sections
,
&
local_scope
);
SplitIdsIntoMultipleVarsBySection
(
in_var_names
,
height_sections
,
splited_ids
,
&
local_scope
);
// create output var in local scope
for
(
auto
&
name
:
out_var_names
)
{
local_scope
.
Var
(
name
)
->
GetMutable
<
framework
::
LoDTensor
>
();
}
std
::
vector
<
distributed
::
VarHandlePtr
>
rets
;
for
(
size_t
i
=
0
;
i
<
in_var_names
.
size
();
i
++
)
{
if
(
NeedSend
(
local_scope
,
in_var_names
[
i
]))
{
VLOG
(
3
)
<<
"sending "
<<
in_var_names
[
i
]
<<
" to "
<<
epmap
[
i
]
<<
" to get "
<<
out_var_names
[
i
]
<<
" back"
;
rets
.
push_back
(
rpc_client
->
AsyncPrefetchVar
(
epmap
[
i
],
cpu_ctx
,
local_scope
,
in_var_names
[
i
],
out_var_names
[
i
],
table_names
[
i
]));
}
else
{
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
out_var_names
[
i
];
}
}
for
(
size_t
i
=
0
;
i
<
rets
.
size
();
i
++
)
{
PADDLE_ENFORCE
(
rets
[
i
]
->
Wait
(),
"internal error in RPCClient"
);
}
MergeMultipleVarsIntoOneBySection
(
id_name
,
ids_vector
,
out_name
,
out_var_names
,
height_sections
,
splited_ids
,
context
,
&
local_scope
,
&
actual_ctx
);
context
.
scope
().
DeleteScope
(
&
local_scope
);
}
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/parameter_prefetch.h
0 → 100644
浏览文件 @
44debca8
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
operators
{
namespace
distributed
{
void
prefetch
(
const
std
::
string
&
id_name
,
const
std
::
string
&
out_name
,
const
std
::
vector
<
std
::
string
>&
table_names
,
const
std
::
vector
<
std
::
string
>&
epmap
,
const
std
::
vector
<
int
>&
height_sections
,
const
framework
::
ExecutionContext
&
context
);
};
// namespace distributed
};
// namespace operators
};
// namespace paddle
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
44debca8
...
@@ -191,7 +191,8 @@ class RequestHandler {
...
@@ -191,7 +191,8 @@ class RequestHandler {
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
virtual
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
)
=
0
;
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
table_name
=
""
)
=
0
;
protected:
protected:
const
bool
sync_mode_
;
const
bool
sync_mode_
;
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
44debca8
...
@@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -37,7 +37,8 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Sync
// Sync
...
@@ -77,8 +78,10 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -77,8 +78,10 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
sync_mode_
)
{
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
...
@@ -113,14 +116,22 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
...
@@ -113,14 +116,22 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
VLOG
(
4
)
<<
"RequestPrefetchHandler "
<<
varname
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
out_var_name
);
if
(
table_name
.
empty
())
{
InitializeVariable
(
*
outvar
,
var_desc
->
GetType
());
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
out_var_name
);
executor_
->
RunPreparedContext
(
InitializeVariable
(
*
outvar
,
var_desc
->
GetType
());
(
*
prefetch_var_name_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
executor_
->
RunPreparedContext
(
(
*
prefetch_var_name_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
}
else
{
(
*
outvar
)
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
lookup_table_op
=
BuildLookupTableOp
(
table_name
,
varname
,
out_var_name
);
paddle
::
platform
::
CPUPlace
cpu_place
;
lookup_table_op
->
Run
(
*
scope
,
cpu_place
);
}
return
true
;
return
true
;
}
}
...
@@ -129,7 +140,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
...
@@ -129,7 +140,8 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework
::
Variable
*
invar
,
framework
::
Variable
*
invar
,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
)
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
checkpoint_notify_id
!=
-
1
,
checkpoint_notify_id
!=
-
1
,
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
"when checkpoint_notify_id = -1, there should be no RPC invoke."
);
...
...
paddle/fluid/operators/distributed/request_handler_impl.h
浏览文件 @
44debca8
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
...
@@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler {
...
@@ -43,8 +44,8 @@ class RequestSendHandler final : public RequestHandler {
virtual
~
RequestSendHandler
()
{}
virtual
~
RequestSendHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
out_var
_name
=
""
)
override
;
const
std
::
string
&
table
_name
=
""
)
override
;
private:
private:
bool
enable_dc_asgd_
;
bool
enable_dc_asgd_
;
...
@@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler {
...
@@ -59,21 +60,44 @@ class RequestGetHandler final : public RequestHandler {
virtual
~
RequestGetHandler
()
{}
virtual
~
RequestGetHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
out_var
_name
=
""
)
override
;
const
std
::
string
&
table
_name
=
""
)
override
;
private:
private:
bool
enable_dc_asgd_
;
bool
enable_dc_asgd_
;
};
};
static
inline
void
BuildVar
(
const
std
::
string
&
param_name
,
std
::
initializer_list
<
const
char
*>
arguments
,
paddle
::
framework
::
proto
::
OpDesc
::
Var
*
var
)
{
var
->
set_parameter
(
param_name
);
for
(
auto
&
arg_name
:
arguments
)
{
*
var
->
mutable_arguments
()
->
Add
()
=
arg_name
;
}
}
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
class
RequestPrefetchHandler
final
:
public
RequestHandler
{
public:
public:
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
explicit
RequestPrefetchHandler
(
bool
sync_mode
)
:
RequestHandler
(
sync_mode
)
{}
virtual
~
RequestPrefetchHandler
()
{}
virtual
~
RequestPrefetchHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
out_var_name
=
""
)
override
;
const
std
::
string
&
table_name
=
""
)
override
;
private:
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>
BuildLookupTableOp
(
const
std
::
string
&
table_name
,
const
std
::
string
&
id_name
,
const
std
::
string
&
out_name
)
{
paddle
::
framework
::
proto
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"lookup_table"
);
BuildVar
(
"W"
,
{
table_name
.
data
()},
op_desc
.
add_inputs
());
BuildVar
(
"Ids"
,
{
id_name
.
data
()},
op_desc
.
add_inputs
());
BuildVar
(
"Out"
,
{
out_name
.
data
()},
op_desc
.
add_outputs
());
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
return
op
;
}
};
};
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
class
RequestCheckpointHandler
final
:
public
RequestHandler
{
...
@@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler {
...
@@ -85,8 +109,8 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual
~
RequestCheckpointHandler
()
{}
virtual
~
RequestCheckpointHandler
()
{}
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
bool
Handle
(
const
std
::
string
&
varname
,
framework
::
Scope
*
scope
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
framework
::
Variable
*
var
,
framework
::
Variable
**
outvar
,
const
int
trainer_id
,
const
int
trainer_id
,
const
std
::
string
&
out_var_name
=
""
,
const
std
::
string
&
out_var
_name
=
""
)
override
;
const
std
::
string
&
table
_name
=
""
)
override
;
private:
private:
int
checkpoint_notify_id
;
int
checkpoint_notify_id
;
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
44debca8
...
@@ -48,7 +48,7 @@ class RPCClient {
...
@@ -48,7 +48,7 @@ class RPCClient {
virtual
VarHandlePtr
AsyncPrefetchVar
(
virtual
VarHandlePtr
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
in_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
table_name
=
""
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
virtual
VarHandlePtr
AsyncSendBatchBarrier
(
virtual
VarHandlePtr
AsyncSendBatchBarrier
(
...
...
paddle/fluid/operators/distributed/send_recv.proto.in
浏览文件 @
44debca8
...
@@ -80,6 +80,7 @@ message VariableMessage {
...
@@ -80,6 +80,7 @@ message VariableMessage {
// when profile switches from 1 to 2.
// when profile switches from 1 to 2.
int64 profile = 11;
int64 profile = 11;
int64 trainer_id = 12;
int64 trainer_id = 12;
string table_name = 13;
}
}
message VoidMessage {}
message VoidMessage {}
paddle/fluid/operators/distributed/variable_response.h
浏览文件 @
44debca8
...
@@ -85,6 +85,7 @@ class VariableResponse {
...
@@ -85,6 +85,7 @@ class VariableResponse {
inline
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
framework
::
Scope
*
GetMutableLocalScope
()
const
{
return
local_scope_
;
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
Varname
()
const
{
return
meta_
.
varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
inline
std
::
string
OutVarname
()
const
{
return
meta_
.
out_varname
();
}
inline
std
::
string
TableName
()
const
{
return
meta_
.
table_name
();
}
// should call parse first.
// should call parse first.
framework
::
Variable
*
GetVar
()
{
framework
::
Variable
*
GetVar
()
{
...
...
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
44debca8
...
@@ -87,6 +87,25 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -87,6 +87,25 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"(boolean, default false) "
"If the grad op reuse the input's variable."
)
"If the grad op reuse the input's variable."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
// for parameter prefetch
AddAttr
<
bool
>
(
"remote_prefetch"
,
""
).
SetDefault
(
false
);
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
std
::
vector
<
int
>>
(
"height_sections"
,
"Height for each output SelectedRows."
)
.
SetDefault
(
std
::
vector
<
int
>
({}));
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
,
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping"
)
.
SetDefault
({});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"table_names"
,
"(string vector, the splited table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping"
)
.
SetDefault
({});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Lookup Table Operator.
Lookup Table Operator.
...
...
paddle/fluid/operators/lookup_table_op.cu
浏览文件 @
44debca8
...
@@ -78,27 +78,47 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -78,27 +78,47 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
size_t
N
=
table_t
->
dims
()[
0
];
auto
id_name
=
context
.
Inputs
(
"Ids"
).
front
();
size_t
D
=
table_t
->
dims
()[
1
];
auto
out_name
=
context
.
Outputs
(
"Out"
).
front
();
size_t
K
=
ids_t
->
numel
();
// for remote prefetch
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
epmap
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
auto
*
table
=
table_t
->
data
<
T
>
();
auto
height_sections
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"height_sections"
);
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
table_names
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"table_names"
);
dim3
threads
(
128
,
8
);
if
(
!
epmap
.
empty
())
{
dim3
grids
(
8
,
1
);
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
if
(
padding_idx
==
-
1
)
// server
LookupTable
<
#ifdef PADDLE_WITH_DISTRIBUTE
T
,
128
,
8
,
8
,
operators
::
distributed
::
prefetch
(
id_name
,
out_name
,
table_names
,
epmap
,
false
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
height_sections
,
context
);
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
#else
else
PADDLE_THROW
(
LookupTable
<
"paddle is not compiled with distribute support, can not do "
T
,
128
,
8
,
8
,
"parameter prefetch!"
);
true
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
#endif
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
else
{
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
K
=
ids_t
->
numel
();
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
if
(
padding_idx
==
-
1
)
LookupTable
<
T
,
128
,
8
,
8
,
false
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
else
LookupTable
<
T
,
128
,
8
,
8
,
true
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
}
}
};
};
...
@@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -109,6 +129,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
if
(
is_sparse
)
{
...
...
paddle/fluid/operators/lookup_table_op.h
浏览文件 @
44debca8
...
@@ -23,6 +23,10 @@ limitations under the License. */
...
@@ -23,6 +23,10 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -41,44 +45,66 @@ class LookupTableKernel : public framework::OpKernel<T> {
...
@@ -41,44 +45,66 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
table_var
=
context
.
InputVar
(
"W"
);
auto
*
table_var
=
context
.
InputVar
(
"W"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
auto
id_name
=
context
.
Inputs
(
"Ids"
).
front
();
int64_t
*
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
auto
out_name
=
context
.
Outputs
(
"Out"
).
front
();
int64_t
ids_numel
=
ids_t
->
numel
();
// for remote prefetch
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
auto
epmap
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"epmap"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
height_sections
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"height_sections"
);
int64_t
row_number
=
table_t
->
dims
()[
0
];
auto
table_names
=
context
.
Attr
<
std
::
vector
<
std
::
string
>>
(
"table_names"
);
int64_t
row_width
=
table_t
->
dims
()[
1
];
if
(
!
epmap
.
empty
())
{
auto
*
table
=
table_t
->
data
<
T
>
();
// if epmap is not empty, then the parameter will be fetched from remote
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// parameter
// server
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
#ifdef PADDLE_WITH_DISTRIBUTE
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
operators
::
distributed
::
prefetch
(
id_name
,
out_name
,
table_names
,
epmap
,
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
height_sections
,
context
);
}
else
{
#else
PADDLE_ENFORCE_LT
(
ids
[
i
],
row_number
);
PADDLE_THROW
(
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
"ids %d"
,
i
);
"paddle is not compiled with distribute support, can not do "
memcpy
(
output
+
i
*
row_width
,
table
+
ids
[
i
]
*
row_width
,
"parameter prefetch!"
);
row_width
*
sizeof
(
T
));
#endif
}
else
{
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
*
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
int64_t
ids_numel
=
ids_t
->
numel
();
if
(
table_var
->
IsType
<
LoDTensor
>
())
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
int64_t
row_number
=
table_t
->
dims
()[
0
];
int64_t
row_width
=
table_t
->
dims
()[
1
];
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
output
+
i
*
row_width
,
table
+
ids
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
}
}
}
}
else
if
(
table_var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
table_var
->
IsType
<
SelectedRows
>
())
{
const
auto
&
table_t
=
table_var
->
Get
<
SelectedRows
>
();
const
auto
&
table_t
=
table_var
->
Get
<
SelectedRows
>
()
;
int64_t
row_width
=
table_t
.
value
().
dims
()[
1
]
;
int64_t
row_width
=
table_t
.
value
().
dims
()[
1
]
;
const
auto
*
table
=
table_t
.
value
().
data
<
T
>
()
;
const
auto
*
table
=
table_t
.
value
().
data
<
T
>
(
);
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
()
);
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
}
else
{
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
auto
id_index
=
table_t
.
Index
(
ids
[
i
]
);
auto
id_index
=
table_t
.
Index
(
ids
[
i
]
);
PADDLE_ENFORCE_GE
(
id_index
,
0
,
"the input key should be exists."
);
PADDLE_ENFORCE_GE
(
id_index
,
0
,
"the input key should be exists."
);
blas
.
VCOPY
(
row_width
,
table
+
id_index
*
row_width
,
blas
.
VCOPY
(
row_width
,
table
+
id_index
*
row_width
,
output
+
i
*
row_width
);
output
+
i
*
row_width
);
}
}
}
}
}
}
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
44debca8
...
@@ -326,6 +326,11 @@ def embedding(input,
...
@@ -326,6 +326,11 @@ def embedding(input,
"""
"""
helper
=
LayerHelper
(
'embedding'
,
**
locals
())
helper
=
LayerHelper
(
'embedding'
,
**
locals
())
remote_prefetch
=
False
if
os
.
environ
.
get
(
'PADDLE_ENABLE_REMOTE_PREFETCH'
):
remote_prefetch
=
True
if
remote_prefetch
:
assert
is_sparse
is
True
and
is_distributed
is
False
w
=
helper
.
create_parameter
(
w
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
tmp
=
helper
.
create_variable_for_type_inference
(
dtype
)
tmp
=
helper
.
create_variable_for_type_inference
(
dtype
)
...
@@ -339,6 +344,7 @@ def embedding(input,
...
@@ -339,6 +344,7 @@ def embedding(input,
attrs
=
{
attrs
=
{
'is_sparse'
:
is_sparse
,
'is_sparse'
:
is_sparse
,
'is_distributed'
:
is_distributed
,
'is_distributed'
:
is_distributed
,
'remote_prefetch'
:
remote_prefetch
,
'padding_idx'
:
padding_idx
'padding_idx'
:
padding_idx
})
})
return
tmp
return
tmp
...
...
python/paddle/fluid/tests/unittests/dist_ctr.py
浏览文件 @
44debca8
...
@@ -16,11 +16,13 @@ from __future__ import print_function
...
@@ -16,11 +16,13 @@ from __future__ import print_function
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
os
import
dist_ctr_reader
import
dist_ctr_reader
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
from
test_dist_base
import
TestDistRunnerBase
,
runtime_main
IS_SPARSE
=
True
IS_SPARSE
=
True
os
.
environ
[
'PADDLE_ENABLE_REMOTE_PREFETCH'
]
=
"1"
# Fix seed for test
# Fix seed for test
fluid
.
default_startup_program
().
random_seed
=
1
fluid
.
default_startup_program
().
random_seed
=
1
...
...
python/paddle/fluid/tests/unittests/test_dist_transpiler.py
浏览文件 @
44debca8
...
@@ -782,5 +782,46 @@ class TestNCCL2Transpile(TranspilerTest):
...
@@ -782,5 +782,46 @@ class TestNCCL2Transpile(TranspilerTest):
pass
pass
# test for remote prefetch
class
TestRemoteLookupTable
(
TestDistLookupTableBase
):
def
net_conf
(
self
):
import
os
os
.
environ
[
'PADDLE_ENABLE_REMOTE_PREFETCH'
]
=
"1"
self
.
network_with_table
(
is_sparse
=
True
,
is_distributed
=
False
)
def
transpiler_test_impl
(
self
):
pserver1
,
startup1
=
self
.
get_pserver
(
self
.
pserver1_ep
)
self
.
assertEqual
(
len
(
pserver1
.
blocks
),
4
)
# 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
1
].
ops
],
[
"sum"
,
"scale"
,
"adam"
,
"scale"
,
"scale"
])
# 2 optimize for table adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
2
].
ops
],
[
"sum"
,
"scale"
,
"adam"
,
"scale"
,
"scale"
])
# 3 optimize for table 2 adam
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
self
.
assertEqual
([
op
.
type
for
op
in
pserver1
.
blocks
[
3
].
ops
],
[
"sum"
,
"scale"
,
"adam"
,
"scale"
,
"scale"
])
trainer
,
_
=
self
.
get_trainer
()
self
.
assertEqual
(
len
(
trainer
.
blocks
),
1
)
ops
=
[
'lookup_table'
,
'sequence_pool'
,
'lookup_table'
,
'sequence_pool'
,
'lookup_table'
,
'sequence_pool'
,
'concat'
,
'mul'
,
'elementwise_add'
,
'cross_entropy'
,
'mean'
,
'fill_constant'
,
'mean_grad'
,
'cross_entropy_grad'
,
'elementwise_add_grad'
,
'send'
,
'mul_grad'
,
'send'
,
'concat_grad'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'split_selected_rows'
,
'send'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'sequence_pool_grad'
,
'lookup_table_grad'
,
'sum'
,
'split_selected_rows'
,
'send'
,
'send_barrier'
,
'recv'
,
'recv'
,
'fetch_barrier'
]
self
.
assertEqual
([
op
.
type
for
op
in
trainer
.
blocks
[
0
].
ops
],
ops
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_lookup_remote_table_op.py
0 → 100644
浏览文件 @
44debca8
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
signal
import
time
import
unittest
from
multiprocessing
import
Process
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
paddle.fluid.framework
import
Program
,
program_guard
def
run_pserver
(
pserver_id
,
use_cuda
,
sync_mode
):
scope
=
fluid
.
core
.
Scope
()
program
=
Program
()
with
fluid
.
scope_guard
(
scope
):
with
program_guard
(
program
,
startup_program
=
Program
()):
# create table parameter in scope
place
=
fluid
.
CUDAPlace
(
0
)
if
use_cuda
else
fluid
.
CPUPlace
()
# create and initialize Param Variable
param
=
scope
.
var
(
'table'
).
get_tensor
()
param_array
=
np
.
ones
((
10
,
8
)).
astype
(
"float32"
)
for
i
in
range
(
len
(
param_array
)):
param_array
[
i
]
*=
param_array
[
i
]
*
i
+
pserver_id
*
10
param
.
set
(
param_array
,
place
)
optimize_block
=
program
.
_create_block
(
program
.
global_block
().
idx
)
program
.
global_block
().
append_op
(
type
=
"listen_and_serv"
,
inputs
=
{
'X'
:
[]},
outputs
=
{},
attrs
=
{
"optimize_blocks"
:
[
optimize_block
],
"endpoint"
:
'127.0.0.1:0'
,
"Fanin"
:
1
,
"sync_mode"
:
True
,
"grad_to_block_id"
:
[]
})
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
program
)
class
TestListenAndServOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
ps_timeout
=
5
def
_start_pserver
(
self
,
pserver_id
,
use_cuda
,
sync_mode
,
pserver_func
):
p
=
Process
(
target
=
pserver_func
,
args
=
(
pserver_id
,
use_cuda
,
sync_mode
))
p
.
daemon
=
True
p
.
start
()
return
p
def
_wait_ps_ready
(
self
,
pid
):
start_left_time
=
self
.
ps_timeout
sleep_time
=
0.5
while
True
:
assert
start_left_time
>=
0
,
"wait ps ready failed"
time
.
sleep
(
sleep_time
)
try
:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os
.
stat
(
"/tmp/paddle.%d.port"
%
pid
)
return
except
os
.
error
:
start_left_time
-=
sleep_time
def
_get_pserver_port
(
self
,
pid
):
with
open
(
"/tmp/paddle.%d.port"
%
pid
,
'r'
)
as
f
:
port
=
int
(
f
.
read
().
strip
())
return
port
def
_run_lookup_table_op_one_pserver
(
self
,
place
,
port
):
scope
=
fluid
.
core
.
Scope
()
program
=
Program
()
with
fluid
.
scope_guard
(
scope
):
with
program_guard
(
program
,
startup_program
=
Program
()):
# create and initialize Param Variable
param
=
scope
.
var
(
'W'
).
get_tensor
()
param_array
=
np
.
full
((
10
,
8
),
1.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
ids
=
scope
.
var
(
'Ids'
).
get_tensor
()
ids_array
=
np
.
array
([[
1
],
[
2
],
[
5
]]).
astype
(
"int64"
)
ids
.
set
(
ids_array
,
place
)
ids_lod
=
[[
0
,
1
,
2
,
3
]]
ids
.
set_lod
(
ids_lod
)
out
=
scope
.
var
(
'Out'
).
get_tensor
()
emaps
=
[
'127.0.0.1:'
+
str
(
port
)]
table_names
=
[
'table'
]
height_sections
=
[
10
]
# create and run sgd operator
lookup_table_op
=
Operator
(
"lookup_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
,
remote_prefetch
=
True
,
epmap
=
emaps
,
table_names
=
table_names
,
height_sections
=
height_sections
)
lookup_table_op
.
run
(
scope
,
place
)
# get and compare result
result_array
=
np
.
array
(
out
)
self
.
assertEqual
(
out
.
lod
(),
ids_lod
)
self
.
assertEqual
(
list
(
result_array
.
shape
),
[
len
(
ids_array
),
8
])
for
i
in
range
(
len
(
ids_array
)):
id
=
ids_array
[
i
][
0
]
self
.
assertTrue
((
result_array
[
i
]
==
id
).
all
())
def
_run_lookup_table_op_two_pserver
(
self
,
place
,
port0
,
port1
):
scope
=
fluid
.
core
.
Scope
()
program
=
Program
()
with
fluid
.
scope_guard
(
scope
):
with
program_guard
(
program
,
startup_program
=
Program
()):
# create and initialize Param Variable
param
=
scope
.
var
(
'W'
).
get_tensor
()
param_array
=
np
.
full
((
10
,
8
),
1.0
).
astype
(
"float32"
)
param
.
set
(
param_array
,
place
)
ids
=
scope
.
var
(
'Ids'
).
get_tensor
()
ids_array
=
np
.
array
([[
1
],
[
2
],
[
11
],
[
13
]]).
astype
(
"int64"
)
ids
.
set
(
ids_array
,
place
)
ids_lod
=
[[
0
,
2
,
3
,
4
]]
ids
.
set_lod
(
ids_lod
)
out
=
scope
.
var
(
'Out'
).
get_tensor
()
emaps
=
[
'127.0.0.1:'
+
str
(
port0
),
'127.0.0.1:'
+
str
(
port1
)]
table_names
=
[
'table'
,
'table'
]
height_sections
=
[
10
,
20
]
# create and run sgd operator
lookup_table_op
=
Operator
(
"lookup_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
,
remote_prefetch
=
True
,
epmap
=
emaps
,
table_names
=
table_names
,
height_sections
=
height_sections
)
lookup_table_op
.
run
(
scope
,
place
)
# get and compare result
result_array
=
np
.
array
(
out
)
self
.
assertEqual
(
out
.
lod
(),
ids_lod
)
self
.
assertEqual
(
list
(
result_array
.
shape
),
[
len
(
ids_array
),
8
])
for
i
in
range
(
len
(
ids_array
)):
id
=
ids_array
[
i
][
0
]
self
.
assertTrue
((
result_array
[
i
]
==
id
).
all
())
def
test_lookup_remote_table
(
self
):
os
.
environ
[
'PADDLE_ENABLE_REMOTE_PREFETCH'
]
=
"1"
# run pserver on CPU in sync mode
p0
=
self
.
_start_pserver
(
0
,
False
,
True
,
run_pserver
)
self
.
_wait_ps_ready
(
p0
.
pid
)
port0
=
self
.
_get_pserver_port
(
p0
.
pid
)
p1
=
self
.
_start_pserver
(
1
,
False
,
True
,
run_pserver
)
self
.
_wait_ps_ready
(
p1
.
pid
)
port1
=
self
.
_get_pserver_port
(
p1
.
pid
)
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
_run_lookup_table_op_one_pserver
(
place
,
port0
)
self
.
_run_lookup_table_op_two_pserver
(
place
,
port0
,
port1
)
# raise SIGTERM to pserver
os
.
kill
(
p0
.
pid
,
signal
.
SIGINT
)
p0
.
join
()
os
.
kill
(
p1
.
pid
,
signal
.
SIGINT
)
p1
.
join
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
44debca8
...
@@ -236,6 +236,31 @@ class DistributeTranspiler(object):
...
@@ -236,6 +236,31 @@ class DistributeTranspiler(object):
else
:
else
:
raise
ValueError
(
"must set trainer_id > 0"
)
raise
ValueError
(
"must set trainer_id > 0"
)
def
_get_all_remote_sparse_update_op
(
self
,
main_program
):
sparse_update_ops
=
[]
sparse_update_op_types
=
[
"lookup_table"
]
for
op
in
main_program
.
global_block
().
ops
:
if
op
.
type
in
sparse_update_op_types
and
op
.
attr
(
'remote_prefetch'
)
is
True
and
not
op
.
attr
(
'is_distributed'
):
sparse_update_ops
.
append
(
op
)
return
sparse_update_ops
def
_update_remote_sparse_update_op
(
self
,
param_varname
,
height_sections
,
endpint_map
,
table_names
):
for
op
in
self
.
sparse_update_ops
:
if
param_varname
in
op
.
input_arg_names
:
op
.
_set_attr
(
'epmap'
,
endpint_map
)
op
.
_set_attr
(
'table_names'
,
table_names
)
op
.
_set_attr
(
'height_sections'
,
height_sections
)
op
.
_set_attr
(
'trainer_id'
,
self
.
trainer_id
)
def
_is_input_of_remote_sparse_update_op
(
self
,
param_name
):
for
op
in
self
.
sparse_update_ops
:
if
param_name
in
op
.
input_arg_names
:
return
True
return
False
def
transpile
(
self
,
def
transpile
(
self
,
trainer_id
,
trainer_id
,
program
=
None
,
program
=
None
,
...
@@ -299,6 +324,12 @@ class DistributeTranspiler(object):
...
@@ -299,6 +324,12 @@ class DistributeTranspiler(object):
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
self
.
param_name_to_grad_name
[
param_var
.
name
]
=
grad_var
.
name
self
.
grad_name_to_param_name
[
grad_var
.
name
]
=
param_var
.
name
self
.
grad_name_to_param_name
[
grad_var
.
name
]
=
param_var
.
name
# get all sparse update ops
self
.
sparse_update_ops
=
self
.
_get_all_remote_sparse_update_op
(
self
.
origin_program
)
# use_sparse_update_param_name -> split_height_section
self
.
sparse_param_to_height_sections
=
dict
()
# add distributed attrs to program
# add distributed attrs to program
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_is_distributed
=
True
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
self
.
origin_program
.
_endpoints
=
self
.
pserver_endpoints
...
@@ -336,6 +367,13 @@ class DistributeTranspiler(object):
...
@@ -336,6 +367,13 @@ class DistributeTranspiler(object):
splited_grad_varname
=
splited_vars
[
0
].
name
splited_grad_varname
=
splited_vars
[
0
].
name
index
=
find_op_by_output_arg
(
index
=
find_op_by_output_arg
(
program
.
global_block
(),
splited_grad_varname
,
reverse
=
True
)
program
.
global_block
(),
splited_grad_varname
,
reverse
=
True
)
if
splited_vars
[
0
].
type
==
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
:
sparse_param_name
=
self
.
grad_name_to_param_name
[
grad_varname
]
if
self
.
_is_input_of_remote_sparse_update_op
(
sparse_param_name
):
self
.
sparse_param_to_height_sections
[
sparse_param_name
]
=
[
splited_vars
[
0
].
shape
[
0
]]
elif
len
(
splited_vars
)
>
1
:
elif
len
(
splited_vars
)
>
1
:
orig_var
=
program
.
global_block
().
vars
[
splited_grad_varname
]
orig_var
=
program
.
global_block
().
vars
[
splited_grad_varname
]
index
=
find_op_by_output_arg
(
index
=
find_op_by_output_arg
(
...
@@ -406,16 +444,18 @@ class DistributeTranspiler(object):
...
@@ -406,16 +444,18 @@ class DistributeTranspiler(object):
all_recv_outputs
=
[]
all_recv_outputs
=
[]
for
param_varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
for
param_varname
,
splited_var
in
six
.
iteritems
(
self
.
param_var_mapping
):
eps
=
[]
eps
=
[]
table_names
=
[]
for
var
in
splited_var
:
for
var
in
splited_var
:
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
index
=
[
v
.
name
for
v
in
recv_vars
].
index
(
var
.
name
)
eps
.
append
(
eplist
[
index
])
eps
.
append
(
eplist
[
index
])
table_names
.
append
(
var
.
name
)
if
self
.
sync_mode
:
if
self
.
sync_mode
:
recv_dep_in
=
send_barrier_out
recv_dep_in
=
send_barrier_out
else
:
else
:
# connect deps to send op in async mode
# connect deps to send op in async mode
recv_dep_in
=
self
.
grad_name_to_send_dummy_out
[
recv_dep_in
=
self
.
grad_name_to_send_dummy_out
[
self
.
param_name_to_grad_name
[
param_varname
]]
self
.
param_name_to_grad_name
[
param_varname
]]
all_recv_outputs
.
extend
(
splited_var
)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor
# if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op.
# will use op_role_var to get expected device place to run this op.
...
@@ -425,18 +465,25 @@ class DistributeTranspiler(object):
...
@@ -425,18 +465,25 @@ class DistributeTranspiler(object):
if
len
(
splited_trainer_grad
)
==
1
:
if
len
(
splited_trainer_grad
)
==
1
:
recv_op_role_var_name
=
splited_trainer_grad
[
0
].
name
recv_op_role_var_name
=
splited_trainer_grad
[
0
].
name
program
.
global_block
().
append_op
(
if
param_varname
in
self
.
sparse_param_to_height_sections
:
type
=
"recv"
,
height_sections
=
self
.
sparse_param_to_height_sections
[
inputs
=
{
"X"
:
[
recv_dep_in
]},
param_varname
]
outputs
=
{
"Out"
:
splited_var
},
self
.
_update_remote_sparse_update_op
(
attrs
=
{
param_varname
,
height_sections
,
eps
,
table_names
)
"epmap"
:
eps
,
else
:
"trainer_id"
:
self
.
trainer_id
,
all_recv_outputs
.
extend
(
splited_var
)
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
program
.
global_block
().
append_op
(
OP_ROLE_VAR_ATTR_NAME
:
type
=
"recv"
,
[
param_varname
,
recv_op_role_var_name
],
inputs
=
{
"X"
:
[
recv_dep_in
]},
"sync_mode"
:
not
self
.
sync_mode
outputs
=
{
"Out"
:
splited_var
},
})
attrs
=
{
"epmap"
:
eps
,
"trainer_id"
:
self
.
trainer_id
,
RPC_OP_ROLE_ATTR_NAME
:
RPC_OP_ROLE_ATTR_VALUE
,
OP_ROLE_VAR_ATTR_NAME
:
[
param_varname
,
recv_op_role_var_name
],
"sync_mode"
:
not
self
.
sync_mode
})
if
self
.
sync_mode
:
if
self
.
sync_mode
:
# form a WAW dependency
# form a WAW dependency
...
@@ -454,14 +501,15 @@ class DistributeTranspiler(object):
...
@@ -454,14 +501,15 @@ class DistributeTranspiler(object):
if
len
(
splited_var
)
<=
1
:
if
len
(
splited_var
)
<=
1
:
continue
continue
orig_param
=
program
.
global_block
().
vars
[
param_varname
]
orig_param
=
program
.
global_block
().
vars
[
param_varname
]
program
.
global_block
().
append_op
(
if
param_varname
not
in
self
.
sparse_param_to_height_sections
:
type
=
"concat"
,
program
.
global_block
().
append_op
(
inputs
=
{
"X"
:
splited_var
},
type
=
"concat"
,
outputs
=
{
"Out"
:
[
orig_param
]},
inputs
=
{
"X"
:
splited_var
},
attrs
=
{
outputs
=
{
"Out"
:
[
orig_param
]},
"axis"
:
0
,
attrs
=
{
RPC_OP_ROLE_ATTR_NAME
:
DIST_OP_ROLE_ATTR_VALUE
"axis"
:
0
,
})
RPC_OP_ROLE_ATTR_NAME
:
DIST_OP_ROLE_ATTR_VALUE
})
self
.
_get_trainer_startup_program
(
recv_vars
=
recv_vars
,
eplist
=
eplist
)
self
.
_get_trainer_startup_program
(
recv_vars
=
recv_vars
,
eplist
=
eplist
)
...
@@ -1420,6 +1468,10 @@ to transpile() call.")
...
@@ -1420,6 +1468,10 @@ to transpile() call.")
height_sections
=
[]
height_sections
=
[]
for
v
in
splited_vars
:
for
v
in
splited_vars
:
height_sections
.
append
(
v
.
shape
[
0
])
height_sections
.
append
(
v
.
shape
[
0
])
sparse_param_name
=
self
.
grad_name_to_param_name
[
orig_var
.
name
]
if
self
.
_is_input_of_remote_sparse_update_op
(
sparse_param_name
):
self
.
sparse_param_to_height_sections
[
sparse_param_name
]
=
height_sections
program
.
global_block
().
_insert_op
(
program
.
global_block
().
_insert_op
(
index
=
index
+
1
,
index
=
index
+
1
,
type
=
"split_selected_rows"
,
type
=
"split_selected_rows"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录