Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bafe287a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bafe287a
编写于
4月 17, 2023
作者:
L
LiYuRio
提交者:
GitHub
4月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry-pick fleet executor from 2.4 (#52896)
* cherry-pick fleet executor from 2.4 * fix test case
上级
a2aa0087
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
219 addition
and
13 deletion
+219
-13
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
...e/fluid/distributed/fleet_executor/compute_interceptor.cc
+94
-1
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
...le/fluid/distributed/fleet_executor/compute_interceptor.h
+3
-0
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
+33
-5
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
+5
-0
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
...luid/distributed/fleet_executor/interceptor_message.proto
+8
-0
paddle/fluid/distributed/fleet_executor/task_node.cc
paddle/fluid/distributed/fleet_executor/task_node.cc
+10
-0
paddle/fluid/distributed/fleet_executor/task_node.h
paddle/fluid/distributed/fleet_executor/task_node.h
+11
-0
paddle/fluid/pybind/bind_fleet_executor.cc
paddle/fluid/pybind/bind_fleet_executor.cc
+2
-0
python/paddle/distributed/fleet/fleet_executor_utils.py
python/paddle/distributed/fleet/fleet_executor_utils.py
+8
-0
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+42
-6
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
...d/tests/unittests/test_fleet_executor_cond_interceptor.py
+3
-1
未找到文件。
paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
浏览文件 @
bafe287a
...
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/jit/serializer.h"
#include "paddle/phi/core/errors.h"
namespace
paddle
{
...
...
@@ -45,6 +46,65 @@ void ComputeInterceptor::PrepareDeps() {
}
}
void
ComputeInterceptor
::
DecodeMsgVars
(
const
InterceptorMessage
&
msg
)
{
int64_t
scope_id
=
msg
.
scope_idx
();
PADDLE_ENFORCE_LT
(
scope_id
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
scope_id
));
auto
*
scope
=
microbatch_scopes_
[
scope_id
];
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
for
(
const
auto
&
var_iter
:
msg
.
vars_list
())
{
const
std
::
string
&
name
=
var_iter
.
name
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place_
);
std
::
istringstream
ss
(
var_iter
.
stensor
());
auto
*
var
=
scope
->
Var
(
name
);
auto
*
tensor
=
var
->
GetMutable
<
phi
::
DenseTensor
>
();
framework
::
DeserializeFromStream
(
ss
,
tensor
,
dev_ctx
);
VLOG
(
3
)
<<
"Set vars "
<<
name
<<
" with value in scope "
<<
scope_id
<<
" with dims "
<<
tensor
->
dims
()
<<
" with dtype "
<<
tensor
->
dtype
();
}
}
InterceptorMessage
ComputeInterceptor
::
PrepareVarsMsg
()
{
PADDLE_ENFORCE_LT
(
cur_scope_id_
,
microbatch_scopes_
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld"
,
microbatch_scopes_
.
size
(),
cur_scope_id_
));
auto
*
scope
=
microbatch_scopes_
[
cur_scope_id_
];
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_WITH_VARS
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
for
(
auto
iter
:
node_
->
vars_to_dtype
())
{
VarList
*
vars
=
ready_msg
.
add_vars_list
();
const
auto
&
var_name
=
iter
.
first
;
vars
->
set_name
(
var_name
);
std
::
ostringstream
ss
;
auto
&
dev_ctx
=
*
pool
.
Get
(
place_
);
auto
*
var
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE
(
var
,
platform
::
errors
::
NotFound
(
"Variable %s not exists in scope %ld"
,
var_name
,
cur_scope_id_
));
const
auto
&
tensor
=
var
->
Get
<
phi
::
DenseTensor
>
();
framework
::
SerializeToStream
(
ss
,
tensor
,
dev_ctx
);
vars
->
set_stensor
(
ss
.
str
());
VLOG
(
3
)
<<
"Prepare vars msg "
<<
var_name
<<
" with dimension "
<<
tensor
.
dims
()
<<
" dtype "
<<
tensor
.
dtype
();
}
return
ready_msg
;
}
void
ComputeInterceptor
::
IncreaseReady
(
int64_t
up_id
,
int64_t
scope_id
)
{
auto
it
=
in_readys_
.
find
(
up_id
);
PADDLE_ENFORCE_NE
(
it
,
...
...
@@ -105,6 +165,16 @@ bool ComputeInterceptor::IsInputReady() {
flag
=
flag
&&
(
ready_size_map
.
at
(
i
)
!=
0
);
}
if
(
flag
)
{
for
(
auto
iter
:
scope_id_to_finish_flag_
)
{
if
(
iter
.
first
==
i
)
{
break
;
}
else
if
(
!
iter
.
second
)
{
VLOG
(
3
)
<<
"The previous scope is not ready, waiting for the "
"previous scope "
<<
iter
.
first
;
return
false
;
}
}
cur_scope_id_
=
i
;
return
true
;
}
else
{
...
...
@@ -214,11 +284,20 @@ void ComputeInterceptor::RunOps() {
void
ComputeInterceptor
::
Run
()
{
while
(
IsInputReady
()
&&
CanWriteOutput
())
{
VLOG
(
0
)
<<
"id="
<<
GetInterceptorId
()
VLOG
(
3
)
<<
"id="
<<
GetInterceptorId
()
<<
" ComputeInterceptor running in scope "
<<
cur_scope_id_
;
RunOps
();
if
(
!
scope_id_to_finish_flag_
.
empty
())
{
PADDLE_ENFORCE_NE
(
scope_id_to_finish_flag_
.
find
(
cur_scope_id_
),
scope_id_to_finish_flag_
.
end
(),
platform
::
errors
::
NotFound
(
"Can not find scope %ld in scope_id_to_finish"
,
cur_scope_id_
));
scope_id_to_finish_flag_
.
erase
(
cur_scope_id_
);
}
// send to downstream and increase buff used
SendDataReadyToDownStream
();
// reply to upstream and decrease ready data
...
...
@@ -239,6 +318,20 @@ void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
<<
msg
.
scope_idx
()
<<
" "
;
DecreaseBuff
(
msg
.
src_id
());
Run
();
}
else
if
(
msg
.
message_type
()
==
DATA_WITH_VARS
)
{
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive data_with_vars "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
DecodeMsgVars
(
msg
);
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
Run
();
}
else
if
(
msg
.
message_type
()
==
START_LOOP
)
{
VLOG
(
3
)
<<
"Compute interceptor "
<<
interceptor_id_
<<
" receive start_loop "
<<
msg
.
src_id
()
<<
" "
<<
msg
.
scope_idx
()
<<
" "
;
IncreaseReady
(
msg
.
src_id
(),
msg
.
scope_idx
());
scope_id_to_finish_flag_
.
emplace
(
msg
.
scope_idx
(),
false
);
Run
();
}
}
...
...
paddle/fluid/distributed/fleet_executor/compute_interceptor.h
浏览文件 @
bafe287a
...
...
@@ -47,9 +47,12 @@ class ComputeInterceptor : public Interceptor {
private:
void
PrepareDeps
();
InterceptorMessage
PrepareVarsMsg
();
void
DecodeMsgVars
(
const
InterceptorMessage
&
msg
);
bool
IsInputReady
();
bool
CanWriteOutput
();
std
::
map
<
int64_t
,
bool
>
scope_id_to_finish_flag_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/cond_interceptor.cc
浏览文件 @
bafe287a
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include <algorithm>
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
...
...
@@ -38,6 +39,8 @@ void CondInterceptor::PrepareDeps() {
for
(
const
auto
&
up
:
upstream
)
{
if
(
id_to_dep_type
.
at
(
up
.
first
)
==
DependType
::
NORMAL
)
{
normal_in_id_
.
insert
(
up
.
first
);
}
else
if
(
id_to_dep_type
.
at
(
up
.
first
)
==
DependType
::
LOOP
)
{
loop_id_
=
up
.
first
;
}
}
...
...
@@ -90,6 +93,13 @@ void CondInterceptor::SendDataReady(int64_t down_id) {
Send
(
down_id
,
ready_msg
);
}
void
CondInterceptor
::
SendStartLoop
(
int64_t
down_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
START_LOOP
);
ready_msg
.
set_scope_idx
(
cur_scope_id_
);
Send
(
down_id
,
ready_msg
);
}
void
CondInterceptor
::
ReplyDataIsUseless
(
int64_t
up_id
)
{
InterceptorMessage
ready_msg
;
ready_msg
.
set_message_type
(
DATA_IS_USELESS
);
...
...
@@ -104,18 +114,36 @@ void CondInterceptor::Compute() {
if
(
cond
)
{
VLOG
(
3
)
<<
"Loop again in scope "
<<
cur_scope_id_
;
for
(
auto
&
down_id
:
normal_out_id_
)
{
Send
DataReady
(
down_id
);
Send
StartLoop
(
down_id
);
}
++
num_of_scopes_
;
}
else
{
VLOG
(
0
)
<<
"Finish loop in scope "
<<
cur_scope_id_
;
VLOG
(
3
)
<<
"Finish loop in scope "
<<
cur_scope_id_
;
SendDataReady
(
stop_loop_id_
);
}
}
void
CondInterceptor
::
Run
(
const
InterceptorMessage
&
msg
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
)
{
if
(
msg
.
message_type
()
==
DATA_IS_READY
||
msg
.
message_type
()
==
DATA_WITH_VARS
)
{
if
(
msg
.
src_id
()
==
loop_id_
)
{
--
num_of_scopes_
;
VLOG
(
3
)
<<
"Receving loop again message from "
<<
msg
.
src_id
()
<<
" waiting other "
<<
num_of_scopes_
<<
" scopes ready"
;
ready_scope_id_
.
emplace_back
(
msg
.
scope_idx
());
if
(
num_of_scopes_
==
0
)
{
std
::
sort
(
ready_scope_id_
.
begin
(),
ready_scope_id_
.
end
());
for
(
auto
scope_id
:
ready_scope_id_
)
{
VLOG
(
3
)
<<
"Start a new loop in scope "
<<
scope_id
;
cur_scope_id_
=
scope_id
;
Compute
();
}
ready_scope_id_
.
clear
();
}
}
else
{
cur_scope_id_
=
msg
.
scope_idx
();
Compute
();
}
}
else
if
(
msg
.
message_type
()
==
DATA_IS_USELESS
)
{
if
(
node_
->
id_to_dep_type
().
at
(
msg
.
src_id
())
==
DependType
::
STOP_LOOP
)
{
for
(
auto
&
up_id
:
normal_in_id_
)
{
...
...
paddle/fluid/distributed/fleet_executor/cond_interceptor.h
浏览文件 @
bafe287a
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <iomanip>
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
...
...
@@ -37,6 +38,7 @@ class CondInterceptor final : public Interceptor {
void
Compute
();
bool
GetCondResult
();
void
SendDataReady
(
int64_t
down_id
);
void
SendStartLoop
(
int64_t
down_id
);
void
ReplyDataIsUseless
(
int64_t
up_id
);
int64_t
cur_scope_id_
;
...
...
@@ -44,6 +46,9 @@ class CondInterceptor final : public Interceptor {
std
::
set
<
int64_t
>
normal_in_id_
;
std
::
set
<
int64_t
>
normal_out_id_
;
int64_t
stop_loop_id_
;
int64_t
loop_id_
;
int64_t
num_of_scopes_
{
0
};
std
::
vector
<
int64_t
>
ready_scope_id_
;
};
}
// namespace distributed
...
...
paddle/fluid/distributed/fleet_executor/interceptor_message.proto
浏览文件 @
bafe287a
...
...
@@ -24,6 +24,13 @@ enum MessageType {
ERR
=
4
;
// current Interceptor encounters error
RESET
=
5
;
// reset the status
START
=
6
;
DATA_WITH_VARS
=
7
;
START_LOOP
=
8
;
}
message
VarList
{
required
string
name
=
1
;
required
string
stensor
=
2
;
}
message
InterceptorMessage
{
...
...
@@ -32,6 +39,7 @@ message InterceptorMessage {
optional
MessageType
message_type
=
3
[
default
=
RESET
];
optional
bool
ctrl_message
=
4
[
default
=
false
];
optional
int64
scope_idx
=
5
[
default
=
0
];
repeated
VarList
vars_list
=
6
;
}
message
InterceptorResponse
{
optional
bool
rst
=
1
[
default
=
false
];
}
...
...
paddle/fluid/distributed/fleet_executor/task_node.cc
浏览文件 @
bafe287a
...
...
@@ -45,6 +45,16 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
<<
". And the TaskNode's max_run_time and max_slot_num will be set to 1."
;
}
void
TaskNode
::
SetVarsToDtype
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
)
{
vars_to_dtype_
=
vars_to_dtype
;
}
void
TaskNode
::
SetVarsToShape
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
)
{
vars_to_shape_
=
vars_to_shape
;
}
void
TaskNode
::
SetProgram
(
paddle
::
framework
::
ProgramDesc
*
program
)
{
program_
=
program
;
}
...
...
paddle/fluid/distributed/fleet_executor/task_node.h
浏览文件 @
bafe287a
...
...
@@ -116,6 +116,15 @@ class TaskNode final {
int64_t
buff_size
=
1
,
DependType
type
=
DependType
::
NORMAL
);
std
::
string
DebugString
()
const
;
void
SetVarsToDtype
(
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
);
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
()
const
{
return
vars_to_shape_
;
}
const
std
::
map
<
std
::
string
,
std
::
string
>&
vars_to_dtype
()
const
{
return
vars_to_dtype_
;
}
void
SetVarsToShape
(
const
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>&
vars_to_shape
);
private:
DISABLE_COPY_AND_ASSIGN
(
TaskNode
);
...
...
@@ -148,6 +157,8 @@ class TaskNode final {
int64_t
send_down_per_steps_
{
1
};
std
::
string
type_
;
std
::
map
<
std
::
string
,
std
::
string
>
vars_to_dtype_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int64_t
>>
vars_to_shape_
;
};
}
// namespace distributed
...
...
paddle/fluid/pybind/bind_fleet_executor.cc
浏览文件 @
bafe287a
...
...
@@ -184,6 +184,8 @@ void BindFleetExecutor(py::module* m) {
.
def
(
"set_run_at_offset"
,
&
TaskNode
::
SetRunAtOffset
)
.
def
(
"set_type"
,
&
TaskNode
::
SetType
)
.
def
(
"set_cond_var_name"
,
&
TaskNode
::
SetCondVarName
)
.
def
(
"set_vars_to_shape"
,
&
TaskNode
::
SetVarsToShape
)
.
def
(
"set_vars_to_dtype"
,
&
TaskNode
::
SetVarsToDtype
)
.
def
(
"role"
,
&
TaskNode
::
role
)
.
def
(
"init"
,
[](
TaskNode
&
self
)
{
self
.
Init
();
})
.
def
(
"set_program"
,
&
TaskNode
::
SetProgram
);
...
...
python/paddle/distributed/fleet/fleet_executor_utils.py
浏览文件 @
bafe287a
...
...
@@ -33,6 +33,8 @@ class TaskNode:
program
=
None
,
lazy_initialize
=
False
,
cond_var_name
=
None
,
vars_to_dtype
=
None
,
vars_to_shape
=
None
,
):
"""
:param rank (int): Current rank of the task node.
...
...
@@ -58,6 +60,8 @@ class TaskNode:
self
.
program
=
program
self
.
lazy_initialize
=
lazy_initialize
self
.
cond_var_name
=
cond_var_name
self
.
vars_to_dtype
=
vars_to_dtype
self
.
vars_to_shape
=
vars_to_shape
self
.
run_pre_steps
=
None
self
.
run_at_offset
=
None
self
.
node
=
None
...
...
@@ -101,6 +105,10 @@ class TaskNode:
self
.
node
.
set_run_at_offset
(
self
.
run_at_offset
)
if
self
.
cond_var_name
:
self
.
node
.
set_cond_var_name
(
self
.
cond_var_name
)
if
self
.
vars_to_shape
:
self
.
node
.
set_vars_to_shape
(
self
.
vars_to_shape
)
if
self
.
vars_to_dtype
:
self
.
node
.
set_vars_to_dtype
(
self
.
vars_to_dtype
)
for
up
in
self
.
upstreams
:
self
.
node
.
add_upstream_task
(
up
[
0
],
up
[
1
],
up
[
2
])
for
down
in
self
.
downstreams
:
...
...
python/paddle/fluid/executor.py
浏览文件 @
bafe287a
...
...
@@ -963,6 +963,7 @@ class Executor:
self
.
ctx_caches
=
dict
()
self
.
trainer_caches
=
dict
()
self
.
scope_caches
=
dict
()
self
.
micro_scope_cache
=
dict
()
self
.
var_caches
=
dict
()
self
.
pruned_program_caches
=
dict
()
p
=
core
.
Place
()
...
...
@@ -1032,6 +1033,12 @@ class Executor:
def
_add_scope_cache
(
self
,
scope_cache_key
,
scope
):
self
.
scope_caches
[
scope_cache_key
]
=
scope
def
_add_micro_scopes_cache
(
self
,
program_cache_key
,
micro_scopes
:
list
):
self
.
micro_scope_cache
[
program_cache_key
]
=
micro_scopes
def
_get_micro_scopes_cache
(
self
,
program_cache_key
):
return
self
.
micro_scope_cache
.
get
(
program_cache_key
,
None
)
# just for testing, will be removed later
@
lru_cache
()
def
_log_force_set_program_cache
(
self
,
use_program_cache
):
...
...
@@ -1467,6 +1474,7 @@ class Executor:
feed
=
feed
,
fetch_list
=
fetch_list
,
with_standalone_executor
=
self
.
_fleet_executor_with_standalone
,
return_numpy
=
return_numpy
,
)
if
"startup_program"
in
program
.
_pipeline_opt
:
program
=
program
.
_pipeline_opt
[
"startup_program"
]
...
...
@@ -2340,13 +2348,25 @@ class Executor:
fetch_var_name
=
"fetch"
,
fetch_list
=
None
,
with_standalone_executor
=
False
,
return_numpy
=
True
,
):
cache_key
=
_get_strong_program_cache_key
(
program
,
feed
,
fetch_list
)
cached_program
=
self
.
_get_program_cache
(
cache_key
)
cached_scope
=
self
.
_get_scope_cache
(
cache_key
)
micro_cached_scopes
=
self
.
_get_micro_scopes_cache
(
cache_key
)
fleet_opt
=
program
.
_pipeline_opt
[
"fleet_opt"
]
if
cached_scope
is
None
:
cached_scope
=
global_scope
()
self
.
_add_scope_cache
(
cache_key
,
cached_scope
)
if
micro_cached_scopes
is
None
:
micro_cached_scopes
=
[]
if
(
"inference_generation"
in
fleet_opt
and
fleet_opt
[
"inference_generation"
]
):
for
_
in
range
(
int
(
fleet_opt
[
"num_micro_batches"
])):
micro_cached_scopes
.
append
(
cached_scope
.
new_scope
())
self
.
_add_micro_scopes_cache
(
cache_key
,
micro_cached_scopes
)
if
cached_program
is
None
:
assert
(
program
.
_pipeline_opt
...
...
@@ -2424,7 +2444,7 @@ class Executor:
program
=
cached_program
,
scope
=
cached_scope
,
fleet_opt
=
fleet_opt
,
micro_scope_list
=
micro_
scope_list
,
micro_scope_list
=
micro_
cached_scopes
,
with_standalone_executor
=
with_standalone_executor
,
)
...
...
@@ -2448,17 +2468,33 @@ class Executor:
tensor
.
set
(
data
,
self
.
place
)
self
.
_fleet_executor
.
run
(
cache_key
)
if
"fetch_var"
in
fleet_opt
:
# If we speed up the generation in evaluation, we need to generate
# multiple queries at the same time. Each query will in separate scope in order
# not mix up. It indicate that final result will in multiple scopes and need to
# fetch each.
result_list
=
[]
for
scope
in
micro_scope_list
:
for
var
in
fleet_opt
[
"fetch_var"
]:
tensor
=
core
.
get_variable_tensor
(
scope
,
var
)
result_list
.
append
(
as_numpy
(
tensor
))
for
scope
in
micro_cached_scopes
:
scope_result_list
=
[]
for
varname
in
fleet_opt
[
"fetch_var"
]:
tensor
=
None
try
:
tensor
=
core
.
get_variable_tensor
(
scope
,
varname
)
if
return_numpy
:
tensor
=
as_numpy
(
tensor
)
except
:
var
=
scope
.
find_var
(
varname
)
tensor
=
var
.
get_lod_tensor_array
()
if
return_numpy
:
tensor
=
as_numpy
(
tensor
)
else
:
tensor
=
[
t
for
t
in
tensor
]
if
tensor
:
scope_result_list
.
append
(
tensor
)
if
scope_result_list
:
result_list
.
append
(
scope_result_list
)
return
result_list
if
fetch_list
:
...
...
python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py
浏览文件 @
bafe287a
...
...
@@ -154,6 +154,8 @@ class TestFleetExecutor(unittest.TestCase):
node_type
=
"Compute"
,
task_id
=
3
,
program
=
paddle
.
static
.
Program
(),
vars_to_dtype
=
{
'x'
:
'float32'
,
'tmp_1'
:
'int64'
},
vars_to_shape
=
{
'x'
:
(
1
,),
'tmp_1'
:
(
1
,)},
lazy_initialize
=
True
,
)
task_e
=
TaskNode
(
...
...
@@ -205,7 +207,7 @@ class TestFleetExecutor(unittest.TestCase):
exe
=
paddle
.
static
.
Executor
(
place
)
loader
.
start
()
res
=
exe
.
run
(
main_program
)
ref_res
=
np
.
full
([
1
],
10
,
dtype
=
"float32"
)
ref_res
=
np
.
full
([
1
,
1
],
10
,
dtype
=
"float32"
)
for
data
in
res
:
np
.
testing
.
assert_allclose
(
data
,
ref_res
,
rtol
=
1e-05
)
ref_res
=
ref_res
+
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录