Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1366832a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1366832a
编写于
7月 01, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dist pass barrier
上级
5988d0c0
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
128 addition
and
47 deletion
+128
-47
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+14
-4
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+7
-2
paddle/fluid/operators/distributed/grpc_client.cc
paddle/fluid/operators/distributed/grpc_client.cc
+25
-4
paddle/fluid/operators/distributed/grpc_client.h
paddle/fluid/operators/distributed/grpc_client.h
+15
-9
paddle/fluid/operators/distributed/request_handler.h
paddle/fluid/operators/distributed/request_handler.h
+3
-0
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+18
-18
paddle/fluid/operators/distributed/rpc_client.cc
paddle/fluid/operators/distributed/rpc_client.cc
+1
-1
paddle/fluid/operators/distributed/rpc_client.h
paddle/fluid/operators/distributed/rpc_client.h
+11
-4
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+19
-3
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+7
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-1
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+6
-0
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
1366832a
...
@@ -48,10 +48,20 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
...
@@ -48,10 +48,20 @@ ExecutorPrepareContext::~ExecutorPrepareContext() {
Executor
::
Executor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
Executor
::
Executor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_DISTRIBUTE
void
Executor
::
Complete
()
{
void
Executor
::
BeginPass
()
{
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
auto
client
=
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
GRPCClient
>
()
::
paddle
::
operators
::
distributed
::
GRPCClient
>
();
->
SendComplete
();
client
->
SendBeginPass
();
client
->
Wait
();
}
void
Executor
::
EndPass
()
{
auto
client
=
::
paddle
::
operators
::
distributed
::
RPCClient
::
GetInstance
<
::
paddle
::
operators
::
distributed
::
GRPCClient
>
();
client
->
SendEndPass
();
client
->
Wait
();
}
}
#endif
#endif
...
...
paddle/fluid/framework/executor.h
浏览文件 @
1366832a
...
@@ -46,9 +46,14 @@ class Executor {
...
@@ -46,9 +46,14 @@ class Executor {
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_DISTRIBUTE
/*
/*
* Sending signal to pserver to mark current
trainer stop
.
* Sending signal to pserver to mark current
pass started
.
*/
*/
void
Complete
();
void
BeginPass
();
/*
* Sending signal to pserver to mark current pass finished.
*/
void
EndPass
();
#endif
#endif
/* @Brief
/* @Brief
...
...
paddle/fluid/operators/distributed/grpc_client.cc
浏览文件 @
1366832a
...
@@ -35,9 +35,17 @@ void GRPCClient::InitEventLoop() {
...
@@ -35,9 +35,17 @@ void GRPCClient::InitEventLoop() {
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
}
}
void
GRPCClient
::
Send
Complete
()
{
void
GRPCClient
::
Send
BeginPass
()
{
for
(
auto
&
it
:
channels_
)
{
for
(
auto
&
it
:
channels_
)
{
this
->
AsyncSendComplete
(
it
.
first
);
VLOG
(
3
)
<<
"send begin pass to: "
it
.
first
;
this
->
AsyncSendBeginPass
(
it
.
first
);
}
}
void
GRPCClient
::
SendEndPass
()
{
for
(
auto
&
it
:
channels_
)
{
VLOG
(
3
)
<<
"send end pass to "
<<
it
.
first
;
this
->
AsyncSendEndPass
(
it
.
first
);
}
}
}
}
...
@@ -226,19 +234,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
...
@@ -226,19 +234,32 @@ void GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
req_count_
++
;
req_count_
++
;
}
}
void
GRPCClient
::
AsyncSend
Complete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
void
GRPCClient
::
AsyncSend
BeginPass
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
const
auto
ch
=
GetChannel
(
ep
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
BatchBarrierProcessor
*
s
=
new
BatchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
s
->
Prepare
(
time_out
);
sendrecv
::
VariableMessage
req
;
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
COMPLETE
_MESSAGE
);
req
.
set_varname
(
BEGIN_PASS
_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
req_count_
++
;
}
}
void
GRPCClient
::
AsyncSendEndPass
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
const
auto
ch
=
GetChannel
(
ep
);
FetchBarrierProcessor
*
s
=
new
FetchBarrierProcessor
(
ch
);
s
->
Prepare
(
time_out
);
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
END_PASS_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
reinterpret_cast
<
void
*>
(
s
));
req_count_
++
;
}
void
GRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
void
GRPCClient
::
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
const
std
::
string
&
dir
,
int64_t
time_out
)
{
int64_t
time_out
)
{
...
...
paddle/fluid/operators/distributed/grpc_client.h
浏览文件 @
1366832a
...
@@ -77,12 +77,13 @@ class BaseProcessor {
...
@@ -77,12 +77,13 @@ class BaseProcessor {
context_
.
reset
(
new
grpc
::
ClientContext
());
context_
.
reset
(
new
grpc
::
ClientContext
());
var_h_
=
var_info
;
var_h_
=
var_info
;
context_
->
set_wait_for_ready
(
true
);
context_
->
set_wait_for_ready
(
true
);
if
(
time_out
)
{
std
::
chrono
::
system_clock
::
time_point
deadline
=
std
::
chrono
::
system_clock
::
time_point
deadline
=
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
milliseconds
(
time_out
);
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
milliseconds
(
time_out
);
context_
->
set_deadline
(
deadline
);
context_
->
set_deadline
(
deadline
);
}
}
}
virtual
void
Prepare
(
int64_t
time_out
)
{
virtual
void
Prepare
(
int64_t
time_out
)
{
context_
.
reset
(
new
grpc
::
ClientContext
());
context_
.
reset
(
new
grpc
::
ClientContext
());
...
@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
...
@@ -214,9 +215,17 @@ class GRPCClient : public RPCClient {
void
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
void
AsyncCheckpointNotify
(
const
std
::
string
&
ep
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendBeginPass
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
AsyncSendEndPass
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
void
Wait
()
override
;
void
Wait
()
override
;
void
SendComplete
()
override
;
void
SendBeginPass
()
override
;
void
SendEndPass
()
override
;
protected:
protected:
void
InitImpl
()
override
;
void
InitImpl
()
override
;
...
@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
...
@@ -227,9 +236,6 @@ class GRPCClient : public RPCClient {
void
Proceed
();
void
Proceed
();
void
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
private:
private:
...
...
paddle/fluid/operators/distributed/request_handler.h
浏览文件 @
1366832a
...
@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
...
@@ -37,11 +37,14 @@ constexpr char kRequestSend[] = "RequestSend";
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestGet
[]
=
"RequestGet"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestPrefetch
[]
=
"RequestPrefetch"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestCheckpoint
[]
=
"RequestCheckpoint"
;
constexpr
char
kRequestPassBarrier
[]
=
"RequestPassBarrier"
;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define BEGIN_PASS_MESSAGE "BEGIN_PASS@RECV"
#define END_PASS_MESSAGE "END_PASS@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
1366832a
...
@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -55,14 +55,14 @@ bool RequestSendHandler::Handle(const std::string& varname,
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv batch barrier message"
;
VLOG
(
3
)
<<
"sync: recv batch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
}
else
if
(
varname
==
BEGIN_PASS_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv complete message"
;
VLOG
(
3
)
<<
"sync: recv begin pass message"
;
rpc_server_
->
DecreaseClientNum
();
rpc_server_
->
WaitCond
(
kRequestSend
);
rpc_server_
->
BeginPass
();
}
else
{
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
if
(
sync_mode_
)
{
rpc_server_
->
WaitCond
(
kRequestSend
);
rpc_server_
->
WaitCond
(
kRequestSend
);
}
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
if
(
invar
==
nullptr
)
{
if
(
invar
==
nullptr
)
{
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
LOG
(
ERROR
)
<<
"sync: Can not find server side var: "
<<
varname
;
...
@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
...
@@ -91,21 +91,21 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework
::
Variable
**
outvar
,
framework
::
Variable
**
outvar
,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
)
{
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
VLOG
(
4
)
<<
"RequestGetHandler:"
<<
varname
;
if
(
varname
!=
FETCH_BARRIER_MESSAGE
)
{
if
(
sync_mode_
)
{
if
(
sync_mode_
)
{
if
(
varname
==
FETCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
else
if
(
varname
==
END_PASS_MESSAGE
)
{
rpc_server_
->
EndPass
();
}
else
{
rpc_server_
->
WaitCond
(
kRequestGet
);
rpc_server_
->
WaitCond
(
kRequestGet
);
*
outvar
=
scope_
->
FindVar
(
varname
);
}
}
}
else
{
if
(
varname
!=
FETCH_BARRIER_MESSAGE
&&
varname
!=
END_PASS_MESSAGE
)
{
*
outvar
=
scope_
->
FindVar
(
varname
);
*
outvar
=
scope_
->
FindVar
(
varname
);
return
true
;
}
}
// FETCH_BARRIER_MESSAGE
if
(
sync_mode_
)
{
VLOG
(
3
)
<<
"sync: recv fetch barrier message"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestGet
);
}
}
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/distributed/rpc_client.cc
浏览文件 @
1366832a
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
#include "gflags/gflags.h"
#include "gflags/gflags.h"
// default to 3min to avoid temprary network failures.
// default to 3min to avoid temprary network failures.
DEFINE_int32
(
rpc_deadline
,
18
0000
,
"deadline timeouts for rpc"
);
DEFINE_int32
(
rpc_deadline
,
3
0000
,
"deadline timeouts for rpc"
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/distributed/rpc_client.h
浏览文件 @
1366832a
...
@@ -60,10 +60,17 @@ class RPCClient {
...
@@ -60,10 +60,17 @@ class RPCClient {
const
std
::
string
&
dir
,
const
std
::
string
&
dir
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
// SendComplete tells all the server that current trainer have no more data
virtual
void
AsyncSendBeginPass
(
const
std
::
string
&
ep
,
// to train, so that the pserver can reduce it's barrier count, and continue
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
// to train with other trainers.
virtual
void
SendComplete
()
=
0
;
virtual
void
AsyncSendEndPass
(
const
std
::
string
&
ep
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
=
0
;
// BeginePass/EndPass tells all the pserver that start/end a pass, so that
// the pserver can increase/reduce it's barrier count, and continue to train
// with other trainers.
virtual
void
SendBeginPass
()
=
0
;
virtual
void
SendEndPass
()
=
0
;
virtual
void
Wait
()
=
0
;
virtual
void
Wait
()
=
0
;
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
1366832a
...
@@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
...
@@ -44,7 +44,8 @@ void RPCServer::SavePort() const {
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
this
,
&
rpc_name
]
{
barrier_cond_
.
wait
(
lock
,
[
this
,
&
rpc_name
]
{
return
(
barrier_counter_
[
rpc_name
]
>=
client_num_
||
exit_flag_
.
load
());
return
((
barrier_counter_
[
rpc_name
]
==
client_num_
&&
client_num_
!=
0
)
||
exit_flag_
.
load
());
});
});
VLOG
(
3
)
<<
"batch_barrier_: "
<<
rpc_name
<<
" "
VLOG
(
3
)
<<
"batch_barrier_: "
<<
rpc_name
<<
" "
...
@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
...
@@ -63,10 +64,25 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
}
}
}
}
void
RPCServer
::
DecreaseClientNum
()
{
void
RPCServer
::
BeginPass
()
{
VLOG
(
4
)
<<
"RPCServer begin increase pass barrier"
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
locl
(
mutex_
);
client_num_
++
;
VLOG
(
4
)
<<
"increase client_num to: "
<<
client_num_
;
}
barrier_cond_
.
notify_all
();
}
void
RPCServer
::
EndPass
()
{
VLOG
(
4
)
<<
"RPCServer begin increase pass barrier"
;
{
std
::
unique_lock
<
std
::
mutex
>
locl
(
mutex_
);
client_num_
--
;
client_num_
--
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
barrier_counter_
[
kRequestGet
]
--
;
}
}
}
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
}
}
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
1366832a
...
@@ -43,6 +43,9 @@ class RPCServer {
...
@@ -43,6 +43,9 @@ class RPCServer {
bool
IsExit
()
{
return
exit_flag_
.
load
();
}
bool
IsExit
()
{
return
exit_flag_
.
load
();
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
int
GetSelectedPort
()
const
{
return
selected_port_
;
}
int
GetClientNum
()
const
;
void
SavePort
()
const
;
void
SavePort
()
const
;
// RegisterRPC, register the rpc method name to a handler
// RegisterRPC, register the rpc method name to a handler
...
@@ -60,7 +63,10 @@ class RPCServer {
...
@@ -60,7 +63,10 @@ class RPCServer {
void
SetCond
(
const
std
::
string
&
rpc_name
);
void
SetCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
WaitCond
(
const
std
::
string
&
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
);
void
DecreaseClientNum
();
void
BeginPass
();
void
EndPass
();
void
ResetBarrierCounter
();
void
ResetBarrierCounter
();
protected:
protected:
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
1366832a
...
@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -493,7 +493,8 @@ All parameter, weight, gradient are variables in Paddle.
py
::
class_
<
framework
::
Executor
>
(
m
,
"Executor"
)
py
::
class_
<
framework
::
Executor
>
(
m
,
"Executor"
)
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
#ifdef PADDLE_WITH_DISTRIBUTE
#ifdef PADDLE_WITH_DISTRIBUTE
.
def
(
"complete"
,
&
Executor
::
Complete
)
.
def
(
"begin_pass"
,
&
Executor
::
BeginPass
)
.
def
(
"end_pass"
,
&
Executor
::
EndPass
)
#endif
#endif
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
int
block_id
,
bool
create_local_scope
,
bool
create_vars
)
{
...
...
python/paddle/fluid/executor.py
浏览文件 @
1366832a
...
@@ -348,6 +348,12 @@ class Executor(object):
...
@@ -348,6 +348,12 @@ class Executor(object):
]
]
return
outs
return
outs
def
begin_pass
(
self
):
self
.
executor
.
begin_pass
()
def
end_pass
(
self
):
self
.
executor
.
end_pass
()
def
run
(
self
,
def
run
(
self
,
program
=
None
,
program
=
None
,
feed
=
None
,
feed
=
None
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录