Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
25f80fd3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
25f80fd3
编写于
1月 12, 2021
作者:
T
tangwei12
提交者:
GitHub
1月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix/distributed proto (#29981)
* rename sendrecv.proto to namespace paddle.distributed * split ps with distributed
上级
d479ae17
变更
44
隐藏空白更改
内联
并排
Showing
44 changed file
with
249 addition
and
198 deletion
+249
-198
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/configure.cmake
cmake/configure.cmake
+5
-0
cmake/third_party.cmake
cmake/third_party.cmake
+1
-1
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-4
paddle/fluid/distributed/common/registerer.h
paddle/fluid/distributed/common/registerer.h
+7
-7
paddle/fluid/distributed/ps.proto
paddle/fluid/distributed/ps.proto
+1
-1
paddle/fluid/distributed/service/brpc_ps_client.cc
paddle/fluid/distributed/service/brpc_ps_client.cc
+3
-3
paddle/fluid/distributed/service/brpc_ps_client.h
paddle/fluid/distributed/service/brpc_ps_client.h
+2
-2
paddle/fluid/distributed/service/brpc_ps_server.cc
paddle/fluid/distributed/service/brpc_ps_server.cc
+106
-93
paddle/fluid/distributed/service/brpc_ps_server.h
paddle/fluid/distributed/service/brpc_ps_server.h
+5
-5
paddle/fluid/distributed/service/brpc_utils.cc
paddle/fluid/distributed/service/brpc_utils.cc
+6
-6
paddle/fluid/distributed/service/brpc_utils.h
paddle/fluid/distributed/service/brpc_utils.h
+2
-2
paddle/fluid/distributed/service/heter_client.cc
paddle/fluid/distributed/service/heter_client.cc
+2
-2
paddle/fluid/distributed/service/heter_client.h
paddle/fluid/distributed/service/heter_client.h
+2
-2
paddle/fluid/distributed/service/heter_server.h
paddle/fluid/distributed/service/heter_server.h
+5
-5
paddle/fluid/distributed/service/ps_client.cc
paddle/fluid/distributed/service/ps_client.cc
+4
-5
paddle/fluid/distributed/service/ps_client.h
paddle/fluid/distributed/service/ps_client.h
+4
-1
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-1
paddle/fluid/distributed/service/server.cc
paddle/fluid/distributed/service/server.cc
+6
-4
paddle/fluid/distributed/service/server.h
paddle/fluid/distributed/service/server.h
+6
-4
paddle/fluid/distributed/service/service.h
paddle/fluid/distributed/service/service.h
+4
-0
paddle/fluid/distributed/table/accessor.h
paddle/fluid/distributed/table/accessor.h
+1
-1
paddle/fluid/distributed/table/table.cc
paddle/fluid/distributed/table/table.cc
+12
-11
paddle/fluid/distributed/table/table.h
paddle/fluid/distributed/table/table.h
+1
-1
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
+5
-3
paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
...le/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
+7
-4
paddle/fluid/distributed/test/brpc_utils_test.cc
paddle/fluid/distributed/test/brpc_utils_test.cc
+2
-2
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+12
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+3
-3
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+2
-2
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+3
-3
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+4
-1
paddle/fluid/operators/pscore/CMakeLists.txt
paddle/fluid/operators/pscore/CMakeLists.txt
+4
-0
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
+2
-2
paddle/fluid/operators/pscore/heter_listen_and_server_test.cc
...le/fluid/operators/pscore/heter_listen_and_server_test.cc
+2
-2
paddle/fluid/operators/pscore/heter_server_test.cc
paddle/fluid/operators/pscore/heter_server_test.cc
+2
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-2
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+2
-0
paddle/testing/paddle_gtest_main.cc
paddle/testing/paddle_gtest_main.cc
+2
-1
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
25f80fd3
...
...
@@ -160,6 +160,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option
(
WITH_XBYAK
"Compile with xbyak support"
ON
)
option
(
WITH_CONTRIB
"Compile the third-party contributation"
OFF
)
option
(
WITH_GRPC
"Use grpc as the default rpc framework"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_PSCORE
"Compile with parameter server support"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_INFERENCE_API_TEST
"Test fluid inference C++ high-level api interface"
OFF
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
WITH_DGC
"Use DGC(Deep Gradient Compression) or not"
${
WITH_DISTRIBUTE
}
)
...
...
cmake/configure.cmake
浏览文件 @
25f80fd3
...
...
@@ -160,6 +160,11 @@ if(WITH_DISTRIBUTE)
add_definitions
(
-DPADDLE_WITH_DISTRIBUTE
)
endif
()
if
(
WITH_PSCORE
)
add_definitions
(
-DPADDLE_WITH_PSCORE
)
endif
()
if
(
WITH_GRPC
)
add_definitions
(
-DPADDLE_WITH_GRPC
)
endif
(
WITH_GRPC
)
...
...
cmake/third_party.cmake
浏览文件 @
25f80fd3
...
...
@@ -274,7 +274,7 @@ if(WITH_BOX_PS)
list
(
APPEND third_party_deps extern_box_ps
)
endif
(
WITH_BOX_PS
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
include
(
external/snappy
)
list
(
APPEND third_party_deps extern_snappy
)
...
...
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
25f80fd3
if
(
WITH_PSLIB
)
return
()
endif
()
if
(
NOT WITH_DISTRIBUTE
)
if
(
NOT WITH_PSCORE
)
return
()
endif
()
...
...
paddle/fluid/distributed/common/registerer.h
浏览文件 @
25f80fd3
...
...
@@ -69,24 +69,24 @@ class ObjectFactory {
};
typedef
std
::
map
<
std
::
string
,
ObjectFactory
*>
FactoryMap
;
typedef
std
::
map
<
std
::
string
,
FactoryMap
>
Bas
eClassMap
;
typedef
std
::
map
<
std
::
string
,
FactoryMap
>
PsCor
eClassMap
;
#ifdef __cplusplus
extern
"C"
{
#endif
inline
Bas
eClassMap
&
global_factory_map
()
{
static
BaseClassMap
*
base_class
=
new
Bas
eClassMap
();
inline
PsCor
eClassMap
&
global_factory_map
()
{
static
PsCoreClassMap
*
base_class
=
new
PsCor
eClassMap
();
return
*
base_class
;
}
#ifdef __cplusplus
}
#endif
inline
Bas
eClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
}
inline
PsCor
eClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
}
// typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap;
#define REGISTER_
REGISTERER(base_class)
\
#define REGISTER_
PSCORE_REGISTERER(base_class)
\
class base_class##Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
...
...
@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
};
#define REGISTER_
CLASS(clazz, name)
\
#define REGISTER_
PSCORE_CLASS(clazz, name)
\
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { return Any(new name()); } \
...
...
@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_
PSCORE_
CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
}
// namespace distributed
...
...
paddle/fluid/distributed/ps.proto
浏览文件 @
25f80fd3
...
...
@@ -86,7 +86,7 @@ message SparseTableParameter {
message
ServerServiceParameter
{
optional
string
server_class
=
1
[
default
=
"BrpcPsServer"
];
optional
string
client_class
=
2
[
default
=
"BrpcPsClient"
];
optional
string
service_class
=
3
[
default
=
"PsService"
];
optional
string
service_class
=
3
[
default
=
"
Brpc
PsService"
];
optional
uint32
start_server_port
=
4
[
default
=
0
];
// will find a avaliable port from it
optional
uint32
server_thread_num
=
5
[
default
=
12
];
...
...
paddle/fluid/distributed/service/brpc_ps_client.cc
浏览文件 @
25f80fd3
...
...
@@ -17,8 +17,8 @@
#include <sstream>
#include <string>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
...
...
@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void
DownpourPsClientService
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
int
ret
=
_client
->
handle_client2client_msg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
...
...
paddle/fluid/distributed/service/brpc_ps_client.h
浏览文件 @
25f80fd3
...
...
@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return
0
;
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
protected:
...
...
paddle/fluid/distributed/service/brpc_ps_server.cc
浏览文件 @
25f80fd3
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "Eigen/Dense"
#include "butil/endpoint.h"
...
...
@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() {
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
}
auto
*
service
=
CREATE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
...
...
@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
int32_t
BrpcPsServer
::
port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
PsService
::
initialize
()
{
int32_t
Brpc
PsService
::
initialize
()
{
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
PsService
::
stop_server
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
PsService
::
pull_dense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
PsService
::
push_dense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
PsService
::
pull_sparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
PsService
::
push_sparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
PsService
::
save_one_table
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
PsService
::
save_all_table
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
PsService
::
shrink_table
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
PsService
::
load_one_table
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
PsService
::
load_all_table
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
PsService
::
clear_one_table
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
PsService
::
clear_all_table
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
PsService
::
push_dense_param
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
PsService
::
print_table_stat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
PsService
::
pull_geo_param
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
PsService
::
push_sparse_param
;
_service_handler_map
[
PS_BARRIER
]
=
&
PsService
::
barrier
;
_service_handler_map
[
PS_
START_PROFILER
]
=
&
PsService
::
start_profil
er
;
_service_handler_map
[
PS_ST
OP_PROFILER
]
=
&
PsService
::
stop
_profiler
;
_service_handler_map
[
PS_
PUSH_GLOBAL_STEP
]
=
&
PsService
::
push_global_step
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
Brpc
PsService
::
stop_server
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
Brpc
PsService
::
pull_dense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
Brpc
PsService
::
push_dense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
Brpc
PsService
::
pull_sparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
Brpc
PsService
::
push_sparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
Brpc
PsService
::
save_one_table
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
Brpc
PsService
::
save_all_table
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
Brpc
PsService
::
shrink_table
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
Brpc
PsService
::
load_one_table
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
Brpc
PsService
::
load_all_table
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
Brpc
PsService
::
clear_one_table
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
Brpc
PsService
::
clear_all_table
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
Brpc
PsService
::
push_dense_param
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
Brpc
PsService
::
print_table_stat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
Brpc
PsService
::
pull_geo_param
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
BrpcPsService
::
push_sparse_param
;
_service_handler_map
[
PS_
BARRIER
]
=
&
BrpcPsService
::
barri
er
;
_service_handler_map
[
PS_ST
ART_PROFILER
]
=
&
BrpcPsService
::
start
_profiler
;
_service_handler_map
[
PS_
STOP_PROFILER
]
=
&
BrpcPsService
::
stop_profiler
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
...
...
@@ -116,7 +118,7 @@ int32_t PsService::initialize() {
return -1; \
}
int32_t
PsService
::
initialize_shard_info
()
{
int32_t
Brpc
PsService
::
initialize_shard_info
()
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
...
...
@@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() {
return
0
;
}
void
PsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
void
Brpc
PsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
if
(
!
request
->
has_table_id
())
{
...
...
@@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t
PsService
::
pull_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
pull_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_dense"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
...
...
@@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
push_dense_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
push_dense_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_dense_param"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_buffer
;
...
...
@@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table,
return
0
;
}
int32_t
PsService
::
push_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
push_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_dense"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
req_buffer_size
=
request
.
data
().
size
();
...
...
@@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
...
...
@@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
push_sparse_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
push_sparse_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_sparse_param"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
...
...
@@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table,
return
0
;
}
int32_t
PsService
::
pull_geo_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
pull_geo_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_geo_param"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_sparse_request_buffer
;
...
...
@@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
pull_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
pull_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_sparse"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_sparse_request_buffer
;
...
...
@@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
push_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
push_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_sparse"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
...
...
@@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
print_table_stat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
print_table_stat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
print_table_stat
();
paddle
::
framework
::
BinaryArchive
ar
;
...
...
@@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table,
return
0
;
}
int32_t
PsService
::
load_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
load_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
...
...
@@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
load_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
load_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
for
(
auto
&
itr
:
table_map
)
{
if
(
load_one_table
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
...
...
@@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
save_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
save_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
...
...
@@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
return
feasign_size
;
}
int32_t
PsService
::
save_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
save_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
int32_t
all_feasign_size
=
0
;
int32_t
feasign_size
=
0
;
...
...
@@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
shrink_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
shrink_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
flush
();
if
(
table
->
shrink
()
!=
0
)
{
...
...
@@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
clear_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
clear_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
flush
();
table
->
clear
();
return
0
;
}
int32_t
PsService
::
clear_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
clear_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
for
(
auto
&
itr
:
table_map
)
{
if
(
clear_one_table
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
...
...
@@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table,
return
0
;
}
int32_t
PsService
::
stop_server
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
stop_server
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
*
p_server
=
_server
;
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
stop
();
...
...
@@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
return
0
;
}
int32_t
PsService
::
stop_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
stop_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
return
0
;
}
int32_t
PsService
::
start_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
BrpcPsService
::
start_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
}
int32_t
PsService
::
push_global_step
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int32_t
Brpc
PsService
::
push_global_step
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
);
auto
req_buffer_size
=
request
.
data
().
size
();
if
(
req_buffer_size
<
1
)
{
...
...
paddle/fluid/distributed/service/brpc_ps_server.h
浏览文件 @
25f80fd3
...
...
@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
class
PsService
;
class
Brpc
PsService
;
typedef
int32_t
(
PsService
::*
serviceHandlerFunc
)(
typedef
int32_t
(
Brpc
PsService
::*
serviceHandlerFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
class
PsService
:
public
PsBaseService
{
class
Brpc
PsService
:
public
PsBaseService
{
public:
virtual
int32_t
initialize
()
override
;
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
private:
...
...
paddle/fluid/distributed/service/brpc_utils.cc
浏览文件 @
25f80fd3
...
...
@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
var_msg
->
set_type
(
::
paddle
::
LOD_TENSOR
);
var_msg
->
set_type
(
::
paddle
::
distributed
::
LOD_TENSOR
);
const
framework
::
LoD
lod
=
tensor
->
lod
();
if
(
lod
.
size
()
>
0
)
{
var_msg
->
set_lod_level
(
lod
.
size
());
...
...
@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
var_msg
->
set_type
(
::
paddle
::
SELECTED_ROWS
);
var_msg
->
set_type
(
::
paddle
::
distributed
::
SELECTED_ROWS
);
var_msg
->
set_slr_height
(
slr
->
height
());
auto
*
var_data
=
var_msg
->
mutable_data
();
...
...
@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
Var
(
msg
.
varname
());
if
(
msg
.
type
()
==
::
paddle
::
LOD_TENSOR
)
{
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
SELECTED_ROWS
)
{
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
...
...
@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
"Not find variable %s in scope."
,
msg
.
varname
()));
if
(
msg
.
type
()
==
::
paddle
::
LOD_TENSOR
)
{
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
SELECTED_ROWS
)
{
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
...
...
paddle/fluid/distributed/service/brpc_utils.h
浏览文件 @
25f80fd3
...
...
@@ -44,8 +44,8 @@ class DeviceContext;
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
...
...
paddle/fluid/distributed/service/heter_client.cc
浏览文件 @
25f80fd3
...
...
@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
,
response
;
auto
&
request_io_buffer
=
cntl
.
request_attachment
();
::
paddle
::
PsService_Stub
stub
(
xpu_channels_
[
num
].
get
());
::
paddle
::
distributed
::
PsService_Stub
stub
(
xpu_channels_
[
num
].
get
());
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name_val
,
send_var_name_val
,
recv_var_name_val
,
*
p_ctx
,
p_scope
,
&
request
,
&
request_io_buffer
);
...
...
@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
}
::
paddle
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
::
paddle
::
distributed
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
closure
->
cntl
(
i
)
->
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
...
...
paddle/fluid/distributed/service/heter_client.h
浏览文件 @
25f80fd3
...
...
@@ -35,8 +35,8 @@ limitations under the License. */
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
HeterRpcCallbackFunc
;
...
...
paddle/fluid/distributed/service/heter_server.h
浏览文件 @
25f80fd3
...
...
@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace
paddle
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
class
HeterService
;
typedef
int32_t
(
HeterService
::*
serviceHandlerFunc
)(
...
...
@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef
std
::
function
<
int
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
HeterServiceHandler
;
class
HeterService
:
public
::
paddle
::
PsService
{
class
HeterService
:
public
::
paddle
::
distributed
::
PsService
{
public:
HeterService
()
{
_service_handler_map
[
PS_STOP_SERVER
]
=
&
HeterService
::
stop_heter_worker
;
...
...
@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual
~
HeterService
()
{}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
...
...
paddle/fluid/distributed/service/ps_client.cc
浏览文件 @
25f80fd3
...
...
@@ -13,9 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include "brpc/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
...
...
@@ -23,7 +21,7 @@
namespace
paddle
{
namespace
distributed
{
REGISTER_CLASS
(
PSClient
,
BrpcPsClient
);
REGISTER_
PSCORE_
CLASS
(
PSClient
,
BrpcPsClient
);
int32_t
PSClient
::
configure
(
const
PSParameter
&
config
,
...
...
@@ -43,7 +41,7 @@ int32_t PSClient::configure(
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
size_t
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
accessor
=
CREATE_CLASS
(
auto
*
accessor
=
CREATE_
PSCORE_
CLASS
(
ValueAccessor
,
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
accessor
->
configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
...
...
@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSClient
*
client
=
CREATE_CLASS
(
PSClient
,
service_param
.
client_class
());
PSClient
*
client
=
CREATE_PSCORE_CLASS
(
PSClient
,
service_param
.
client_class
());
if
(
client
==
NULL
)
{
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
<<
service_param
.
client_class
();
...
...
paddle/fluid/distributed/service/ps_client.h
浏览文件 @
25f80fd3
...
...
@@ -28,6 +28,9 @@
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
...
...
@@ -206,7 +209,7 @@ class PSClient {
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
//处理client2client消息
};
REGISTER_REGISTERER
(
PSClient
);
REGISTER_
PSCORE_
REGISTERER
(
PSClient
);
class
PSClientFactory
{
public:
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
25f80fd3
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
syntax
=
"proto2"
;
package
paddle
;
package
paddle
.
distributed
;
option
cc_generic_services
=
true
;
option
cc_enable_arenas
=
true
;
...
...
paddle/fluid/distributed/service/server.cc
浏览文件 @
25f80fd3
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
...
...
@@ -20,8 +21,8 @@
namespace
paddle
{
namespace
distributed
{
REGISTER_CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_
CLASS
(
PsBaseService
,
PsService
);
REGISTER_
PSCORE_
CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_
PSCORE_CLASS
(
PsBaseService
,
Brpc
PsService
);
PSServer
*
PSServerFactory
::
create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
...
...
@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSServer
*
server
=
CREATE_CLASS
(
PSServer
,
service_param
.
server_class
());
PSServer
*
server
=
CREATE_PSCORE_CLASS
(
PSServer
,
service_param
.
server_class
());
if
(
server
==
NULL
)
{
LOG
(
ERROR
)
<<
"server is not registered, server_name:"
<<
service_param
.
server_class
();
...
...
@@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t
global_step_table
=
UINT32_MAX
;
for
(
size_t
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_CLASS
(
auto
*
table
=
CREATE_
PSCORE_
CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
if
(
downpour_param
.
downpour_table_param
(
i
).
table_class
()
==
...
...
paddle/fluid/distributed/service/server.h
浏览文件 @
25f80fd3
...
...
@@ -46,6 +46,8 @@ namespace paddle {
namespace
distributed
{
class
Table
;
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
class
PSServer
{
public:
...
...
@@ -107,7 +109,7 @@ class PSServer {
platform
::
Place
place_
=
platform
::
CPUPlace
();
};
REGISTER_REGISTERER
(
PSServer
);
REGISTER_
PSCORE_
REGISTERER
(
PSServer
);
typedef
std
::
function
<
void
(
void
*
)
>
PServerCallBack
;
...
...
@@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return
0
;
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
=
0
;
virtual
void
set_response_code
(
PsResponseMessage
&
response
,
int
err_code
,
...
...
@@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer
*
_server
;
const
ServerParameter
*
_config
;
};
REGISTER_REGISTERER
(
PsBaseService
);
REGISTER_
PSCORE_
REGISTERER
(
PsBaseService
);
class
PSServerFactory
{
public:
...
...
paddle/fluid/distributed/service/service.h
浏览文件 @
25f80fd3
...
...
@@ -28,6 +28,10 @@ limitations under the License. */
namespace
paddle
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
using
paddle
::
distributed
::
PsService
;
class
PSCore
{
public:
explicit
PSCore
()
{}
...
...
paddle/fluid/distributed/table/accessor.h
浏览文件 @
25f80fd3
...
...
@@ -165,6 +165,6 @@ class ValueAccessor {
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
_data_coverter_map
;
};
REGISTER_REGISTERER
(
ValueAccessor
);
REGISTER_
PSCORE_
REGISTERER
(
ValueAccessor
);
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/table/table.cc
浏览文件 @
25f80fd3
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h"
...
...
@@ -27,14 +28,14 @@
namespace
paddle
{
namespace
distributed
{
REGISTER_CLASS
(
Table
,
CommonDenseTable
);
REGISTER_CLASS
(
Table
,
CommonSparseTable
);
REGISTER_CLASS
(
Table
,
SparseGeoTable
);
REGISTER_CLASS
(
Table
,
BarrierTable
);
REGISTER_CLASS
(
Table
,
TensorTable
);
REGISTER_CLASS
(
Table
,
DenseTensorTable
);
REGISTER_CLASS
(
Table
,
GlobalStepTable
);
REGISTER_CLASS
(
ValueAccessor
,
CommMergeAccessor
);
REGISTER_
PSCORE_
CLASS
(
Table
,
CommonDenseTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
CommonSparseTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
SparseGeoTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
BarrierTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
TensorTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
DenseTensorTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
GlobalStepTable
);
REGISTER_
PSCORE_
CLASS
(
ValueAccessor
,
CommMergeAccessor
);
int32_t
TableManager
::
initialize
()
{
static
bool
initialized
=
false
;
...
...
@@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() {
<<
_config
.
table_id
();
return
-
1
;
}
auto
*
accessor
=
CREATE_CLASS
(
ValueAccessor
,
_config
.
accessor
().
accessor_class
())
if
(
accessor
==
NULL
)
{
auto
*
accessor
=
CREATE_PSCORE_CLASS
(
ValueAccessor
,
_config
.
accessor
().
accessor_class
())
if
(
accessor
==
NULL
)
{
LOG
(
ERROR
)
<<
"accessor is unregisteg, table_id:"
<<
_config
.
table_id
()
<<
", accessor_name:"
<<
_config
.
accessor
().
accessor_class
();
return
-
1
;
...
...
paddle/fluid/distributed/table/table.h
浏览文件 @
25f80fd3
...
...
@@ -127,7 +127,7 @@ class Table {
float
*
_global_lr
=
nullptr
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
};
REGISTER_REGISTERER
(
Table
);
REGISTER_
PSCORE_
REGISTERER
(
Table
);
class
TableManager
{
public:
...
...
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
浏览文件 @
25f80fd3
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <string>
#include <thread> // NOLINT
...
...
@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
...
...
paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
浏览文件 @
25f80fd3
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
...
...
@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto(
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
...
...
@@ -225,7 +226,8 @@ void RunBrpcPushSparse() {
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
break
;
}
...
...
@@ -252,7 +254,8 @@ void RunBrpcPushSparse() {
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
break
;
}
...
...
paddle/fluid/distributed/test/brpc_utils_test.cc
浏览文件 @
25f80fd3
...
...
@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) {
auto
&
ctx
=
*
pool
.
Get
(
place
);
CreateVarsOnScope
(
&
scope
,
&
place
,
ctx
);
::
paddle
::
MultiVariableMessage
multi_msg
;
::
paddle
::
distributed
::
MultiVariableMessage
multi_msg
;
std
::
string
message_name
(
"se_de_test"
);
std
::
vector
<
std
::
string
>
send_var_name
=
{
"x1"
,
"x2"
,
"x3"
};
std
::
vector
<
std
::
string
>
recv_var_name
=
{};
...
...
@@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) {
// platform::CUDAPlace place;
// RunMultiVarMsg(place);
// }
// #endif
\ No newline at end of file
// #endif
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
25f80fd3
...
...
@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method
sendrecvop_rpc communicator
collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
lod_rank_table feed_fetch_method collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto pslib_brpc
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
(
)
else
if
(
WITH_PSCORE
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
...
...
@@ -228,6 +228,16 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
multi_trainer.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
hogwild_worker.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor
)
endif
()
elseif
(
WITH_PSLIB
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
...
...
@@ -239,7 +249,6 @@ elseif(WITH_PSLIB)
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
25f80fd3
...
...
@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
reduce_op_handle.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
25f80fd3
...
...
@@ -16,7 +16,7 @@
#include "paddle/fluid/framework/variable_helper.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
"results to be fetched!"
));
// init once
if
(
run_futures_
.
size
()
==
0
&&
places_
.
size
()
>
1
)
{
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
strategy_
.
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
places_
.
size
());
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
25f80fd3
...
...
@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void
ThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
strategy_
.
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
25f80fd3
...
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
writer_
.
Flush
();
}
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
...
...
@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars
();
thread_scope_
->
DropKids
();
}
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
25f80fd3
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#endif
...
...
@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
workers_
.
resize
(
thread_num_
);
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
trainer_desc
.
thread_barrier
())
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
thread_num_
);
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
25f80fd3
...
...
@@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS
${
mkldnn_quantizer_src_file
}
)
# Create shared inference library defaultly
if
(
NOT WITH_
DISTRIBUT
E
)
if
(
NOT WITH_
PSCOR
E
)
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
analysis_predictor
)
DEPS
${
fluid_modules
}
analysis_predictor
)
else
()
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
analysis_predictor fleet ps_service
)
DEPS
${
fluid_modules
}
analysis_predictor fleet ps_service
)
endif
()
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
25f80fd3
...
...
@@ -22,10 +22,13 @@ add_subdirectory(jit)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
pscore
)
add_subdirectory
(
collective
)
endif
()
if
(
WITH_PSCORE
)
add_subdirectory
(
pscore
)
endif
()
add_subdirectory
(
amp
)
add_subdirectory
(
reader
)
...
...
paddle/fluid/operators/pscore/CMakeLists.txt
浏览文件 @
25f80fd3
if
(
WITH_PSLIB
)
return
()
endif
()
include
(
operators
)
set
(
DISTRIBUTE_DEPS
""
)
...
...
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
浏览文件 @
25f80fd3
...
...
@@ -46,8 +46,8 @@ class DeviceContext;
namespace
paddle
{
namespace
operators
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
template
<
class
TKey
,
class
TValue
>
class
DoubleFindMap
:
public
std
::
unordered_map
<
TKey
,
TValue
>
{
...
...
paddle/fluid/operators/pscore/heter_listen_and_server_test.cc
浏览文件 @
25f80fd3
...
...
@@ -36,8 +36,8 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
distributed
;
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
DECLARE_double
(
eager_delete_tensor_gb
);
USE_OP
(
scale
);
...
...
paddle/fluid/operators/pscore/heter_server_test.cc
浏览文件 @
25f80fd3
...
...
@@ -32,8 +32,8 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
distributed
;
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
USE_OP
(
scale
);
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
25f80fd3
...
...
@@ -49,7 +49,7 @@ if (WITH_CRYPTO)
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
crypto.cc
)
endif
(
WITH_CRYPTO
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
set_source_files_properties
(
fleet_py.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
list
(
APPEND PYBIND_DEPS fleet communicator
)
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
25f80fd3
...
...
@@ -107,7 +107,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/crypto.h"
#endif
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/pybind/fleet_py.h"
#endif
...
...
@@ -2841,7 +2841,7 @@ All parameter, weight, gradient are variables in Paddle.
BindCrypto
(
&
m
);
#endif
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
BindDistFleetWrapper
(
&
m
);
BindPSHost
(
&
m
);
BindCommunicatorContext
(
&
m
);
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
25f80fd3
...
...
@@ -249,6 +249,7 @@ function cmake_base() {
-DPY_VERSION=
${
PY_VERSION
:-
2
.7
}
-DCMAKE_INSTALL_PREFIX=
${
INSTALL_PREFIX
:-
/paddle/build
}
-DWITH_GRPC=
${
grpc_flag
}
-DWITH_PSCORE=
${
distibuted_flag
}
-DWITH_GLOO=
${
gloo_flag
}
-DWITH_LITE=
${
WITH_LITE
:-
OFF
}
-DWITH_XPU=
${
WITH_XPU
:-
OFF
}
...
...
@@ -284,6 +285,7 @@ EOF
-DPY_VERSION
=
${
PY_VERSION
:-
2
.7
}
\
-DCMAKE_INSTALL_PREFIX
=
${
INSTALL_PREFIX
:-
/paddle/build
}
\
-DWITH_GRPC
=
${
grpc_flag
}
\
-DWITH_PSCORE
=
${
distibuted_flag
}
\
-DWITH_GLOO
=
${
gloo_flag
}
\
-DLITE_GIT_TAG
=
develop
\
-DWITH_XPU
=
${
WITH_XPU
:-
OFF
}
\
...
...
paddle/testing/paddle_gtest_main.cc
浏览文件 @
25f80fd3
...
...
@@ -59,7 +59,8 @@ int main(int argc, char** argv) {
std
::
vector
<
std
::
string
>
envs
;
std
::
vector
<
std
::
string
>
undefok
;
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC)
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) && \
!defined(PADDLE_WITH_PSLIB)
std
::
string
str_max_body_size
;
if
(
google
::
GetCommandLineOption
(
"max_body_size"
,
&
str_max_body_size
))
{
setenv
(
"FLAGS_max_body_size"
,
"2147483647"
,
1
);
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
25f80fd3
...
...
@@ -268,7 +268,7 @@ class Service:
def
__init__
(
self
):
self
.
server_class
=
"BrpcPsServer"
self
.
client_class
=
"BrpcPsClient"
self
.
service_class
=
"PsService"
self
.
service_class
=
"
Brpc
PsService"
self
.
start_server_port
=
0
self
.
server_thread_num
=
12
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录