Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a97ca56a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a97ca56a
编写于
1月 13, 2021
作者:
T
tangwei12
提交者:
GitHub
1月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
split ps with distributed (#30337)
Change-Id: I3c788e7576688e63181e7f01562529b85a09cc59
上级
5eab1a38
变更
44
显示空白变更内容
内联
并排
Showing
44 changed file
with
251 addition
and
200 deletion
+251
-200
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/configure.cmake
cmake/configure.cmake
+5
-0
cmake/third_party.cmake
cmake/third_party.cmake
+1
-1
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-4
paddle/fluid/distributed/common/registerer.h
paddle/fluid/distributed/common/registerer.h
+7
-7
paddle/fluid/distributed/ps.proto
paddle/fluid/distributed/ps.proto
+1
-1
paddle/fluid/distributed/service/brpc_ps_client.cc
paddle/fluid/distributed/service/brpc_ps_client.cc
+3
-3
paddle/fluid/distributed/service/brpc_ps_client.h
paddle/fluid/distributed/service/brpc_ps_client.h
+2
-2
paddle/fluid/distributed/service/brpc_ps_server.cc
paddle/fluid/distributed/service/brpc_ps_server.cc
+106
-93
paddle/fluid/distributed/service/brpc_ps_server.h
paddle/fluid/distributed/service/brpc_ps_server.h
+5
-5
paddle/fluid/distributed/service/brpc_utils.cc
paddle/fluid/distributed/service/brpc_utils.cc
+6
-6
paddle/fluid/distributed/service/brpc_utils.h
paddle/fluid/distributed/service/brpc_utils.h
+2
-2
paddle/fluid/distributed/service/heter_client.cc
paddle/fluid/distributed/service/heter_client.cc
+2
-2
paddle/fluid/distributed/service/heter_client.h
paddle/fluid/distributed/service/heter_client.h
+2
-2
paddle/fluid/distributed/service/heter_server.h
paddle/fluid/distributed/service/heter_server.h
+5
-5
paddle/fluid/distributed/service/ps_client.cc
paddle/fluid/distributed/service/ps_client.cc
+4
-5
paddle/fluid/distributed/service/ps_client.h
paddle/fluid/distributed/service/ps_client.h
+4
-1
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-1
paddle/fluid/distributed/service/server.cc
paddle/fluid/distributed/service/server.cc
+6
-4
paddle/fluid/distributed/service/server.h
paddle/fluid/distributed/service/server.h
+6
-4
paddle/fluid/distributed/service/service.h
paddle/fluid/distributed/service/service.h
+4
-0
paddle/fluid/distributed/table/accessor.h
paddle/fluid/distributed/table/accessor.h
+1
-1
paddle/fluid/distributed/table/table.cc
paddle/fluid/distributed/table/table.cc
+12
-11
paddle/fluid/distributed/table/table.h
paddle/fluid/distributed/table/table.h
+1
-1
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
+5
-3
paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
...le/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
+7
-4
paddle/fluid/distributed/test/brpc_utils_test.cc
paddle/fluid/distributed/test/brpc_utils_test.cc
+2
-2
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+12
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+3
-3
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+2
-2
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+3
-3
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+4
-1
paddle/fluid/operators/pscore/CMakeLists.txt
paddle/fluid/operators/pscore/CMakeLists.txt
+4
-0
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
+2
-2
paddle/fluid/operators/pscore/heter_listen_and_server_test.cc
...le/fluid/operators/pscore/heter_listen_and_server_test.cc
+2
-2
paddle/fluid/operators/pscore/heter_server_test.cc
paddle/fluid/operators/pscore/heter_server_test.cc
+2
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-2
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+4
-2
paddle/testing/paddle_gtest_main.cc
paddle/testing/paddle_gtest_main.cc
+2
-1
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
a97ca56a
...
...
@@ -136,6 +136,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option
(
WITH_XBYAK
"Compile with xbyak support"
ON
)
option
(
WITH_CONTRIB
"Compile the third-party contributation"
OFF
)
option
(
WITH_GRPC
"Use grpc as the default rpc framework"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_PSCORE
"Compile with parameter server support"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_INFERENCE_API_TEST
"Test fluid inference C++ high-level api interface"
OFF
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
WITH_DGC
"Use DGC(Deep Gradient Compression) or not"
${
WITH_DISTRIBUTE
}
)
...
...
cmake/configure.cmake
浏览文件 @
a97ca56a
...
...
@@ -156,6 +156,11 @@ if(WITH_DISTRIBUTE)
add_definitions
(
-DPADDLE_WITH_DISTRIBUTE
)
endif
()
if
(
WITH_PSCORE
)
add_definitions
(
-DPADDLE_WITH_PSCORE
)
endif
()
if
(
WITH_GRPC
)
add_definitions
(
-DPADDLE_WITH_GRPC
)
endif
(
WITH_GRPC
)
...
...
cmake/third_party.cmake
浏览文件 @
a97ca56a
...
...
@@ -280,7 +280,7 @@ if(WITH_BOX_PS)
list
(
APPEND third_party_deps extern_box_ps
)
endif
(
WITH_BOX_PS
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
include
(
external/snappy
)
list
(
APPEND third_party_deps extern_snappy
)
...
...
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
a97ca56a
if
(
WITH_PSLIB
)
return
()
endif
()
if
(
NOT WITH_DISTRIBUTE
)
if
(
NOT WITH_PSCORE
)
return
()
endif
()
...
...
paddle/fluid/distributed/common/registerer.h
浏览文件 @
a97ca56a
...
...
@@ -69,24 +69,24 @@ class ObjectFactory {
};
typedef
std
::
map
<
std
::
string
,
ObjectFactory
*>
FactoryMap
;
typedef
std
::
map
<
std
::
string
,
FactoryMap
>
Bas
eClassMap
;
typedef
std
::
map
<
std
::
string
,
FactoryMap
>
PsCor
eClassMap
;
#ifdef __cplusplus
extern
"C"
{
#endif
inline
Bas
eClassMap
&
global_factory_map
()
{
static
BaseClassMap
*
base_class
=
new
Bas
eClassMap
();
inline
PsCor
eClassMap
&
global_factory_map
()
{
static
PsCoreClassMap
*
base_class
=
new
PsCor
eClassMap
();
return
*
base_class
;
}
#ifdef __cplusplus
}
#endif
inline
Bas
eClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
}
inline
PsCor
eClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
}
// typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap;
#define REGISTER_
REGISTERER(base_class)
\
#define REGISTER_
PSCORE_REGISTERER(base_class)
\
class base_class##Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
...
...
@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
};
#define REGISTER_
CLASS(clazz, name)
\
#define REGISTER_
PSCORE_CLASS(clazz, name)
\
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { return Any(new name()); } \
...
...
@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_
PSCORE_
CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
}
// namespace distributed
...
...
paddle/fluid/distributed/ps.proto
浏览文件 @
a97ca56a
...
...
@@ -86,7 +86,7 @@ message SparseTableParameter {
message
ServerServiceParameter
{
optional
string
server_class
=
1
[
default
=
"BrpcPsServer"
];
optional
string
client_class
=
2
[
default
=
"BrpcPsClient"
];
optional
string
service_class
=
3
[
default
=
"PsService"
];
optional
string
service_class
=
3
[
default
=
"
Brpc
PsService"
];
optional
uint32
start_server_port
=
4
[
default
=
0
];
// will find a avaliable port from it
optional
uint32
server_thread_num
=
5
[
default
=
12
];
...
...
paddle/fluid/distributed/service/brpc_ps_client.cc
浏览文件 @
a97ca56a
...
...
@@ -17,8 +17,8 @@
#include <sstream>
#include <string>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
...
...
@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void
DownpourPsClientService
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
int
ret
=
_client
->
handle_client2client_msg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
...
...
paddle/fluid/distributed/service/brpc_ps_client.h
浏览文件 @
a97ca56a
...
...
@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return
0
;
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
protected:
...
...
paddle/fluid/distributed/service/brpc_ps_server.cc
浏览文件 @
a97ca56a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "Eigen/Dense"
#include "butil/endpoint.h"
...
...
@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() {
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
}
auto
*
service
=
CREATE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
...
...
@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
int32_t
BrpcPsServer
::
port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
PsService
::
initialize
()
{
int32_t
Brpc
PsService
::
initialize
()
{
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
PsService
::
stop_server
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
PsService
::
pull_dense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
PsService
::
push_dense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
PsService
::
pull_sparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
PsService
::
push_sparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
PsService
::
save_one_table
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
PsService
::
save_all_table
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
PsService
::
shrink_table
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
PsService
::
load_one_table
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
PsService
::
load_all_table
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
PsService
::
clear_one_table
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
PsService
::
clear_all_table
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
PsService
::
push_dense_param
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
PsService
::
print_table_stat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
PsService
::
pull_geo_param
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
PsService
::
push_sparse_param
;
_service_handler_map
[
PS_BARRIER
]
=
&
PsService
::
barrier
;
_service_handler_map
[
PS_
START_PROFILER
]
=
&
PsService
::
start_profil
er
;
_service_handler_map
[
PS_ST
OP_PROFILER
]
=
&
PsService
::
stop
_profiler
;
_service_handler_map
[
PS_
PUSH_GLOBAL_STEP
]
=
&
PsService
::
push_global_step
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
Brpc
PsService
::
stop_server
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
Brpc
PsService
::
pull_dense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
Brpc
PsService
::
push_dense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
Brpc
PsService
::
pull_sparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
Brpc
PsService
::
push_sparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
Brpc
PsService
::
save_one_table
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
Brpc
PsService
::
save_all_table
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
Brpc
PsService
::
shrink_table
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
Brpc
PsService
::
load_one_table
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
Brpc
PsService
::
load_all_table
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
Brpc
PsService
::
clear_one_table
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
Brpc
PsService
::
clear_all_table
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
Brpc
PsService
::
push_dense_param
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
Brpc
PsService
::
print_table_stat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
Brpc
PsService
::
pull_geo_param
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
BrpcPsService
::
push_sparse_param
;
_service_handler_map
[
PS_
BARRIER
]
=
&
BrpcPsService
::
barri
er
;
_service_handler_map
[
PS_ST
ART_PROFILER
]
=
&
BrpcPsService
::
start
_profiler
;
_service_handler_map
[
PS_
STOP_PROFILER
]
=
&
BrpcPsService
::
stop_profiler
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
...
...
@@ -116,7 +118,7 @@ int32_t PsService::initialize() {
return -1; \
}
int32_t
PsService
::
initialize_shard_info
()
{
int32_t
Brpc
PsService
::
initialize_shard_info
()
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
...
...
@@ -132,7 +134,7 @@ int32_t PsService::initialize_shard_info() {
return
0
;
}
void
PsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
void
Brpc
PsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
...
...
@@ -163,7 +165,7 @@ void PsService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t
PsService
::
pull_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
Brpc
PsService
::
pull_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_dense"
);
...
...
@@ -191,7 +193,7 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
push_dense_param
(
Table
*
table
,
int32_t
Brpc
PsService
::
push_dense_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
...
...
@@ -218,7 +220,7 @@ int32_t PsService::push_dense_param(Table *table,
return
0
;
}
int32_t
PsService
::
push_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
Brpc
PsService
::
push_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_dense"
);
...
...
@@ -244,7 +246,7 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
Brpc
PsService
::
barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
...
...
@@ -262,7 +264,7 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
push_sparse_param
(
Table
*
table
,
int32_t
Brpc
PsService
::
push_sparse_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
...
...
@@ -294,7 +296,8 @@ int32_t PsService::push_sparse_param(Table *table,
return
0
;
}
int32_t
PsService
::
pull_geo_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
pull_geo_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_geo_param"
);
...
...
@@ -316,7 +319,8 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
pull_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
pull_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_sparse"
);
...
...
@@ -353,7 +357,8 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
push_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
push_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_sparse"
);
...
...
@@ -384,7 +389,7 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
print_table_stat
(
Table
*
table
,
int32_t
Brpc
PsService
::
print_table_stat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
...
...
@@ -398,7 +403,8 @@ int32_t PsService::print_table_stat(Table *table,
return
0
;
}
int32_t
PsService
::
load_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
load_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
...
...
@@ -415,7 +421,8 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
load_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
load_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
...
...
@@ -428,7 +435,8 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
save_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
save_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
...
...
@@ -449,7 +457,8 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
return
feasign_size
;
}
int32_t
PsService
::
save_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
save_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
...
...
@@ -466,7 +475,8 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
shrink_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
shrink_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
...
...
@@ -477,7 +487,7 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
clear_one_table
(
Table
*
table
,
int32_t
Brpc
PsService
::
clear_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
...
...
@@ -487,7 +497,7 @@ int32_t PsService::clear_one_table(Table *table,
return
0
;
}
int32_t
PsService
::
clear_all_table
(
Table
*
table
,
int32_t
Brpc
PsService
::
clear_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
...
...
@@ -500,7 +510,8 @@ int32_t PsService::clear_all_table(Table *table,
return
0
;
}
int32_t
PsService
::
stop_server
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
stop_server
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
*
p_server
=
_server
;
...
...
@@ -512,7 +523,8 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
stop_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
stop_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
...
...
@@ -520,14 +532,15 @@ int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
start_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
start_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
}
int32_t
PsService
::
push_global_step
(
Table
*
table
,
int32_t
Brpc
PsService
::
push_global_step
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
...
...
paddle/fluid/distributed/service/brpc_ps_server.h
浏览文件 @
a97ca56a
...
...
@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
class
PsService
;
class
Brpc
PsService
;
typedef
int32_t
(
PsService
::*
serviceHandlerFunc
)(
typedef
int32_t
(
Brpc
PsService
::*
serviceHandlerFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
class
PsService
:
public
PsBaseService
{
class
Brpc
PsService
:
public
PsBaseService
{
public:
virtual
int32_t
initialize
()
override
;
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
private:
...
...
paddle/fluid/distributed/service/brpc_utils.cc
浏览文件 @
a97ca56a
...
...
@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
var_msg
->
set_type
(
::
paddle
::
LOD_TENSOR
);
var_msg
->
set_type
(
::
paddle
::
distributed
::
LOD_TENSOR
);
const
framework
::
LoD
lod
=
tensor
->
lod
();
if
(
lod
.
size
()
>
0
)
{
var_msg
->
set_lod_level
(
lod
.
size
());
...
...
@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
var_msg
->
set_type
(
::
paddle
::
SELECTED_ROWS
);
var_msg
->
set_type
(
::
paddle
::
distributed
::
SELECTED_ROWS
);
var_msg
->
set_slr_height
(
slr
->
height
());
auto
*
var_data
=
var_msg
->
mutable_data
();
...
...
@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
Var
(
msg
.
varname
());
if
(
msg
.
type
()
==
::
paddle
::
LOD_TENSOR
)
{
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
SELECTED_ROWS
)
{
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
...
...
@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable %s in scope."
,
msg
.
varname
()));
if
(
msg
.
type
()
==
::
paddle
::
LOD_TENSOR
)
{
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
SELECTED_ROWS
)
{
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
...
...
paddle/fluid/distributed/service/brpc_utils.h
浏览文件 @
a97ca56a
...
...
@@ -44,8 +44,8 @@ class DeviceContext;
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
...
...
paddle/fluid/distributed/service/heter_client.cc
浏览文件 @
a97ca56a
...
...
@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
,
response
;
auto
&
request_io_buffer
=
cntl
.
request_attachment
();
::
paddle
::
PsService_Stub
stub
(
xpu_channels_
[
num
].
get
());
::
paddle
::
distributed
::
PsService_Stub
stub
(
xpu_channels_
[
num
].
get
());
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name_val
,
send_var_name_val
,
recv_var_name_val
,
*
p_ctx
,
p_scope
,
&
request
,
&
request_io_buffer
);
...
...
@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
::
paddle
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
::
paddle
::
distributed
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
closure
->
cntl
(
i
)
->
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
...
...
paddle/fluid/distributed/service/heter_client.h
浏览文件 @
a97ca56a
...
...
@@ -35,8 +35,8 @@ limitations under the License. */
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
HeterRpcCallbackFunc
;
...
...
paddle/fluid/distributed/service/heter_server.h
浏览文件 @
a97ca56a
...
...
@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
class
HeterService
;
typedef
int32_t
(
HeterService
::*
serviceHandlerFunc
)(
...
...
@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef
std
::
function
<
int
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
HeterServiceHandler
;
class
HeterService
:
public
::
paddle
::
PsService
{
class
HeterService
:
public
::
paddle
::
distributed
::
PsService
{
public:
HeterService
()
{
_service_handler_map
[
PS_STOP_SERVER
]
=
&
HeterService
::
stop_heter_worker
;
...
...
@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual
~
HeterService
()
{}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
...
...
paddle/fluid/distributed/service/ps_client.cc
浏览文件 @
a97ca56a
...
...
@@ -13,9 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include "brpc/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
...
...
@@ -23,7 +21,7 @@
namespace
paddle
{
namespace
distributed
{
REGISTER_CLASS
(
PSClient
,
BrpcPsClient
);
REGISTER_
PSCORE_
CLASS
(
PSClient
,
BrpcPsClient
);
int32_t
PSClient
::
configure
(
const
PSParameter
&
config
,
...
...
@@ -43,7 +41,7 @@ int32_t PSClient::configure(
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
size_t
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
accessor
=
CREATE_CLASS
(
auto
*
accessor
=
CREATE_
PSCORE_
CLASS
(
ValueAccessor
,
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
accessor
->
configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
...
...
@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSClient
*
client
=
CREATE_CLASS
(
PSClient
,
service_param
.
client_class
());
PSClient
*
client
=
CREATE_PSCORE_CLASS
(
PSClient
,
service_param
.
client_class
());
if
(
client
==
NULL
)
{
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
<<
service_param
.
client_class
();
...
...
paddle/fluid/distributed/service/ps_client.h
浏览文件 @
a97ca56a
...
...
@@ -28,6 +28,9 @@
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
...
...
@@ -206,7 +209,7 @@ class PSClient {
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
//处理client2client消息
};
REGISTER_REGISTERER
(
PSClient
);
REGISTER_
PSCORE_
REGISTERER
(
PSClient
);
class
PSClientFactory
{
public:
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
a97ca56a
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
;
package
paddle
.
distributed
;
option
cc_generic_services
=
true
;
option
cc_enable_arenas
=
true
;
...
...
paddle/fluid/distributed/service/server.cc
浏览文件 @
a97ca56a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
...
...
@@ -20,8 +21,8 @@
namespace
paddle
{
namespace
distributed
{
REGISTER_CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_
CLASS
(
PsBaseService
,
PsService
);
REGISTER_
PSCORE_
CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_
PSCORE_CLASS
(
PsBaseService
,
Brpc
PsService
);
PSServer
*
PSServerFactory
::
create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
...
...
@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSServer
*
server
=
CREATE_CLASS
(
PSServer
,
service_param
.
server_class
());
PSServer
*
server
=
CREATE_PSCORE_CLASS
(
PSServer
,
service_param
.
server_class
());
if
(
server
==
NULL
)
{
LOG
(
ERROR
)
<<
"server is not registered, server_name:"
<<
service_param
.
server_class
();
...
...
@@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t
global_step_table
=
UINT32_MAX
;
for
(
size_t
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_CLASS
(
auto
*
table
=
CREATE_
PSCORE_
CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
if
(
downpour_param
.
downpour_table_param
(
i
).
table_class
()
==
...
...
paddle/fluid/distributed/service/server.h
浏览文件 @
a97ca56a
...
...
@@ -46,6 +46,8 @@ namespace paddle {
namespace
distributed
{
class
Table
;
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
class
PSServer
{
public:
...
...
@@ -107,7 +109,7 @@ class PSServer {
platform
::
Place
place_
=
platform
::
CPUPlace
();
};
REGISTER_REGISTERER
(
PSServer
);
REGISTER_
PSCORE_
REGISTERER
(
PSServer
);
typedef
std
::
function
<
void
(
void
*
)
>
PServerCallBack
;
...
...
@@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return
0
;
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
=
0
;
virtual
void
set_response_code
(
PsResponseMessage
&
response
,
int
err_code
,
...
...
@@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer
*
_server
;
const
ServerParameter
*
_config
;
};
REGISTER_REGISTERER
(
PsBaseService
);
REGISTER_
PSCORE_
REGISTERER
(
PsBaseService
);
class
PSServerFactory
{
public:
...
...
paddle/fluid/distributed/service/service.h
浏览文件 @
a97ca56a
...
...
@@ -28,6 +28,10 @@ limitations under the License. */
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
using
paddle
::
distributed
::
PsService
;
class
PSCore
{
public:
explicit
PSCore
()
{}
...
...
paddle/fluid/distributed/table/accessor.h
浏览文件 @
a97ca56a
...
...
@@ -165,6 +165,6 @@ class ValueAccessor {
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
_data_coverter_map
;
};
REGISTER_REGISTERER
(
ValueAccessor
);
REGISTER_
PSCORE_
REGISTERER
(
ValueAccessor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/table/table.cc
浏览文件 @
a97ca56a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h"
...
...
@@ -27,14 +28,14 @@
namespace
paddle
{
namespace
distributed
{
REGISTER_CLASS
(
Table
,
CommonDenseTable
);
REGISTER_CLASS
(
Table
,
CommonSparseTable
);
REGISTER_CLASS
(
Table
,
SparseGeoTable
);
REGISTER_CLASS
(
Table
,
BarrierTable
);
REGISTER_CLASS
(
Table
,
TensorTable
);
REGISTER_CLASS
(
Table
,
DenseTensorTable
);
REGISTER_CLASS
(
Table
,
GlobalStepTable
);
REGISTER_CLASS
(
ValueAccessor
,
CommMergeAccessor
);
REGISTER_
PSCORE_
CLASS
(
Table
,
CommonDenseTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
CommonSparseTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
SparseGeoTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
BarrierTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
TensorTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
DenseTensorTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
GlobalStepTable
);
REGISTER_
PSCORE_
CLASS
(
ValueAccessor
,
CommMergeAccessor
);
int32_t
TableManager
::
initialize
()
{
static
bool
initialized
=
false
;
...
...
@@ -61,8 +62,8 @@ int32_t Table::initialize_accessor() {
<<
_config
.
table_id
();
return
-
1
;
}
auto
*
accessor
=
CREATE_CLASS
(
ValueAccessor
,
auto
*
accessor
=
CREATE_PSCORE_CLASS
(
ValueAccessor
,
_config
.
accessor
().
accessor_class
())
if
(
accessor
==
NULL
)
{
LOG
(
ERROR
)
<<
"accessor is unregisteg, table_id:"
<<
_config
.
table_id
()
<<
", accessor_name:"
<<
_config
.
accessor
().
accessor_class
();
...
...
paddle/fluid/distributed/table/table.h
浏览文件 @
a97ca56a
...
...
@@ -127,7 +127,7 @@ class Table {
float
*
_global_lr
=
nullptr
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
};
REGISTER_REGISTERER
(
Table
);
REGISTER_
PSCORE_
REGISTERER
(
Table
);
class
TableManager
{
public:
...
...
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
浏览文件 @
a97ca56a
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <string>
#include <thread> // NOLINT
...
...
@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
...
...
paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
浏览文件 @
a97ca56a
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -225,7 +226,8 @@ void RunBrpcPushSparse() {
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
break
;
}
...
...
@@ -252,7 +254,8 @@ void RunBrpcPushSparse() {
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
...
...
paddle/fluid/distributed/test/brpc_utils_test.cc
浏览文件 @
a97ca56a
...
...
@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) {
auto
&
ctx
=
*
pool
.
Get
(
place
);
CreateVarsOnScope
(
&
scope
,
&
place
,
ctx
);
::
paddle
::
MultiVariableMessage
multi_msg
;
::
paddle
::
distributed
::
MultiVariableMessage
multi_msg
;
std
::
string
message_name
(
"se_de_test"
);
std
::
vector
<
std
::
string
>
send_var_name
=
{
"x1"
,
"x2"
,
"x3"
};
std
::
vector
<
std
::
string
>
recv_var_name
=
{};
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
a97ca56a
...
...
@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method
sendrecvop_rpc communicator
collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
lod_rank_table feed_fetch_method collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto pslib_brpc
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
(
)
else
if
(
WITH_PSCORE
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
...
...
@@ -230,6 +230,16 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
multi_trainer.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
hogwild_worker.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor
)
endif
()
elseif
(
WITH_PSLIB
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
...
...
@@ -241,7 +251,6 @@ elseif(WITH_PSLIB)
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
a97ca56a
...
...
@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
reduce_op_handle.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
a97ca56a
...
...
@@ -16,7 +16,7 @@
#include "paddle/fluid/framework/variable_helper.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
"results to be fetched!"
));
// init once
if
(
run_futures_
.
size
()
==
0
&&
places_
.
size
()
>
1
)
{
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
strategy_
.
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
places_
.
size
());
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
a97ca56a
...
...
@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void
ThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
strategy_
.
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
a97ca56a
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
writer_
.
Flush
();
}
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
...
...
@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars
();
thread_scope_
->
DropKids
();
}
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
a97ca56a
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
workers_
.
resize
(
thread_num_
);
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
trainer_desc
.
thread_barrier
())
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
thread_num_
);
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
a97ca56a
...
...
@@ -77,7 +77,7 @@ set(SHARED_INFERENCE_SRCS
${
mkldnn_quantizer_src_file
}
)
# Create shared inference library defaultly
if
(
NOT WITH_
DISTRIBUT
E
)
if
(
NOT WITH_
PSCOR
E
)
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
analysis_predictor
)
else
()
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
a97ca56a
...
...
@@ -22,10 +22,13 @@ add_subdirectory(jit)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
pscore
)
add_subdirectory
(
collective
)
endif
()
if
(
WITH_PSCORE
)
add_subdirectory
(
pscore
)
endif
()
add_subdirectory
(
amp
)
add_subdirectory
(
reader
)
...
...
paddle/fluid/operators/pscore/CMakeLists.txt
浏览文件 @
a97ca56a
if
(
WITH_PSLIB
)
return
()
endif
()
include
(
operators
)
set
(
DISTRIBUTE_DEPS
""
)
...
...
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
浏览文件 @
a97ca56a
...
...
@@ -46,8 +46,8 @@ class DeviceContext;
namespace
paddle
{
namespace
operators
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
template
<
class
TKey
,
class
TValue
>
class
DoubleFindMap
:
public
std
::
unordered_map
<
TKey
,
TValue
>
{
...
...
paddle/fluid/operators/pscore/heter_listen_and_server_test.cc
浏览文件 @
a97ca56a
...
...
@@ -36,8 +36,8 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
distributed
;
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
DECLARE_double
(
eager_delete_tensor_gb
);
USE_OP
(
scale
);
...
...
paddle/fluid/operators/pscore/heter_server_test.cc
浏览文件 @
a97ca56a
...
...
@@ -32,8 +32,8 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
distributed
;
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
USE_OP
(
scale
);
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
a97ca56a
...
...
@@ -49,7 +49,7 @@ if (WITH_CRYPTO)
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
crypto.cc
)
endif
(
WITH_CRYPTO
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
set_source_files_properties
(
fleet_py.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
list
(
APPEND PYBIND_DEPS fleet communicator
)
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
a97ca56a
...
...
@@ -106,7 +106,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/crypto.h"
#endif
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/pybind/fleet_py.h"
#endif
...
...
@@ -2833,7 +2833,7 @@ All parameter, weight, gradient are variables in Paddle.
BindCrypto
(
&
m
);
#endif
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
BindDistFleetWrapper
(
&
m
);
BindPSHost
(
&
m
);
BindCommunicatorContext
(
&
m
);
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
a97ca56a
...
...
@@ -236,6 +236,7 @@ function cmake_base() {
-DPY_VERSION=
${
PY_VERSION
:-
2
.7
}
-DCMAKE_INSTALL_PREFIX=
${
INSTALL_PREFIX
:-
/paddle/build
}
-DWITH_GRPC=
${
grpc_flag
}
-DWITH_PSCORE=
${
distibuted_flag
}
-DWITH_GLOO=
${
gloo_flag
}
-DWITH_LITE=
${
WITH_LITE
:-
OFF
}
-DWITH_XPU=
${
WITH_XPU
:-
OFF
}
...
...
@@ -269,6 +270,7 @@ EOF
-DPY_VERSION
=
${
PY_VERSION
:-
2
.7
}
\
-DCMAKE_INSTALL_PREFIX
=
${
INSTALL_PREFIX
:-
/paddle/build
}
\
-DWITH_GRPC
=
${
grpc_flag
}
\
-DWITH_PSCORE
=
${
distibuted_flag
}
\
-DWITH_GLOO
=
${
gloo_flag
}
\
-DLITE_GIT_TAG
=
develop
\
-DWITH_XPU
=
${
WITH_XPU
:-
OFF
}
\
...
...
paddle/testing/paddle_gtest_main.cc
浏览文件 @
a97ca56a
...
...
@@ -59,7 +59,8 @@ int main(int argc, char** argv) {
std
::
vector
<
std
::
string
>
envs
;
std
::
vector
<
std
::
string
>
undefok
;
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC)
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) && \
!defined(PADDLE_WITH_PSLIB)
std
::
string
str_max_body_size
;
if
(
google
::
GetCommandLineOption
(
"max_body_size"
,
&
str_max_body_size
))
{
setenv
(
"FLAGS_max_body_size"
,
"2147483647"
,
1
);
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
a97ca56a
...
...
@@ -268,7 +268,7 @@ class Service:
def
__init__
(
self
):
self
.
server_class
=
"BrpcPsServer"
self
.
client_class
=
"BrpcPsClient"
self
.
service_class
=
"PsService"
self
.
service_class
=
"
Brpc
PsService"
self
.
start_server_port
=
0
self
.
server_thread_num
=
12
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录