Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f81957a7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f81957a7
编写于
12月 13, 2018
作者:
H
heqiaozhi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine cmake for pslib & pre_define
上级
2912d531
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
73 addition
and
22 deletion
+73
-22
CMakeLists.txt
CMakeLists.txt
+1
-1
cmake/configure.cmake
cmake/configure.cmake
+4
-0
cmake/external/libmct.cmake
cmake/external/libmct.cmake
+7
-6
cmake/external/pslib_brpc.cmake
cmake/external/pslib_brpc.cmake
+8
-7
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+6
-1
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+14
-0
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+7
-4
paddle/fluid/framework/executor_thread_worker.cc
paddle/fluid/framework/executor_thread_worker.cc
+5
-1
paddle/fluid/framework/executor_thread_worker.h
paddle/fluid/framework/executor_thread_worker.h
+10
-2
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+11
-0
未找到文件。
CMakeLists.txt
浏览文件 @
f81957a7
...
@@ -222,7 +222,7 @@ if(WITH_PSLIB)
...
@@ -222,7 +222,7 @@ if(WITH_PSLIB)
include
(
external/libmct
)
include
(
external/libmct
)
include
(
external/pslib_brpc
)
include
(
external/pslib_brpc
)
include
(
external/pslib
)
include
(
external/pslib
)
endif
()
endif
(
WITH_PSLIB
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_DISTRIBUTE
)
if
(
WITH_GRPC
)
if
(
WITH_GRPC
)
...
...
cmake/configure.cmake
浏览文件 @
f81957a7
...
@@ -84,6 +84,10 @@ if(NOT WITH_GOLANG)
...
@@ -84,6 +84,10 @@ if(NOT WITH_GOLANG)
add_definitions
(
-DPADDLE_WITHOUT_GOLANG
)
add_definitions
(
-DPADDLE_WITHOUT_GOLANG
)
endif
(
NOT WITH_GOLANG
)
endif
(
NOT WITH_GOLANG
)
if
(
WITH_PSLIB
)
add_definitions
(
-DPADDLE_WITH_PSLIB
)
endif
()
if
(
WITH_GPU
)
if
(
WITH_GPU
)
add_definitions
(
-DPADDLE_WITH_CUDA
)
add_definitions
(
-DPADDLE_WITH_CUDA
)
...
...
cmake/external/libmct.cmake
浏览文件 @
f81957a7
...
@@ -29,10 +29,11 @@ INCLUDE(ExternalProject)
...
@@ -29,10 +29,11 @@ INCLUDE(ExternalProject)
SET
(
LIBMCT_PROJECT
"extern_libmct"
)
SET
(
LIBMCT_PROJECT
"extern_libmct"
)
IF
((
NOT DEFINED LIBMCT_VER
)
OR
(
NOT DEFINED LIBMCT_URL
))
IF
((
NOT DEFINED LIBMCT_VER
)
OR
(
NOT DEFINED LIBMCT_URL
))
MESSAGE
(
STATUS
"use pre defined download url"
)
MESSAGE
(
STATUS
"use pre defined download url"
)
SET
(
LIBMCT_VER
"libmct"
CACHE STRING
""
FORCE
)
#todo libmct version
SET
(
LIBMCT_VER
"0.1.0"
CACHE STRING
""
FORCE
)
SET
(
LIBMCT_URL
"http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/
${
LIBMCT_VER
}
.tar.gz"
CACHE STRING
""
FORCE
)
#todo libmct url
SET
(
LIBMCT_NAME
"libmct"
CACHE STRING
""
FORCE
)
SET
(
LIBMCT_URL
"https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/
${
LIBMCT_VER
}
/
${
LIBMCT_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
ENDIF
()
ENDIF
()
MESSAGE
(
STATUS
"LIBMCT_
VER:
${
LIBMCT_VER
}
, LIBMCT_URL:
${
LIBMCT_URL
}
"
)
MESSAGE
(
STATUS
"LIBMCT_
NAME:
${
LIBMCT_NAME
}
, LIBMCT_URL:
${
LIBMCT_URL
}
"
)
SET
(
LIBMCT_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/libmct"
)
SET
(
LIBMCT_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/libmct"
)
SET
(
LIBMCT_DOWNLOAD_DIR
"
${
LIBMCT_SOURCE_DIR
}
/src/
${
LIBMCT_PROJECT
}
"
)
SET
(
LIBMCT_DOWNLOAD_DIR
"
${
LIBMCT_SOURCE_DIR
}
/src/
${
LIBMCT_PROJECT
}
"
)
SET
(
LIBMCT_DST_DIR
"libmct"
)
SET
(
LIBMCT_DST_DIR
"libmct"
)
...
@@ -47,7 +48,7 @@ INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})
...
@@ -47,7 +48,7 @@ INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})
FILE
(
WRITE
${
LIBMCT_DOWNLOAD_DIR
}
/CMakeLists.txt
FILE
(
WRITE
${
LIBMCT_DOWNLOAD_DIR
}
/CMakeLists.txt
"PROJECT(LIBMCT)
\n
"
"PROJECT(LIBMCT)
\n
"
"cmake_minimum_required(VERSION 3.0)
\n
"
"cmake_minimum_required(VERSION 3.0)
\n
"
"install(DIRECTORY
${
LIBMCT_
VER
}
/include
${
LIBMCT_VER
}
/lib
\n
"
"install(DIRECTORY
${
LIBMCT_
NAME
}
/include
${
LIBMCT_NAME
}
/lib
\n
"
" DESTINATION
${
LIBMCT_DST_DIR
}
)
\n
"
)
" DESTINATION
${
LIBMCT_DST_DIR
}
)
\n
"
)
ExternalProject_Add
(
ExternalProject_Add
(
...
@@ -55,8 +56,8 @@ ExternalProject_Add(
...
@@ -55,8 +56,8 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
LIBMCT_SOURCE_DIR
}
PREFIX
${
LIBMCT_SOURCE_DIR
}
DOWNLOAD_DIR
${
LIBMCT_DOWNLOAD_DIR
}
DOWNLOAD_DIR
${
LIBMCT_DOWNLOAD_DIR
}
DOWNLOAD_COMMAND wget --no-check-certificate
${
LIBMCT_URL
}
-c -q -O
${
LIBMCT_
VER
}
.tar.gz
DOWNLOAD_COMMAND wget --no-check-certificate
${
LIBMCT_URL
}
-c -q -O
${
LIBMCT_
NAME
}
.tar.gz
&& tar zxvf
${
LIBMCT_
VER
}
.tar.gz
&& tar zxvf
${
LIBMCT_
NAME
}
.tar.gz
DOWNLOAD_NO_PROGRESS 1
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
LIBMCT_INSTALL_ROOT
}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
LIBMCT_INSTALL_ROOT
}
...
...
cmake/external/pslib_brpc.cmake
浏览文件 @
f81957a7
...
@@ -27,12 +27,13 @@ ENDIF()
...
@@ -27,12 +27,13 @@ ENDIF()
INCLUDE
(
ExternalProject
)
INCLUDE
(
ExternalProject
)
SET
(
PSLIB_BRPC_PROJECT
"extern_pslib_brpc"
)
SET
(
PSLIB_BRPC_PROJECT
"extern_pslib_brpc"
)
IF
((
NOT DEFINED PSLIB_BRPC_
VER
)
OR
(
NOT DEFINED PSLIB_BRPC_URL
))
IF
((
NOT DEFINED PSLIB_BRPC_
NAME
)
OR
(
NOT DEFINED PSLIB_BRPC_URL
))
MESSAGE
(
STATUS
"use pre defined download url"
)
MESSAGE
(
STATUS
"use pre defined download url"
)
SET
(
PSLIB_BRPC_VER
"pslib_brpc"
CACHE STRING
""
FORCE
)
#todo pslib version
SET
(
PSLIB_BRPC_VER
"0.1.0"
CACHE STRING
""
FORCE
)
SET
(
PSLIB_BRPC_URL
"http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/
${
PSLIB_BRPC_VER
}
.tar.gz"
CACHE STRING
""
FORCE
)
#todo pslib_brpc url
SET
(
PSLIB_BRPC_NAME
"pslib_brpc"
CACHE STRING
""
FORCE
)
SET
(
PSLIB_BRPC_URL
"https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/
${
PSLIB_BRPC_VER
}
/
${
PSLIB_BRPC_NAME
}
.tar.gz"
CACHE STRING
""
FORCE
)
ENDIF
()
ENDIF
()
MESSAGE
(
STATUS
"PSLIB_BRPC_
VER:
${
PSLIB_BRPC_VER
}
, PSLIB_BRPC_URL:
${
PSLIB_BRPC_URL
}
"
)
MESSAGE
(
STATUS
"PSLIB_BRPC_
NAME:
${
PSLIB_BRPC_NAME
}
, PSLIB_BRPC_URL:
${
PSLIB_BRPC_URL
}
"
)
SET
(
PSLIB_BRPC_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/pslib_brpc"
)
SET
(
PSLIB_BRPC_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/pslib_brpc"
)
SET
(
PSLIB_BRPC_DOWNLOAD_DIR
"
${
PSLIB_BRPC_SOURCE_DIR
}
/src/
${
PSLIB_BRPC_PROJECT
}
"
)
SET
(
PSLIB_BRPC_DOWNLOAD_DIR
"
${
PSLIB_BRPC_SOURCE_DIR
}
/src/
${
PSLIB_BRPC_PROJECT
}
"
)
SET
(
PSLIB_BRPC_DST_DIR
"pslib_brpc"
)
SET
(
PSLIB_BRPC_DST_DIR
"pslib_brpc"
)
...
@@ -50,7 +51,7 @@ INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})
...
@@ -50,7 +51,7 @@ INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})
FILE
(
WRITE
${
PSLIB_BRPC_DOWNLOAD_DIR
}
/CMakeLists.txt
FILE
(
WRITE
${
PSLIB_BRPC_DOWNLOAD_DIR
}
/CMakeLists.txt
"PROJECT(PSLIB_BRPC)
\n
"
"PROJECT(PSLIB_BRPC)
\n
"
"cmake_minimum_required(VERSION 3.0)
\n
"
"cmake_minimum_required(VERSION 3.0)
\n
"
"install(DIRECTORY
${
PSLIB_BRPC_
VER
}
/include
${
PSLIB_BRPC_VER
}
/lib
\n
"
"install(DIRECTORY
${
PSLIB_BRPC_
NAME
}
/include
${
PSLIB_BRPC_NAME
}
/lib
\n
"
" DESTINATION
${
PSLIB_BRPC_DST_DIR
}
)
\n
"
)
" DESTINATION
${
PSLIB_BRPC_DST_DIR
}
)
\n
"
)
ExternalProject_Add
(
ExternalProject_Add
(
...
@@ -58,8 +59,8 @@ ExternalProject_Add(
...
@@ -58,8 +59,8 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
PSLIB_BRPC_SOURCE_DIR
}
PREFIX
${
PSLIB_BRPC_SOURCE_DIR
}
DOWNLOAD_DIR
${
PSLIB_BRPC_DOWNLOAD_DIR
}
DOWNLOAD_DIR
${
PSLIB_BRPC_DOWNLOAD_DIR
}
DOWNLOAD_COMMAND wget --no-check-certificate
${
PSLIB_BRPC_URL
}
-c -q -O
${
PSLIB_BRPC_
VER
}
.tar.gz
DOWNLOAD_COMMAND wget --no-check-certificate
${
PSLIB_BRPC_URL
}
-c -q -O
${
PSLIB_BRPC_
NAME
}
.tar.gz
&& tar zxvf
${
PSLIB_BRPC_
VER
}
.tar.gz
&& tar zxvf
${
PSLIB_BRPC_
NAME
}
.tar.gz
DOWNLOAD_NO_PROGRESS 1
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND
""
UPDATE_COMMAND
""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
PSLIB_BRPC_INSTALL_ROOT
}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=
${
PSLIB_BRPC_INSTALL_ROOT
}
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
f81957a7
...
@@ -180,7 +180,12 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
...
@@ -180,7 +180,12 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper
)
fast_threaded_ssa_graph_executor variable_helper
)
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib
)
if
(
WITH_PSLIB
)
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib
)
else
()
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper
)
endif
(
WITH_PSLIB
)
cc_test
(
data_feed_test SRCS data_feed_test.cc DEPS async_executor
)
cc_test
(
data_feed_test SRCS data_feed_test.cc DEPS async_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
f81957a7
...
@@ -29,7 +29,9 @@ limitations under the License. */
...
@@ -29,7 +29,9 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include "pslib.h"
#include "pslib.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -48,9 +50,11 @@ void AsyncExecutor::CreateThreads(
...
@@ -48,9 +50,11 @@ void AsyncExecutor::CreateThreads(
worker
->
SetDataFeed
(
reader
);
worker
->
SetDataFeed
(
reader
);
worker
->
SetFetchVarNames
(
fetch_var_names
);
worker
->
SetFetchVarNames
(
fetch_var_names
);
worker
->
BindingDataFeedMemory
();
worker
->
BindingDataFeedMemory
();
#ifdef PADDLE_WITH_PSLIB
worker
->
SetPSlibPtr
(
_pslib_ptr
);
worker
->
SetPSlibPtr
(
_pslib_ptr
);
worker
->
SetPullDenseThread
(
_pull_dense_thread
);
worker
->
SetPullDenseThread
(
_pull_dense_thread
);
worker
->
SetParamConfig
(
&
_param_config
);
worker
->
SetParamConfig
(
&
_param_config
);
#endif
}
}
void
PrepareReaders
(
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>&
readers
,
// NOLINT
void
PrepareReaders
(
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>&
readers
,
// NOLINT
...
@@ -64,6 +68,7 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
...
@@ -64,6 +68,7 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
readers
[
0
]
->
SetFileList
(
filelist
);
readers
[
0
]
->
SetFileList
(
filelist
);
}
}
#ifdef PADDLE_WITH_PSLIB
void
AsyncExecutor
::
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
)
{
void
AsyncExecutor
::
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
)
{
_pslib_ptr
=
_pslib_ptr
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
...
@@ -231,6 +236,7 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
...
@@ -231,6 +236,7 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
_pull_dense_thread
->
start
();
_pull_dense_thread
->
start
();
}
}
}
}
#endif
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
const
std
::
string
&
data_feed_desc_str
,
const
std
::
string
&
data_feed_desc_str
,
...
@@ -279,15 +285,21 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -279,15 +285,21 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed
// todo: should be factory method for creating datafeed
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>
readers
;
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>>
readers
;
PrepareReaders
(
readers
,
actual_thread_num
,
data_feed_desc
,
filelist
);
PrepareReaders
(
readers
,
actual_thread_num
,
data_feed_desc
,
filelist
);
#ifdef PADDLE_WITH_PSLIB
PrepareDenseThread
(
mode
);
PrepareDenseThread
(
mode
);
#endif
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>>
workers
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>>
workers
;
workers
.
resize
(
actual_thread_num
);
workers
.
resize
(
actual_thread_num
);
for
(
auto
&
worker
:
workers
)
{
for
(
auto
&
worker
:
workers
)
{
#ifdef PADDLE_WITH_PSLIB
if
(
mode
==
"mpi"
)
{
if
(
mode
==
"mpi"
)
{
worker
.
reset
(
new
AsyncExecutorThreadWorker
);
worker
.
reset
(
new
AsyncExecutorThreadWorker
);
}
else
{
}
else
{
worker
.
reset
(
new
ExecutorThreadWorker
);
worker
.
reset
(
new
ExecutorThreadWorker
);
}
}
#else
worker
.
reset
(
new
ExecutorThreadWorker
);
#endif
}
}
// prepare thread resource here
// prepare thread resource here
...
@@ -306,9 +318,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -306,9 +318,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for
(
auto
&
th
:
threads
)
{
for
(
auto
&
th
:
threads
)
{
th
.
join
();
th
.
join
();
}
}
#ifdef PADDLE_WITH_PSLIB
if
(
mode
==
"mpi"
)
{
if
(
mode
==
"mpi"
)
{
_pull_dense_thread
->
stop
();
_pull_dense_thread
->
stop
();
}
}
#endif
root_scope_
->
DropKids
();
root_scope_
->
DropKids
();
return
;
return
;
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
f81957a7
...
@@ -64,6 +64,7 @@ class AsyncExecutor {
...
@@ -64,6 +64,7 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
fetch_names
,
const
std
::
vector
<
std
::
string
>&
fetch_names
,
const
std
::
string
&
mode
,
const
std
::
string
&
mode
,
const
bool
debug
=
false
);
const
bool
debug
=
false
);
#ifdef PADDLE_WITH_PSLIB
void
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
);
void
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
);
void
InitWorker
(
void
InitWorker
(
const
std
::
string
&
dist_desc
,
const
std
::
string
&
dist_desc
,
...
@@ -75,7 +76,7 @@ class AsyncExecutor {
...
@@ -75,7 +76,7 @@ class AsyncExecutor {
void
InitModel
();
void
InitModel
();
void
SaveModel
(
const
std
::
string
&
path
);
void
SaveModel
(
const
std
::
string
&
path
);
void
InitParamConfig
();
void
InitParamConfig
();
#endif
private:
private:
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
const
ProgramDesc
&
main_program
,
...
@@ -83,16 +84,18 @@ class AsyncExecutor {
...
@@ -83,16 +84,18 @@ class AsyncExecutor {
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
*
root_scope
,
const
int
thread_index
,
Scope
*
root_scope
,
const
int
thread_index
,
const
bool
debug
);
const
bool
debug
);
#ifdef PADDLE_WITH_PSLIB
void
PrepareDenseThread
(
const
std
::
string
&
mode
);
void
PrepareDenseThread
(
const
std
::
string
&
mode
);
#endif
public:
public:
#ifdef PADDLE_WITH_PSLIB
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
_pslib_ptr
;
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
_pslib_ptr
;
std
::
shared_ptr
<
DensePullThread
>
_pull_dense_thread
;
std
::
shared_ptr
<
DensePullThread
>
_pull_dense_thread
;
AsyncWorkerParamConfig
_param_config
;
#endif
Scope
*
root_scope_
;
Scope
*
root_scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
AsyncWorkerParamConfig
_param_config
;
private:
private:
int
actual_thread_num
;
int
actual_thread_num
;
...
...
paddle/fluid/framework/executor_thread_worker.cc
浏览文件 @
f81957a7
...
@@ -31,6 +31,7 @@ limitations under the License. */
...
@@ -31,6 +31,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
#ifdef PADDLE_WITH_PSLIB
int
DensePullThread
::
start
()
{
int
DensePullThread
::
start
()
{
_running
=
true
;
_running
=
true
;
_t
=
std
::
thread
(
&
DensePullThread
::
run
,
this
);
_t
=
std
::
thread
(
&
DensePullThread
::
run
,
this
);
...
@@ -112,6 +113,7 @@ void DensePullThread::increase_thread_version(
...
@@ -112,6 +113,7 @@ void DensePullThread::increase_thread_version(
std
::
lock_guard
<
std
::
mutex
>
lock
(
_mutex_for_version
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
_mutex_for_version
);
_training_versions
[
table_id
][
thread_id
]
++
;
_training_versions
[
table_id
][
thread_id
]
++
;
}
}
#endif
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
auto
&
block
=
program
.
Block
(
0
);
...
@@ -302,6 +304,7 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
...
@@ -302,6 +304,7 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_
=
g_scope
;
root_scope_
=
g_scope
;
}
}
#ifdef PADDLE_WITH_PSLIB
// AsyncExecutor
// AsyncExecutor
void
AsyncExecutorThreadWorker
::
TrainFiles
()
{
void
AsyncExecutorThreadWorker
::
TrainFiles
()
{
SetDevice
();
SetDevice
();
...
@@ -659,6 +662,7 @@ void AsyncExecutorThreadWorker::check_pull_push_memory(
...
@@ -659,6 +662,7 @@ void AsyncExecutorThreadWorker::check_pull_push_memory(
}
}
}
}
}
}
#endif
}
// einit_modelnd namespace framework
}
// einit_modelnd namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/executor_thread_worker.h
浏览文件 @
f81957a7
...
@@ -25,14 +25,16 @@ limitations under the License. */
...
@@ -25,14 +25,16 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_PSLIB
#include "pslib.h"
#include "pslib.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
const
static
uint32_t
MAX_FEASIGN_NUM
=
1000
*
100
*
100
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
#ifdef PADDLE_WITH_PSLIB
const
static
uint32_t
MAX_FEASIGN_NUM
=
1000
*
100
*
100
;
struct
AsyncWorkerParamConfig
{
struct
AsyncWorkerParamConfig
{
int
slot_dim
;
int
slot_dim
;
...
@@ -130,6 +132,8 @@ class DensePullThread {
...
@@ -130,6 +132,8 @@ class DensePullThread {
float
_total_batch_num
=
0
;
float
_total_batch_num
=
0
;
};
};
#endif
class
ExecutorThreadWorker
{
class
ExecutorThreadWorker
{
public:
public:
ExecutorThreadWorker
()
ExecutorThreadWorker
()
...
@@ -154,12 +158,14 @@ class ExecutorThreadWorker {
...
@@ -154,12 +158,14 @@ class ExecutorThreadWorker {
virtual
void
TrainFiles
();
virtual
void
TrainFiles
();
// set fetch variable names from python interface assigned by users
// set fetch variable names from python interface assigned by users
void
SetFetchVarNames
(
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
void
SetFetchVarNames
(
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
#ifdef PADDLE_WITH_PSLIB
virtual
void
SetPSlibPtr
(
virtual
void
SetPSlibPtr
(
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
pslib_ptr
)
{};
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
pslib_ptr
)
{};
virtual
void
SetPullDenseThread
(
virtual
void
SetPullDenseThread
(
std
::
shared_ptr
<
DensePullThread
>
dpt
)
{}
std
::
shared_ptr
<
DensePullThread
>
dpt
)
{}
virtual
void
SetParamConfig
(
virtual
void
SetParamConfig
(
AsyncWorkerParamConfig
*
param_config
)
{}
AsyncWorkerParamConfig
*
param_config
)
{}
#endif
private:
private:
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
...
@@ -188,6 +194,7 @@ class ExecutorThreadWorker {
...
@@ -188,6 +194,7 @@ class ExecutorThreadWorker {
bool
debug_
;
bool
debug_
;
};
};
#ifdef PADDLE_WITH_PSLIB
class
AsyncExecutorThreadWorker
:
public
ExecutorThreadWorker
{
class
AsyncExecutorThreadWorker
:
public
ExecutorThreadWorker
{
public:
public:
AsyncExecutorThreadWorker
()
{}
AsyncExecutorThreadWorker
()
{}
...
@@ -238,6 +245,7 @@ class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
...
@@ -238,6 +245,7 @@ class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
AsyncWorkerParamConfig
*
_param_config
;
AsyncWorkerParamConfig
*
_param_config
;
};
};
#endif
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
f81957a7
...
@@ -41,6 +41,7 @@ namespace pd = paddle::framework;
...
@@ -41,6 +41,7 @@ namespace pd = paddle::framework;
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
using
set_name_func
=
void
(
pd
::
DataFeedDesc
::*
)(
const
std
::
string
&
);
using
set_name_func
=
void
(
pd
::
DataFeedDesc
::*
)(
const
std
::
string
&
);
#ifdef PADDLE_WITH_PSLIB
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
py
::
init
([](
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
)
{
.
def
(
py
::
init
([](
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
)
{
...
@@ -56,5 +57,15 @@ void BindAsyncExecutor(py::module* m) {
...
@@ -56,5 +57,15 @@ void BindAsyncExecutor(py::module* m) {
.
def
(
"init_model"
,
&
framework
::
AsyncExecutor
::
InitModel
)
.
def
(
"init_model"
,
&
framework
::
AsyncExecutor
::
InitModel
)
.
def
(
"save_model"
,
&
framework
::
AsyncExecutor
::
SaveModel
);
.
def
(
"save_model"
,
&
framework
::
AsyncExecutor
::
SaveModel
);
}
// end BindAsyncExecutor
}
// end BindAsyncExecutor
#else
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
py
::
init
([](
framework
::
Scope
*
scope
,
const
platform
::
Place
&
place
)
{
return
std
::
unique_ptr
<
framework
::
AsyncExecutor
>
(
new
framework
::
AsyncExecutor
(
scope
,
place
));
}))
.
def
(
"run_from_files"
,
&
framework
::
AsyncExecutor
::
RunFromFile
)
}
// end BindAsyncExecutor
#endif
}
// end namespace pybind
}
// end namespace pybind
}
// end namespace paddle
}
// end namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录