Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a97ca56a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a97ca56a
编写于
1月 13, 2021
作者:
T
tangwei12
提交者:
GitHub
1月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
split ps with distributed (#30337)
Change-Id: I3c788e7576688e63181e7f01562529b85a09cc59
上级
5eab1a38
变更
44
隐藏空白更改
内联
并排
Showing
44 changed file
with
251 addition
and
200 deletion
+251
-200
CMakeLists.txt
CMakeLists.txt
+1
-0
cmake/configure.cmake
cmake/configure.cmake
+5
-0
cmake/third_party.cmake
cmake/third_party.cmake
+1
-1
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-4
paddle/fluid/distributed/common/registerer.h
paddle/fluid/distributed/common/registerer.h
+7
-7
paddle/fluid/distributed/ps.proto
paddle/fluid/distributed/ps.proto
+1
-1
paddle/fluid/distributed/service/brpc_ps_client.cc
paddle/fluid/distributed/service/brpc_ps_client.cc
+3
-3
paddle/fluid/distributed/service/brpc_ps_client.h
paddle/fluid/distributed/service/brpc_ps_client.h
+2
-2
paddle/fluid/distributed/service/brpc_ps_server.cc
paddle/fluid/distributed/service/brpc_ps_server.cc
+106
-93
paddle/fluid/distributed/service/brpc_ps_server.h
paddle/fluid/distributed/service/brpc_ps_server.h
+5
-5
paddle/fluid/distributed/service/brpc_utils.cc
paddle/fluid/distributed/service/brpc_utils.cc
+6
-6
paddle/fluid/distributed/service/brpc_utils.h
paddle/fluid/distributed/service/brpc_utils.h
+2
-2
paddle/fluid/distributed/service/heter_client.cc
paddle/fluid/distributed/service/heter_client.cc
+2
-2
paddle/fluid/distributed/service/heter_client.h
paddle/fluid/distributed/service/heter_client.h
+2
-2
paddle/fluid/distributed/service/heter_server.h
paddle/fluid/distributed/service/heter_server.h
+5
-5
paddle/fluid/distributed/service/ps_client.cc
paddle/fluid/distributed/service/ps_client.cc
+4
-5
paddle/fluid/distributed/service/ps_client.h
paddle/fluid/distributed/service/ps_client.h
+4
-1
paddle/fluid/distributed/service/sendrecv.proto
paddle/fluid/distributed/service/sendrecv.proto
+1
-1
paddle/fluid/distributed/service/server.cc
paddle/fluid/distributed/service/server.cc
+6
-4
paddle/fluid/distributed/service/server.h
paddle/fluid/distributed/service/server.h
+6
-4
paddle/fluid/distributed/service/service.h
paddle/fluid/distributed/service/service.h
+4
-0
paddle/fluid/distributed/table/accessor.h
paddle/fluid/distributed/table/accessor.h
+1
-1
paddle/fluid/distributed/table/table.cc
paddle/fluid/distributed/table/table.cc
+12
-11
paddle/fluid/distributed/table/table.h
paddle/fluid/distributed/table/table.h
+1
-1
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
+5
-3
paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
...le/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
+7
-4
paddle/fluid/distributed/test/brpc_utils_test.cc
paddle/fluid/distributed/test/brpc_utils_test.cc
+2
-2
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+12
-3
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+1
-1
paddle/fluid/framework/details/async_ssa_graph_executor.cc
paddle/fluid/framework/details/async_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
...le/fluid/framework/details/threaded_ssa_graph_executor.cc
+2
-2
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+3
-3
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+2
-2
paddle/fluid/inference/CMakeLists.txt
paddle/fluid/inference/CMakeLists.txt
+3
-3
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+4
-1
paddle/fluid/operators/pscore/CMakeLists.txt
paddle/fluid/operators/pscore/CMakeLists.txt
+4
-0
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
+2
-2
paddle/fluid/operators/pscore/heter_listen_and_server_test.cc
...le/fluid/operators/pscore/heter_listen_and_server_test.cc
+2
-2
paddle/fluid/operators/pscore/heter_server_test.cc
paddle/fluid/operators/pscore/heter_server_test.cc
+2
-2
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+1
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-2
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+4
-2
paddle/testing/paddle_gtest_main.cc
paddle/testing/paddle_gtest_main.cc
+2
-1
python/paddle/distributed/fleet/runtime/the_one_ps.py
python/paddle/distributed/fleet/runtime/the_one_ps.py
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
a97ca56a
...
@@ -136,6 +136,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
...
@@ -136,6 +136,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option
(
WITH_XBYAK
"Compile with xbyak support"
ON
)
option
(
WITH_XBYAK
"Compile with xbyak support"
ON
)
option
(
WITH_CONTRIB
"Compile the third-party contributation"
OFF
)
option
(
WITH_CONTRIB
"Compile the third-party contributation"
OFF
)
option
(
WITH_GRPC
"Use grpc as the default rpc framework"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_GRPC
"Use grpc as the default rpc framework"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_PSCORE
"Compile with parameter server support"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_INFERENCE_API_TEST
"Test fluid inference C++ high-level api interface"
OFF
)
option
(
WITH_INFERENCE_API_TEST
"Test fluid inference C++ high-level api interface"
OFF
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
PY_VERSION
"Compile PaddlePaddle with python3 support"
${
PY_VERSION
}
)
option
(
WITH_DGC
"Use DGC(Deep Gradient Compression) or not"
${
WITH_DISTRIBUTE
}
)
option
(
WITH_DGC
"Use DGC(Deep Gradient Compression) or not"
${
WITH_DISTRIBUTE
}
)
...
...
cmake/configure.cmake
浏览文件 @
a97ca56a
...
@@ -156,6 +156,11 @@ if(WITH_DISTRIBUTE)
...
@@ -156,6 +156,11 @@ if(WITH_DISTRIBUTE)
add_definitions
(
-DPADDLE_WITH_DISTRIBUTE
)
add_definitions
(
-DPADDLE_WITH_DISTRIBUTE
)
endif
()
endif
()
if
(
WITH_PSCORE
)
add_definitions
(
-DPADDLE_WITH_PSCORE
)
endif
()
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
add_definitions
(
-DPADDLE_WITH_GRPC
)
add_definitions
(
-DPADDLE_WITH_GRPC
)
endif
(
WITH_GRPC
)
endif
(
WITH_GRPC
)
...
...
cmake/third_party.cmake
浏览文件 @
a97ca56a
...
@@ -280,7 +280,7 @@ if(WITH_BOX_PS)
...
@@ -280,7 +280,7 @@ if(WITH_BOX_PS)
list
(
APPEND third_party_deps extern_box_ps
)
list
(
APPEND third_party_deps extern_box_ps
)
endif
(
WITH_BOX_PS
)
endif
(
WITH_BOX_PS
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
include
(
external/snappy
)
include
(
external/snappy
)
list
(
APPEND third_party_deps extern_snappy
)
list
(
APPEND third_party_deps extern_snappy
)
...
...
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
a97ca56a
if
(
WITH_PSLIB
)
if
(
NOT WITH_PSCORE
)
return
()
endif
()
if
(
NOT WITH_DISTRIBUTE
)
return
()
return
()
endif
()
endif
()
...
...
paddle/fluid/distributed/common/registerer.h
浏览文件 @
a97ca56a
...
@@ -69,24 +69,24 @@ class ObjectFactory {
...
@@ -69,24 +69,24 @@ class ObjectFactory {
};
};
typedef
std
::
map
<
std
::
string
,
ObjectFactory
*>
FactoryMap
;
typedef
std
::
map
<
std
::
string
,
ObjectFactory
*>
FactoryMap
;
typedef
std
::
map
<
std
::
string
,
FactoryMap
>
Bas
eClassMap
;
typedef
std
::
map
<
std
::
string
,
FactoryMap
>
PsCor
eClassMap
;
#ifdef __cplusplus
#ifdef __cplusplus
extern
"C"
{
extern
"C"
{
#endif
#endif
inline
Bas
eClassMap
&
global_factory_map
()
{
inline
PsCor
eClassMap
&
global_factory_map
()
{
static
BaseClassMap
*
base_class
=
new
Bas
eClassMap
();
static
PsCoreClassMap
*
base_class
=
new
PsCor
eClassMap
();
return
*
base_class
;
return
*
base_class
;
}
}
#ifdef __cplusplus
#ifdef __cplusplus
}
}
#endif
#endif
inline
Bas
eClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
}
inline
PsCor
eClassMap
&
global_factory_map_cpp
()
{
return
global_factory_map
();
}
// typedef pa::Any Any;
// typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap;
// typedef ::FactoryMap FactoryMap;
#define REGISTER_
REGISTERER(base_class)
\
#define REGISTER_
PSCORE_REGISTERER(base_class)
\
class base_class##Registerer { \
class base_class##Registerer { \
public: \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
static base_class *CreateInstanceByName(const ::std::string &name) { \
...
@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
...
@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
} \
};
};
#define REGISTER_
CLASS(clazz, name)
\
#define REGISTER_
PSCORE_CLASS(clazz, name)
\
class ObjectFactory##name : public ObjectFactory { \
class ObjectFactory##name : public ObjectFactory { \
public: \
public: \
Any NewInstance() { return Any(new name()); } \
Any NewInstance() { return Any(new name()); } \
...
@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
...
@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
} \
void register_factory_##name() __attribute__((constructor));
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_
PSCORE_
CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
base_class##Registerer::CreateInstanceByName(name);
}
// namespace distributed
}
// namespace distributed
...
...
paddle/fluid/distributed/ps.proto
浏览文件 @
a97ca56a
...
@@ -86,7 +86,7 @@ message SparseTableParameter {
...
@@ -86,7 +86,7 @@ message SparseTableParameter {
message
ServerServiceParameter
{
message
ServerServiceParameter
{
optional
string
server_class
=
1
[
default
=
"BrpcPsServer"
];
optional
string
server_class
=
1
[
default
=
"BrpcPsServer"
];
optional
string
client_class
=
2
[
default
=
"BrpcPsClient"
];
optional
string
client_class
=
2
[
default
=
"BrpcPsClient"
];
optional
string
service_class
=
3
[
default
=
"PsService"
];
optional
string
service_class
=
3
[
default
=
"
Brpc
PsService"
];
optional
uint32
start_server_port
=
4
optional
uint32
start_server_port
=
4
[
default
=
0
];
// will find a avaliable port from it
[
default
=
0
];
// will find a avaliable port from it
optional
uint32
server_thread_num
=
5
[
default
=
12
];
optional
uint32
server_thread_num
=
5
[
default
=
12
];
...
...
paddle/fluid/distributed/service/brpc_ps_client.cc
浏览文件 @
a97ca56a
...
@@ -17,8 +17,8 @@
...
@@ -17,8 +17,8 @@
#include <sstream>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "Eigen/Dense"
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/archive.h"
...
@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
...
@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void
DownpourPsClientService
::
service
(
void
DownpourPsClientService
::
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
::
paddle
::
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
ClosureGuard
done_guard
(
done
);
int
ret
=
_client
->
handle_client2client_msg
(
int
ret
=
_client
->
handle_client2client_msg
(
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
request
->
cmd_id
(),
request
->
client_id
(),
request
->
data
());
...
...
paddle/fluid/distributed/service/brpc_ps_client.h
浏览文件 @
a97ca56a
...
@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
...
@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return
0
;
return
0
;
}
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
const
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
::
google
::
protobuf
::
Closure
*
done
)
override
;
protected:
protected:
...
...
paddle/fluid/distributed/service/brpc_ps_server.cc
浏览文件 @
a97ca56a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include <thread> // NOLINT
#include "Eigen/Dense"
#include "Eigen/Dense"
#include "butil/endpoint.h"
#include "butil/endpoint.h"
...
@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() {
...
@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() {
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
LOG
(
ERROR
)
<<
"miss service_class in ServerServiceParameter"
;
return
-
1
;
return
-
1
;
}
}
auto
*
service
=
CREATE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
auto
*
service
=
CREATE_PSCORE_CLASS
(
PsBaseService
,
service_config
.
service_class
());
if
(
service
==
NULL
)
{
if
(
service
==
NULL
)
{
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
LOG
(
ERROR
)
<<
"service is unregistered, service_name:"
<<
service_config
.
service_class
();
<<
service_config
.
service_class
();
...
@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
...
@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
int32_t
BrpcPsServer
::
port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
BrpcPsServer
::
port
()
{
return
_server
.
listen_address
().
port
;
}
int32_t
PsService
::
initialize
()
{
int32_t
Brpc
PsService
::
initialize
()
{
_is_initialize_shard_info
=
false
;
_is_initialize_shard_info
=
false
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
PsService
::
stop_server
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
Brpc
PsService
::
stop_server
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
PsService
::
pull_dense
;
_service_handler_map
[
PS_PULL_DENSE_TABLE
]
=
&
Brpc
PsService
::
pull_dense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
PsService
::
push_dense
;
_service_handler_map
[
PS_PUSH_DENSE_TABLE
]
=
&
Brpc
PsService
::
push_dense
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
PsService
::
pull_sparse
;
_service_handler_map
[
PS_PULL_SPARSE_TABLE
]
=
&
Brpc
PsService
::
pull_sparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
PsService
::
push_sparse
;
_service_handler_map
[
PS_PUSH_SPARSE_TABLE
]
=
&
Brpc
PsService
::
push_sparse
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
PsService
::
save_one_table
;
_service_handler_map
[
PS_SAVE_ONE_TABLE
]
=
&
Brpc
PsService
::
save_one_table
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
PsService
::
save_all_table
;
_service_handler_map
[
PS_SAVE_ALL_TABLE
]
=
&
Brpc
PsService
::
save_all_table
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
PsService
::
shrink_table
;
_service_handler_map
[
PS_SHRINK_TABLE
]
=
&
Brpc
PsService
::
shrink_table
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
PsService
::
load_one_table
;
_service_handler_map
[
PS_LOAD_ONE_TABLE
]
=
&
Brpc
PsService
::
load_one_table
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
PsService
::
load_all_table
;
_service_handler_map
[
PS_LOAD_ALL_TABLE
]
=
&
Brpc
PsService
::
load_all_table
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
PsService
::
clear_one_table
;
_service_handler_map
[
PS_CLEAR_ONE_TABLE
]
=
&
Brpc
PsService
::
clear_one_table
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
PsService
::
clear_all_table
;
_service_handler_map
[
PS_CLEAR_ALL_TABLE
]
=
&
Brpc
PsService
::
clear_all_table
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
PsService
::
push_dense_param
;
_service_handler_map
[
PS_PUSH_DENSE_PARAM
]
=
&
Brpc
PsService
::
push_dense_param
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
PsService
::
print_table_stat
;
_service_handler_map
[
PS_PRINT_TABLE_STAT
]
=
&
Brpc
PsService
::
print_table_stat
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
PsService
::
pull_geo_param
;
_service_handler_map
[
PS_PULL_GEO_PARAM
]
=
&
Brpc
PsService
::
pull_geo_param
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
&
PsService
::
push_sparse_param
;
_service_handler_map
[
PS_PUSH_SPARSE_PARAM
]
=
_service_handler_map
[
PS_BARRIER
]
=
&
PsService
::
barrier
;
&
BrpcPsService
::
push_sparse_param
;
_service_handler_map
[
PS_
START_PROFILER
]
=
&
PsService
::
start_profil
er
;
_service_handler_map
[
PS_
BARRIER
]
=
&
BrpcPsService
::
barri
er
;
_service_handler_map
[
PS_ST
OP_PROFILER
]
=
&
PsService
::
stop
_profiler
;
_service_handler_map
[
PS_ST
ART_PROFILER
]
=
&
BrpcPsService
::
start
_profiler
;
_service_handler_map
[
PS_
PUSH_GLOBAL_STEP
]
=
&
PsService
::
push_global_step
;
_service_handler_map
[
PS_
STOP_PROFILER
]
=
&
BrpcPsService
::
stop_profiler
;
// shard初始化,server启动后才可从env获取到server_list的shard信息
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info
();
initialize_shard_info
();
...
@@ -116,7 +118,7 @@ int32_t PsService::initialize() {
...
@@ -116,7 +118,7 @@ int32_t PsService::initialize() {
return -1; \
return -1; \
}
}
int32_t
PsService
::
initialize_shard_info
()
{
int32_t
Brpc
PsService
::
initialize_shard_info
()
{
if
(
!
_is_initialize_shard_info
)
{
if
(
!
_is_initialize_shard_info
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
_initialize_shard_mutex
);
if
(
_is_initialize_shard_info
)
{
if
(
_is_initialize_shard_info
)
{
...
@@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() {
...
@@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() {
return
0
;
return
0
;
}
}
void
PsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
void
Brpc
PsService
::
service
(
google
::
protobuf
::
RpcController
*
cntl_base
,
const
PsRequestMessage
*
request
,
const
PsRequestMessage
*
request
,
PsResponseMessage
*
response
,
PsResponseMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
std
::
string
log_label
(
"ReceiveCmd-"
);
if
(
!
request
->
has_table_id
())
{
if
(
!
request
->
has_table_id
())
{
...
@@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base,
...
@@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base,
}
}
}
}
int32_t
PsService
::
pull_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
Brpc
PsService
::
pull_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_dense"
);
platform
::
RecordEvent
record_event
(
"PsService->pull_dense"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
if
(
request
.
params_size
()
<
1
)
{
...
@@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
...
@@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
push_dense_param
(
Table
*
table
,
int32_t
Brpc
PsService
::
push_dense_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_dense_param"
);
platform
::
RecordEvent
record_event
(
"PsService->push_dense_param"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_buffer
;
thread_local
std
::
string
push_buffer
;
...
@@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table,
...
@@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table,
return
0
;
return
0
;
}
}
int32_t
PsService
::
push_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
Brpc
PsService
::
push_dense
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_dense"
);
platform
::
RecordEvent
record_event
(
"PsService->push_dense"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
req_buffer_size
=
request
.
data
().
size
();
auto
req_buffer_size
=
request
.
data
().
size
();
...
@@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
...
@@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
Brpc
PsService
::
barrier
(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
1
)
{
if
(
request
.
params_size
()
<
1
)
{
...
@@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
...
@@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
push_sparse_param
(
Table
*
table
,
int32_t
Brpc
PsService
::
push_sparse_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_sparse_param"
);
platform
::
RecordEvent
record_event
(
"PsService->push_sparse_param"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
auto
&
push_data
=
request
.
data
();
...
@@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table,
...
@@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table,
return
0
;
return
0
;
}
}
int32_t
PsService
::
pull_geo_param
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
pull_geo_param
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_geo_param"
);
platform
::
RecordEvent
record_event
(
"PsService->pull_geo_param"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_sparse_request_buffer
;
thread_local
std
::
string
push_sparse_request_buffer
;
...
@@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
...
@@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
pull_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
pull_sparse
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->pull_sparse"
);
platform
::
RecordEvent
record_event
(
"PsService->pull_sparse"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
thread_local
std
::
string
push_sparse_request_buffer
;
thread_local
std
::
string
push_sparse_request_buffer
;
...
@@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
...
@@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
push_sparse
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
push_sparse
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
RecordEvent
record_event
(
"PsService->push_sparse"
);
platform
::
RecordEvent
record_event
(
"PsService->push_sparse"
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
auto
&
push_data
=
request
.
data
();
auto
&
push_data
=
request
.
data
();
...
@@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
...
@@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
print_table_stat
(
Table
*
table
,
int32_t
Brpc
PsService
::
print_table_stat
(
Table
*
table
,
const
PsRequestMessage
&
request
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
print_table_stat
();
std
::
pair
<
int64_t
,
int64_t
>
ret
=
table
->
print_table_stat
();
paddle
::
framework
::
BinaryArchive
ar
;
paddle
::
framework
::
BinaryArchive
ar
;
...
@@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table,
...
@@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table,
return
0
;
return
0
;
}
}
int32_t
PsService
::
load_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
load_one_table
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
set_response_code
(
...
@@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
...
@@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
load_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
load_all_table
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
auto
&
table_map
=
*
(
_server
->
table
());
for
(
auto
&
itr
:
table_map
)
{
for
(
auto
&
itr
:
table_map
)
{
if
(
load_one_table
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
if
(
load_one_table
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
...
@@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
...
@@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
save_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
save_one_table
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
if
(
request
.
params_size
()
<
2
)
{
if
(
request
.
params_size
()
<
2
)
{
set_response_code
(
set_response_code
(
...
@@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
...
@@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
return
feasign_size
;
return
feasign_size
;
}
}
int32_t
PsService
::
save_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
save_all_table
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
auto
&
table_map
=
*
(
_server
->
table
());
int32_t
all_feasign_size
=
0
;
int32_t
all_feasign_size
=
0
;
int32_t
feasign_size
=
0
;
int32_t
feasign_size
=
0
;
...
@@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
...
@@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
shrink_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
shrink_table
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
flush
();
table
->
flush
();
if
(
table
->
shrink
()
!=
0
)
{
if
(
table
->
shrink
()
!=
0
)
{
...
@@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
...
@@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
clear_one_table
(
Table
*
table
,
int32_t
Brpc
PsService
::
clear_one_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
CHECK_TABLE_EXIST
(
table
,
request
,
response
)
table
->
flush
();
table
->
flush
();
table
->
clear
();
table
->
clear
();
return
0
;
return
0
;
}
}
int32_t
PsService
::
clear_all_table
(
Table
*
table
,
int32_t
Brpc
PsService
::
clear_all_table
(
Table
*
table
,
const
PsRequestMessage
&
request
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
auto
&
table_map
=
*
(
_server
->
table
());
auto
&
table_map
=
*
(
_server
->
table
());
for
(
auto
&
itr
:
table_map
)
{
for
(
auto
&
itr
:
table_map
)
{
if
(
clear_one_table
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
if
(
clear_one_table
(
itr
.
second
.
get
(),
request
,
response
,
cntl
)
!=
0
)
{
...
@@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table,
...
@@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table,
return
0
;
return
0
;
}
}
int32_t
PsService
::
stop_server
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
stop_server
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
auto
*
p_server
=
_server
;
auto
*
p_server
=
_server
;
std
::
thread
t_stop
([
p_server
]()
{
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
stop
();
p_server
->
stop
();
...
@@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
...
@@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
return
0
;
return
0
;
}
}
int32_t
PsService
::
stop_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
stop_profiler
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
string
::
Sprintf
(
"server_%s_profile"
,
_rank
));
return
0
;
return
0
;
}
}
int32_t
PsService
::
start_profiler
(
Table
*
table
,
const
PsRequestMessage
&
request
,
int32_t
BrpcPsService
::
start_profiler
(
Table
*
table
,
PsResponseMessage
&
response
,
const
PsRequestMessage
&
request
,
brpc
::
Controller
*
cntl
)
{
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
return
0
;
return
0
;
}
}
int32_t
PsService
::
push_global_step
(
Table
*
table
,
int32_t
Brpc
PsService
::
push_global_step
(
Table
*
table
,
const
PsRequestMessage
&
request
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
brpc
::
Controller
*
cntl
)
{
CHECK_TABLE_EXIST
(
table
,
request
,
response
);
CHECK_TABLE_EXIST
(
table
,
request
,
response
);
auto
req_buffer_size
=
request
.
data
().
size
();
auto
req_buffer_size
=
request
.
data
().
size
();
if
(
req_buffer_size
<
1
)
{
if
(
req_buffer_size
<
1
)
{
...
...
paddle/fluid/distributed/service/brpc_ps_server.h
浏览文件 @
a97ca56a
...
@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
...
@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
std
::
vector
<
std
::
shared_ptr
<
brpc
::
Channel
>>
_pserver_channels
;
};
};
class
PsService
;
class
Brpc
PsService
;
typedef
int32_t
(
PsService
::*
serviceHandlerFunc
)(
typedef
int32_t
(
Brpc
PsService
::*
serviceHandlerFunc
)(
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
Table
*
table
,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
);
brpc
::
Controller
*
cntl
);
class
PsService
:
public
PsBaseService
{
class
Brpc
PsService
:
public
PsBaseService
{
public:
public:
virtual
int32_t
initialize
()
override
;
virtual
int32_t
initialize
()
override
;
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
const
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
;
::
google
::
protobuf
::
Closure
*
done
)
override
;
private:
private:
...
...
paddle/fluid/distributed/service/brpc_utils.cc
浏览文件 @
a97ca56a
...
@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
...
@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
const
platform
::
DeviceContext
&
ctx
,
VarMsg
*
var_msg
,
butil
::
IOBuf
*
iobuf
)
{
butil
::
IOBuf
*
iobuf
)
{
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
var_msg
->
set_type
(
::
paddle
::
LOD_TENSOR
);
var_msg
->
set_type
(
::
paddle
::
distributed
::
LOD_TENSOR
);
const
framework
::
LoD
lod
=
tensor
->
lod
();
const
framework
::
LoD
lod
=
tensor
->
lod
();
if
(
lod
.
size
()
>
0
)
{
if
(
lod
.
size
()
>
0
)
{
var_msg
->
set_lod_level
(
lod
.
size
());
var_msg
->
set_lod_level
(
lod
.
size
());
...
@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
...
@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
tensor
=
slr
->
mutable_value
();
auto
*
rows
=
slr
->
mutable_rows
();
auto
*
rows
=
slr
->
mutable_rows
();
var_msg
->
set_type
(
::
paddle
::
SELECTED_ROWS
);
var_msg
->
set_type
(
::
paddle
::
distributed
::
SELECTED_ROWS
);
var_msg
->
set_slr_height
(
slr
->
height
());
var_msg
->
set_slr_height
(
slr
->
height
());
auto
*
var_data
=
var_msg
->
mutable_data
();
auto
*
var_data
=
var_msg
->
mutable_data
();
...
@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
...
@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++
recv_var_index
)
{
++
recv_var_index
)
{
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
const
auto
&
msg
=
multi_msg
.
var_messages
(
recv_var_index
);
auto
*
var
=
scope
->
Var
(
msg
.
varname
());
auto
*
var
=
scope
->
Var
(
msg
.
varname
());
if
(
msg
.
type
()
==
::
paddle
::
LOD_TENSOR
)
{
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
SELECTED_ROWS
)
{
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
}
}
...
@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
...
@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
PADDLE_ENFORCE_NE
(
var
,
nullptr
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Not find variable %s in scope."
,
msg
.
varname
()));
"Not find variable %s in scope."
,
msg
.
varname
()));
if
(
msg
.
type
()
==
::
paddle
::
LOD_TENSOR
)
{
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
LOD_TENSOR
)
{
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
DeserializeLodTensor
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
else
if
(
msg
.
type
()
==
::
paddle
::
SELECTED_ROWS
)
{
}
else
if
(
msg
.
type
()
==
::
paddle
::
distributed
::
SELECTED_ROWS
)
{
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
DeserializeSelectedRows
(
var
,
msg
,
io_buffer_itr
,
ctx
);
}
}
}
}
...
...
paddle/fluid/distributed/service/brpc_utils.h
浏览文件 @
a97ca56a
...
@@ -44,8 +44,8 @@ class DeviceContext;
...
@@ -44,8 +44,8 @@ class DeviceContext;
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
void
SerializeToMultiVarMsgAndIOBuf
(
void
SerializeToMultiVarMsgAndIOBuf
(
const
std
::
string
&
message_name
,
const
std
::
string
&
message_name
,
...
...
paddle/fluid/distributed/service/heter_client.cc
浏览文件 @
a97ca56a
...
@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
...
@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
cntl
.
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
distributed
::
MultiVarMsg
request
,
response
;
distributed
::
MultiVarMsg
request
,
response
;
auto
&
request_io_buffer
=
cntl
.
request_attachment
();
auto
&
request_io_buffer
=
cntl
.
request_attachment
();
::
paddle
::
PsService_Stub
stub
(
xpu_channels_
[
num
].
get
());
::
paddle
::
distributed
::
PsService_Stub
stub
(
xpu_channels_
[
num
].
get
());
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
distributed
::
SerializeToMultiVarMsgAndIOBuf
(
message_name_val
,
send_var_name_val
,
recv_var_name_val
,
*
p_ctx
,
p_scope
,
message_name_val
,
send_var_name_val
,
recv_var_name_val
,
*
p_ctx
,
p_scope
,
&
request
,
&
request_io_buffer
);
&
request
,
&
request_io_buffer
);
...
@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
...
@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for
(
const
auto
&
param
:
params
)
{
for
(
const
auto
&
param
:
params
)
{
closure
->
request
(
i
)
->
add_params
(
param
);
closure
->
request
(
i
)
->
add_params
(
param
);
}
}
::
paddle
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
::
paddle
::
distributed
::
PsService_Stub
rpc_stub
(
xpu_channels_
[
i
].
get
());
closure
->
cntl
(
i
)
->
set_timeout_ms
(
closure
->
cntl
(
i
)
->
set_timeout_ms
(
FLAGS_pserver_timeout_ms
);
// cmd msg don't limit timeout for save/load
FLAGS_pserver_timeout_ms
);
// cmd msg don't limit timeout for save/load
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
rpc_stub
.
service
(
closure
->
cntl
(
i
),
closure
->
request
(
i
),
...
...
paddle/fluid/distributed/service/heter_client.h
浏览文件 @
a97ca56a
...
@@ -35,8 +35,8 @@ limitations under the License. */
...
@@ -35,8 +35,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
HeterRpcCallbackFunc
;
typedef
std
::
function
<
void
(
void
*
)
>
HeterRpcCallbackFunc
;
...
...
paddle/fluid/distributed/service/heter_server.h
浏览文件 @
a97ca56a
...
@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
...
@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
class
HeterService
;
class
HeterService
;
typedef
int32_t
(
HeterService
::*
serviceHandlerFunc
)(
typedef
int32_t
(
HeterService
::*
serviceHandlerFunc
)(
...
@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
...
@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef
std
::
function
<
int
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
typedef
std
::
function
<
int
(
const
MultiVarMsg
*
,
MultiVarMsg
*
,
brpc
::
Controller
*
)
>
HeterServiceHandler
;
HeterServiceHandler
;
class
HeterService
:
public
::
paddle
::
PsService
{
class
HeterService
:
public
::
paddle
::
distributed
::
PsService
{
public:
public:
HeterService
()
{
HeterService
()
{
_service_handler_map
[
PS_STOP_SERVER
]
=
&
HeterService
::
stop_heter_worker
;
_service_handler_map
[
PS_STOP_SERVER
]
=
&
HeterService
::
stop_heter_worker
;
...
@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
...
@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual
~
HeterService
()
{}
virtual
~
HeterService
()
{}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
const
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
{
::
google
::
protobuf
::
Closure
*
done
)
{
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
ClosureGuard
done_guard
(
done
);
std
::
string
log_label
(
"ReceiveCmd-"
);
std
::
string
log_label
(
"ReceiveCmd-"
);
...
...
paddle/fluid/distributed/service/ps_client.cc
浏览文件 @
a97ca56a
...
@@ -13,9 +13,7 @@
...
@@ -13,9 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include <map>
#include "brpc/server.h"
#include "brpc/server.h"
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
...
@@ -23,7 +21,7 @@
...
@@ -23,7 +21,7 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
REGISTER_CLASS
(
PSClient
,
BrpcPsClient
);
REGISTER_
PSCORE_
CLASS
(
PSClient
,
BrpcPsClient
);
int32_t
PSClient
::
configure
(
int32_t
PSClient
::
configure
(
const
PSParameter
&
config
,
const
PSParameter
&
config
,
...
@@ -43,7 +41,7 @@ int32_t PSClient::configure(
...
@@ -43,7 +41,7 @@ int32_t PSClient::configure(
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
const
auto
&
work_param
=
_config
.
worker_param
().
downpour_worker_param
();
for
(
size_t
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
work_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
accessor
=
CREATE_CLASS
(
auto
*
accessor
=
CREATE_
PSCORE_
CLASS
(
ValueAccessor
,
ValueAccessor
,
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
work_param
.
downpour_table_param
(
i
).
accessor
().
accessor_class
());
accessor
->
configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
accessor
->
configure
(
work_param
.
downpour_table_param
(
i
).
accessor
());
...
@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
...
@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
}
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSClient
*
client
=
CREATE_CLASS
(
PSClient
,
service_param
.
client_class
());
PSClient
*
client
=
CREATE_PSCORE_CLASS
(
PSClient
,
service_param
.
client_class
());
if
(
client
==
NULL
)
{
if
(
client
==
NULL
)
{
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
LOG
(
ERROR
)
<<
"client is not registered, server_name:"
<<
service_param
.
client_class
();
<<
service_param
.
client_class
();
...
...
paddle/fluid/distributed/service/ps_client.h
浏览文件 @
a97ca56a
...
@@ -28,6 +28,9 @@
...
@@ -28,6 +28,9 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
typedef
std
::
function
<
void
(
void
*
)
>
PSClientCallBack
;
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
class
PSClientClosure
:
public
google
::
protobuf
::
Closure
{
public:
public:
...
@@ -206,7 +209,7 @@ class PSClient {
...
@@ -206,7 +209,7 @@ class PSClient {
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
std
::
unordered_map
<
int32_t
,
MsgHandlerFunc
>
_msg_handler_map
;
//处理client2client消息
_msg_handler_map
;
//处理client2client消息
};
};
REGISTER_REGISTERER
(
PSClient
);
REGISTER_
PSCORE_
REGISTERER
(
PSClient
);
class
PSClientFactory
{
class
PSClientFactory
{
public:
public:
...
...
paddle/fluid/distributed/service/sendrecv.proto
浏览文件 @
a97ca56a
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
syntax
=
"proto2"
;
syntax
=
"proto2"
;
package
paddle
;
package
paddle
.
distributed
;
option
cc_generic_services
=
true
;
option
cc_generic_services
=
true
;
option
cc_enable_arenas
=
true
;
option
cc_enable_arenas
=
true
;
...
...
paddle/fluid/distributed/service/server.cc
浏览文件 @
a97ca56a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/distributed/table/table.h"
...
@@ -20,8 +21,8 @@
...
@@ -20,8 +21,8 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
REGISTER_CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_
PSCORE_
CLASS
(
PSServer
,
BrpcPsServer
);
REGISTER_
CLASS
(
PsBaseService
,
PsService
);
REGISTER_
PSCORE_CLASS
(
PsBaseService
,
Brpc
PsService
);
PSServer
*
PSServerFactory
::
create
(
const
PSParameter
&
ps_config
)
{
PSServer
*
PSServerFactory
::
create
(
const
PSParameter
&
ps_config
)
{
const
auto
&
config
=
ps_config
.
server_param
();
const
auto
&
config
=
ps_config
.
server_param
();
...
@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
...
@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
}
}
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
const
auto
&
service_param
=
config
.
downpour_server_param
().
service_param
();
PSServer
*
server
=
CREATE_CLASS
(
PSServer
,
service_param
.
server_class
());
PSServer
*
server
=
CREATE_PSCORE_CLASS
(
PSServer
,
service_param
.
server_class
());
if
(
server
==
NULL
)
{
if
(
server
==
NULL
)
{
LOG
(
ERROR
)
<<
"server is not registered, server_name:"
LOG
(
ERROR
)
<<
"server is not registered, server_name:"
<<
service_param
.
server_class
();
<<
service_param
.
server_class
();
...
@@ -70,7 +72,7 @@ int32_t PSServer::configure(
...
@@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t
global_step_table
=
UINT32_MAX
;
uint32_t
global_step_table
=
UINT32_MAX
;
for
(
size_t
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
downpour_param
.
downpour_table_param_size
();
++
i
)
{
auto
*
table
=
CREATE_CLASS
(
auto
*
table
=
CREATE_
PSCORE_
CLASS
(
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
Table
,
downpour_param
.
downpour_table_param
(
i
).
table_class
());
if
(
downpour_param
.
downpour_table_param
(
i
).
table_class
()
==
if
(
downpour_param
.
downpour_table_param
(
i
).
table_class
()
==
...
...
paddle/fluid/distributed/service/server.h
浏览文件 @
a97ca56a
...
@@ -46,6 +46,8 @@ namespace paddle {
...
@@ -46,6 +46,8 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
class
Table
;
class
Table
;
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
class
PSServer
{
class
PSServer
{
public:
public:
...
@@ -107,7 +109,7 @@ class PSServer {
...
@@ -107,7 +109,7 @@ class PSServer {
platform
::
Place
place_
=
platform
::
CPUPlace
();
platform
::
Place
place_
=
platform
::
CPUPlace
();
};
};
REGISTER_REGISTERER
(
PSServer
);
REGISTER_
PSCORE_
REGISTERER
(
PSServer
);
typedef
std
::
function
<
void
(
void
*
)
>
PServerCallBack
;
typedef
std
::
function
<
void
(
void
*
)
>
PServerCallBack
;
...
@@ -141,8 +143,8 @@ class PsBaseService : public PsService {
...
@@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return
0
;
return
0
;
}
}
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
virtual
void
service
(
::
google
::
protobuf
::
RpcController
*
controller
,
const
::
paddle
::
PsRequestMessage
*
request
,
const
PsRequestMessage
*
request
,
::
paddle
::
PsResponseMessage
*
response
,
PsResponseMessage
*
response
,
::
google
::
protobuf
::
Closure
*
done
)
override
=
0
;
::
google
::
protobuf
::
Closure
*
done
)
override
=
0
;
virtual
void
set_response_code
(
PsResponseMessage
&
response
,
int
err_code
,
virtual
void
set_response_code
(
PsResponseMessage
&
response
,
int
err_code
,
...
@@ -159,7 +161,7 @@ class PsBaseService : public PsService {
...
@@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer
*
_server
;
PSServer
*
_server
;
const
ServerParameter
*
_config
;
const
ServerParameter
*
_config
;
};
};
REGISTER_REGISTERER
(
PsBaseService
);
REGISTER_
PSCORE_
REGISTERER
(
PsBaseService
);
class
PSServerFactory
{
class
PSServerFactory
{
public:
public:
...
...
paddle/fluid/distributed/service/service.h
浏览文件 @
a97ca56a
...
@@ -28,6 +28,10 @@ limitations under the License. */
...
@@ -28,6 +28,10 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
using
paddle
::
distributed
::
PsRequestMessage
;
using
paddle
::
distributed
::
PsResponseMessage
;
using
paddle
::
distributed
::
PsService
;
class
PSCore
{
class
PSCore
{
public:
public:
explicit
PSCore
()
{}
explicit
PSCore
()
{}
...
...
paddle/fluid/distributed/table/accessor.h
浏览文件 @
a97ca56a
...
@@ -165,6 +165,6 @@ class ValueAccessor {
...
@@ -165,6 +165,6 @@ class ValueAccessor {
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
struct
DataConverter
>>
_data_coverter_map
;
_data_coverter_map
;
};
};
REGISTER_REGISTERER
(
ValueAccessor
);
REGISTER_
PSCORE_
REGISTERER
(
ValueAccessor
);
}
// namespace distributed
}
// namespace distributed
}
// namespace paddle
}
// namespace paddle
paddle/fluid/distributed/table/table.cc
浏览文件 @
a97ca56a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h"
#include "glog/logging.h"
...
@@ -27,14 +28,14 @@
...
@@ -27,14 +28,14 @@
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
REGISTER_CLASS
(
Table
,
CommonDenseTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
CommonDenseTable
);
REGISTER_CLASS
(
Table
,
CommonSparseTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
CommonSparseTable
);
REGISTER_CLASS
(
Table
,
SparseGeoTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
SparseGeoTable
);
REGISTER_CLASS
(
Table
,
BarrierTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
BarrierTable
);
REGISTER_CLASS
(
Table
,
TensorTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
TensorTable
);
REGISTER_CLASS
(
Table
,
DenseTensorTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
DenseTensorTable
);
REGISTER_CLASS
(
Table
,
GlobalStepTable
);
REGISTER_
PSCORE_
CLASS
(
Table
,
GlobalStepTable
);
REGISTER_CLASS
(
ValueAccessor
,
CommMergeAccessor
);
REGISTER_
PSCORE_
CLASS
(
ValueAccessor
,
CommMergeAccessor
);
int32_t
TableManager
::
initialize
()
{
int32_t
TableManager
::
initialize
()
{
static
bool
initialized
=
false
;
static
bool
initialized
=
false
;
...
@@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() {
...
@@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() {
<<
_config
.
table_id
();
<<
_config
.
table_id
();
return
-
1
;
return
-
1
;
}
}
auto
*
accessor
=
auto
*
accessor
=
CREATE_PSCORE_CLASS
(
CREATE_CLASS
(
ValueAccessor
,
ValueAccessor
,
_config
.
accessor
().
accessor_class
())
if
(
accessor
==
NULL
)
{
_config
.
accessor
().
accessor_class
())
if
(
accessor
==
NULL
)
{
LOG
(
ERROR
)
<<
"accessor is unregisteg, table_id:"
<<
_config
.
table_id
()
LOG
(
ERROR
)
<<
"accessor is unregisteg, table_id:"
<<
_config
.
table_id
()
<<
", accessor_name:"
<<
_config
.
accessor
().
accessor_class
();
<<
", accessor_name:"
<<
_config
.
accessor
().
accessor_class
();
return
-
1
;
return
-
1
;
...
...
paddle/fluid/distributed/table/table.h
浏览文件 @
a97ca56a
...
@@ -127,7 +127,7 @@ class Table {
...
@@ -127,7 +127,7 @@ class Table {
float
*
_global_lr
=
nullptr
;
float
*
_global_lr
=
nullptr
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
std
::
shared_ptr
<
ValueAccessor
>
_value_accesor
;
};
};
REGISTER_REGISTERER
(
Table
);
REGISTER_
PSCORE_
REGISTERER
(
Table
);
class
TableManager
{
class
TableManager
{
public:
public:
...
...
paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc
浏览文件 @
a97ca56a
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <unistd.h>
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <condition_variable> // NOLINT
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
...
@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
...
@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto
->
mutable_downpour_server_param
();
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_start_server_port
(
0
);
...
@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
...
@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto
->
mutable_downpour_server_param
();
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_start_server_port
(
0
);
...
@@ -244,7 +245,8 @@ void RunBrpcPushDense() {
...
@@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int
ret
=
0
;
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_DENSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
ret
=
-
1
;
break
;
break
;
}
}
...
...
paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc
浏览文件 @
a97ca56a
...
@@ -18,6 +18,7 @@ limitations under the License. */
...
@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto(
...
@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto(
server_proto
->
mutable_downpour_server_param
();
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_start_server_port
(
0
);
...
@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto(
...
@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto(
server_proto
->
mutable_downpour_server_param
();
server_proto
->
mutable_downpour_server_param
();
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
::
paddle
::
distributed
::
ServerServiceParameter
*
server_service_proto
=
downpour_server_proto
->
mutable_service_param
();
downpour_server_proto
->
mutable_service_param
();
server_service_proto
->
set_service_class
(
"PsService"
);
server_service_proto
->
set_service_class
(
"
Brpc
PsService"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_server_class
(
"BrpcPsServer"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_client_class
(
"BrpcPsClient"
);
server_service_proto
->
set_start_server_port
(
0
);
server_service_proto
->
set_start_server_port
(
0
);
...
@@ -225,7 +226,8 @@ void RunBrpcPushSparse() {
...
@@ -225,7 +226,8 @@ void RunBrpcPushSparse() {
int
ret
=
0
;
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_SPARSE_PARAM
)
!=
0
)
{
ret
=
-
1
;
ret
=
-
1
;
break
;
break
;
}
}
...
@@ -252,7 +254,8 @@ void RunBrpcPushSparse() {
...
@@ -252,7 +254,8 @@ void RunBrpcPushSparse() {
int
ret
=
0
;
int
ret
=
0
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
auto
*
closure
=
(
paddle
::
distributed
::
DownpourBrpcClosure
*
)
done
;
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
1
;
++
i
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
if
(
closure
->
check_response
(
i
,
paddle
::
distributed
::
PS_PUSH_SPARSE_TABLE
)
!=
0
)
{
ret
=
-
1
;
ret
=
-
1
;
break
;
break
;
}
}
...
...
paddle/fluid/distributed/test/brpc_utils_test.cc
浏览文件 @
a97ca56a
...
@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) {
...
@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) {
auto
&
ctx
=
*
pool
.
Get
(
place
);
auto
&
ctx
=
*
pool
.
Get
(
place
);
CreateVarsOnScope
(
&
scope
,
&
place
,
ctx
);
CreateVarsOnScope
(
&
scope
,
&
place
,
ctx
);
::
paddle
::
MultiVariableMessage
multi_msg
;
::
paddle
::
distributed
::
MultiVariableMessage
multi_msg
;
std
::
string
message_name
(
"se_de_test"
);
std
::
string
message_name
(
"se_de_test"
);
std
::
vector
<
std
::
string
>
send_var_name
=
{
"x1"
,
"x2"
,
"x3"
};
std
::
vector
<
std
::
string
>
send_var_name
=
{
"x1"
,
"x2"
,
"x3"
};
std
::
vector
<
std
::
string
>
recv_var_name
=
{};
std
::
vector
<
std
::
string
>
recv_var_name
=
{};
...
@@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) {
...
@@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) {
// platform::CUDAPlace place;
// platform::CUDAPlace place;
// RunMultiVarMsg(place);
// RunMultiVarMsg(place);
// }
// }
// #endif
// #endif
\ No newline at end of file
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
a97ca56a
...
@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE)
...
@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method
sendrecvop_rpc communicator
collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
lod_rank_table feed_fetch_method collective_helper
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper data_feed_proto timer monitor
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto pslib_brpc
)
heter_service_proto pslib_brpc
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
(
)
else
if
(
WITH_PSCORE
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
heterxpu_trainer.cc
...
@@ -230,6 +230,16 @@ if(WITH_DISTRIBUTE)
...
@@ -230,6 +230,16 @@ if(WITH_DISTRIBUTE)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
multi_trainer.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
multi_trainer.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
hogwild_worker.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
hogwild_worker.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor
)
endif
()
endif
()
elseif
(
WITH_PSLIB
)
elseif
(
WITH_PSLIB
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
...
@@ -241,7 +251,6 @@ elseif(WITH_PSLIB)
...
@@ -241,7 +251,6 @@ elseif(WITH_PSLIB)
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc
)
graph_to_program_pass variable_helper timer monitor pslib_brpc
)
else
()
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
cc_library
(
executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
...
...
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
a97ca56a
...
@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
...
@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
cc_library
(
variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
reduce_op_handle.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
reduce_op_handle.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
...
...
paddle/fluid/framework/details/async_ssa_graph_executor.cc
浏览文件 @
a97ca56a
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/communicator.h"
#endif
#endif
...
@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
...
@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
"results to be fetched!"
));
"results to be fetched!"
));
// init once
// init once
if
(
run_futures_
.
size
()
==
0
&&
places_
.
size
()
>
1
)
{
if
(
run_futures_
.
size
()
==
0
&&
places_
.
size
()
>
1
)
{
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
strategy_
.
thread_barrier_
)
{
if
(
strategy_
.
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
places_
.
size
());
places_
.
size
());
...
...
paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
浏览文件 @
a97ca56a
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/communicator.h"
#endif
#endif
...
@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
...
@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void
ThreadedSSAGraphExecutor
::
ExecutionFinal
(
void
ThreadedSSAGraphExecutor
::
ExecutionFinal
(
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
std
::
vector
<
OpHandleBase
*>
*
fetch_ops
)
{
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
strategy_
.
thread_barrier_
)
{
if
(
strategy_
.
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
}
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
a97ca56a
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/communicator.h"
#endif
#endif
...
@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
...
@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
writer_
.
Flush
();
writer_
.
Flush
();
}
}
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
thread_barrier_
)
{
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
}
...
@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() {
...
@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars
();
PrintFetchVars
();
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
}
}
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
thread_barrier_
)
{
if
(
thread_barrier_
)
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerDecrement
();
}
}
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
a97ca56a
...
@@ -18,7 +18,7 @@ limitations under the License. */
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer.h"
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/distributed/service/communicator.h"
#include "paddle/fluid/distributed/service/communicator.h"
#endif
#endif
...
@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
if
(
trainer_desc
.
thread_barrier
())
{
if
(
trainer_desc
.
thread_barrier
())
{
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
paddle
::
distributed
::
Communicator
::
GetInstance
()
->
BarrierTriggerReset
(
thread_num_
);
thread_num_
);
...
...
paddle/fluid/inference/CMakeLists.txt
浏览文件 @
a97ca56a
...
@@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS
...
@@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS
${
mkldnn_quantizer_src_file
}
)
${
mkldnn_quantizer_src_file
}
)
# Create shared inference library defaultly
# Create shared inference library defaultly
if
(
NOT WITH_
DISTRIBUT
E
)
if
(
NOT WITH_
PSCOR
E
)
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
analysis_predictor
)
DEPS
${
fluid_modules
}
analysis_predictor
)
else
()
else
()
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
cc_library
(
paddle_fluid_shared SHARED SRCS
${
SHARED_INFERENCE_SRCS
}
DEPS
${
fluid_modules
}
analysis_predictor fleet ps_service
)
DEPS
${
fluid_modules
}
analysis_predictor fleet ps_service
)
endif
()
endif
()
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
a97ca56a
...
@@ -22,10 +22,13 @@ add_subdirectory(jit)
...
@@ -22,10 +22,13 @@ add_subdirectory(jit)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
add_subdirectory
(
pscore
)
add_subdirectory
(
collective
)
add_subdirectory
(
collective
)
endif
()
endif
()
if
(
WITH_PSCORE
)
add_subdirectory
(
pscore
)
endif
()
add_subdirectory
(
amp
)
add_subdirectory
(
amp
)
add_subdirectory
(
reader
)
add_subdirectory
(
reader
)
...
...
paddle/fluid/operators/pscore/CMakeLists.txt
浏览文件 @
a97ca56a
if
(
WITH_PSLIB
)
return
()
endif
()
include
(
operators
)
include
(
operators
)
set
(
DISTRIBUTE_DEPS
""
)
set
(
DISTRIBUTE_DEPS
""
)
...
...
paddle/fluid/operators/pscore/heter_listen_and_serv_op.h
浏览文件 @
a97ca56a
...
@@ -46,8 +46,8 @@ class DeviceContext;
...
@@ -46,8 +46,8 @@ class DeviceContext;
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
template
<
class
TKey
,
class
TValue
>
template
<
class
TKey
,
class
TValue
>
class
DoubleFindMap
:
public
std
::
unordered_map
<
TKey
,
TValue
>
{
class
DoubleFindMap
:
public
std
::
unordered_map
<
TKey
,
TValue
>
{
...
...
paddle/fluid/operators/pscore/heter_listen_and_server_test.cc
浏览文件 @
a97ca56a
...
@@ -36,8 +36,8 @@ namespace framework = paddle::framework;
...
@@ -36,8 +36,8 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
distributed
;
namespace
distributed
=
paddle
::
distributed
;
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
DECLARE_double
(
eager_delete_tensor_gb
);
DECLARE_double
(
eager_delete_tensor_gb
);
USE_OP
(
scale
);
USE_OP
(
scale
);
...
...
paddle/fluid/operators/pscore/heter_server_test.cc
浏览文件 @
a97ca56a
...
@@ -32,8 +32,8 @@ namespace framework = paddle::framework;
...
@@ -32,8 +32,8 @@ namespace framework = paddle::framework;
namespace
platform
=
paddle
::
platform
;
namespace
platform
=
paddle
::
platform
;
namespace
distributed
=
paddle
::
distributed
;
namespace
distributed
=
paddle
::
distributed
;
using
MultiVarMsg
=
::
paddle
::
MultiVariableMessage
;
using
MultiVarMsg
=
::
paddle
::
distributed
::
MultiVariableMessage
;
using
VarMsg
=
::
paddle
::
VariableMessage
;
using
VarMsg
=
::
paddle
::
distributed
::
VariableMessage
;
USE_OP
(
scale
);
USE_OP
(
scale
);
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
a97ca56a
...
@@ -49,7 +49,7 @@ if (WITH_CRYPTO)
...
@@ -49,7 +49,7 @@ if (WITH_CRYPTO)
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
crypto.cc
)
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
crypto.cc
)
endif
(
WITH_CRYPTO
)
endif
(
WITH_CRYPTO
)
if
(
WITH_
DISTRIBUT
E
)
if
(
WITH_
PSCOR
E
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result"
)
set_source_files_properties
(
fleet_py.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
fleet_py.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
list
(
APPEND PYBIND_DEPS fleet communicator
)
list
(
APPEND PYBIND_DEPS fleet communicator
)
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
a97ca56a
...
@@ -106,7 +106,7 @@ limitations under the License. */
...
@@ -106,7 +106,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/crypto.h"
#include "paddle/fluid/pybind/crypto.h"
#endif
#endif
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
#include "paddle/fluid/pybind/fleet_py.h"
#include "paddle/fluid/pybind/fleet_py.h"
#endif
#endif
...
@@ -2833,7 +2833,7 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2833,7 +2833,7 @@ All parameter, weight, gradient are variables in Paddle.
BindCrypto
(
&
m
);
BindCrypto
(
&
m
);
#endif
#endif
#if
def PADDLE_WITH_DISTRIBUT
E
#if
defined PADDLE_WITH_PSCOR
E
BindDistFleetWrapper
(
&
m
);
BindDistFleetWrapper
(
&
m
);
BindPSHost
(
&
m
);
BindPSHost
(
&
m
);
BindCommunicatorContext
(
&
m
);
BindCommunicatorContext
(
&
m
);
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
a97ca56a
...
@@ -236,7 +236,8 @@ function cmake_base() {
...
@@ -236,7 +236,8 @@ function cmake_base() {
-DPY_VERSION=
${
PY_VERSION
:-
2
.7
}
-DPY_VERSION=
${
PY_VERSION
:-
2
.7
}
-DCMAKE_INSTALL_PREFIX=
${
INSTALL_PREFIX
:-
/paddle/build
}
-DCMAKE_INSTALL_PREFIX=
${
INSTALL_PREFIX
:-
/paddle/build
}
-DWITH_GRPC=
${
grpc_flag
}
-DWITH_GRPC=
${
grpc_flag
}
-DWITH_GLOO=
${
gloo_flag
}
-DWITH_PSCORE=
${
distibuted_flag
}
-DWITH_GLOO=
${
gloo_flag
}
-DWITH_LITE=
${
WITH_LITE
:-
OFF
}
-DWITH_LITE=
${
WITH_LITE
:-
OFF
}
-DWITH_XPU=
${
WITH_XPU
:-
OFF
}
-DWITH_XPU=
${
WITH_XPU
:-
OFF
}
-DLITE_GIT_TAG=develop
-DLITE_GIT_TAG=develop
...
@@ -269,7 +270,8 @@ EOF
...
@@ -269,7 +270,8 @@ EOF
-DPY_VERSION
=
${
PY_VERSION
:-
2
.7
}
\
-DPY_VERSION
=
${
PY_VERSION
:-
2
.7
}
\
-DCMAKE_INSTALL_PREFIX
=
${
INSTALL_PREFIX
:-
/paddle/build
}
\
-DCMAKE_INSTALL_PREFIX
=
${
INSTALL_PREFIX
:-
/paddle/build
}
\
-DWITH_GRPC
=
${
grpc_flag
}
\
-DWITH_GRPC
=
${
grpc_flag
}
\
-DWITH_GLOO
=
${
gloo_flag
}
\
-DWITH_PSCORE
=
${
distibuted_flag
}
\
-DWITH_GLOO
=
${
gloo_flag
}
\
-DLITE_GIT_TAG
=
develop
\
-DLITE_GIT_TAG
=
develop
\
-DWITH_XPU
=
${
WITH_XPU
:-
OFF
}
\
-DWITH_XPU
=
${
WITH_XPU
:-
OFF
}
\
-DWITH_LITE
=
${
WITH_LITE
:-
OFF
}
;
build_error
=
$?
-DWITH_LITE
=
${
WITH_LITE
:-
OFF
}
;
build_error
=
$?
...
...
paddle/testing/paddle_gtest_main.cc
浏览文件 @
a97ca56a
...
@@ -59,7 +59,8 @@ int main(int argc, char** argv) {
...
@@ -59,7 +59,8 @@ int main(int argc, char** argv) {
std
::
vector
<
std
::
string
>
envs
;
std
::
vector
<
std
::
string
>
envs
;
std
::
vector
<
std
::
string
>
undefok
;
std
::
vector
<
std
::
string
>
undefok
;
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC)
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) && \
!defined(PADDLE_WITH_PSLIB)
std
::
string
str_max_body_size
;
std
::
string
str_max_body_size
;
if
(
google
::
GetCommandLineOption
(
"max_body_size"
,
&
str_max_body_size
))
{
if
(
google
::
GetCommandLineOption
(
"max_body_size"
,
&
str_max_body_size
))
{
setenv
(
"FLAGS_max_body_size"
,
"2147483647"
,
1
);
setenv
(
"FLAGS_max_body_size"
,
"2147483647"
,
1
);
...
...
python/paddle/distributed/fleet/runtime/the_one_ps.py
浏览文件 @
a97ca56a
...
@@ -268,7 +268,7 @@ class Service:
...
@@ -268,7 +268,7 @@ class Service:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
server_class
=
"BrpcPsServer"
self
.
server_class
=
"BrpcPsServer"
self
.
client_class
=
"BrpcPsClient"
self
.
client_class
=
"BrpcPsClient"
self
.
service_class
=
"PsService"
self
.
service_class
=
"
Brpc
PsService"
self
.
start_server_port
=
0
self
.
start_server_port
=
0
self
.
server_thread_num
=
12
self
.
server_thread_num
=
12
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录