Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a5a373f4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a5a373f4
编写于
2月 25, 2020
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance ut to test more cases, test=develop
上级
a0d14b18
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
7 deletion
+15
-7
paddle/fluid/framework/parallel_executor.cc
paddle/fluid/framework/parallel_executor.cc
+3
-0
python/paddle/fluid/tests/unittests/test_parallel_executor_inference_feed_partial_data.py
...sts/test_parallel_executor_inference_feed_partial_data.py
+12
-7
未找到文件。
paddle/fluid/framework/parallel_executor.cc
浏览文件 @
a5a373f4
...
...
@@ -925,6 +925,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
platform
::
errors
::
InvalidArgument
(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"
));
}
if
(
non_persistable_feed_len
!=
-
1UL
)
{
for
(
size_t
i
=
0
;
i
<
non_persistable_feed_len
;
++
i
)
{
member_
->
SetHasFeed
(
i
);
}
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor_inference_feed_partial_data.py
浏览文件 @
a5a373f4
...
...
@@ -23,15 +23,18 @@ class TestInferencePartialFeed(unittest.TestCase):
self
.
iterations
=
10
self
.
size
=
10
def
run_network
(
self
,
places
,
use_split
):
def
run_network
(
self
,
places
,
use_split
,
has_persistable
):
startup_prog
=
fluid
.
Program
()
main_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
None
,
self
.
size
],
dtype
=
'float32'
)
y
=
fluid
.
data
(
name
=
'y'
,
shape
=
[
None
,
self
.
size
],
dtype
=
'float32'
)
lr
=
fluid
.
data
(
name
=
'lr'
,
shape
=
[
1
],
dtype
=
'float32'
)
lr
.
persistable
=
True
if
has_persistable
:
lr
=
fluid
.
data
(
name
=
'lr'
,
shape
=
[
1
],
dtype
=
'float32'
)
lr
.
persistable
=
True
else
:
lr
=
fluid
.
data
(
name
=
'lr'
,
shape
=
[
None
],
dtype
=
'float32'
)
relu_x
=
fluid
.
layers
.
relu
(
x
)
relu_y
=
fluid
.
layers
.
relu
(
y
)
...
...
@@ -50,7 +53,7 @@ class TestInferencePartialFeed(unittest.TestCase):
for
place_num
in
six
.
moves
.
range
(
1
,
len
(
places
)
*
3
):
x_np
=
gen_random
([
place_num
,
self
.
size
])
y_np
=
gen_random
([
place_num
,
self
.
size
])
if
place_num
<=
len
(
places
):
if
not
lr
.
persistable
or
place_num
<=
len
(
places
):
lr_np
=
gen_random
([
place_num
])
else
:
lr_np
=
gen_random
([
1
])
...
...
@@ -64,7 +67,7 @@ class TestInferencePartialFeed(unittest.TestCase):
assert_result
(
x_np
,
relu_x_np
)
assert_result
(
y_np
,
relu_y_np
)
if
place_num
<=
len
(
places
):
if
not
lr
.
persistable
or
place_num
<=
len
(
places
):
assert_result
(
lr_np
,
relu_lr_np
)
else
:
expected_relu_lr_np
=
max
(
lr_np
[
0
],
0
)
...
...
@@ -113,8 +116,10 @@ class TestInferencePartialFeed(unittest.TestCase):
places
.
append
(
fluid
.
cuda_places
())
for
p
in
places
:
self
.
run_network
(
p
,
use_split
=
True
)
self
.
run_network
(
p
,
use_split
=
False
)
for
has_persistable
in
[
False
,
True
]:
for
use_split
in
[
False
,
True
]:
self
.
run_network
(
p
,
use_split
=
use_split
,
has_persistable
=
has_persistable
)
class
TestInferencePartialFeedUsingDataLoader
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录