提交 a5a373f4 编写于 作者: S sneaxiy

enhance ut to test more cases, test=develop

上级 a0d14b18
......@@ -925,6 +925,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
platform::errors::InvalidArgument(
"The feeded number of persistable variables should "
"not be less than non-persistable variables"));
}
if (non_persistable_feed_len != -1UL) {
for (size_t i = 0; i < non_persistable_feed_len; ++i) {
member_->SetHasFeed(i);
}
......
......@@ -23,15 +23,18 @@ class TestInferencePartialFeed(unittest.TestCase):
self.iterations = 10
self.size = 10
def run_network(self, places, use_split):
def run_network(self, places, use_split, has_persistable):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
x = fluid.data(name='x', shape=[None, self.size], dtype='float32')
y = fluid.data(name='y', shape=[None, self.size], dtype='float32')
lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
if has_persistable:
lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
else:
lr = fluid.data(name='lr', shape=[None], dtype='float32')
relu_x = fluid.layers.relu(x)
relu_y = fluid.layers.relu(y)
......@@ -50,7 +53,7 @@ class TestInferencePartialFeed(unittest.TestCase):
for place_num in six.moves.range(1, len(places) * 3):
x_np = gen_random([place_num, self.size])
y_np = gen_random([place_num, self.size])
if place_num <= len(places):
if not lr.persistable or place_num <= len(places):
lr_np = gen_random([place_num])
else:
lr_np = gen_random([1])
......@@ -64,7 +67,7 @@ class TestInferencePartialFeed(unittest.TestCase):
assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np)
if place_num <= len(places):
if not lr.persistable or place_num <= len(places):
assert_result(lr_np, relu_lr_np)
else:
expected_relu_lr_np = max(lr_np[0], 0)
......@@ -113,8 +116,10 @@ class TestInferencePartialFeed(unittest.TestCase):
places.append(fluid.cuda_places())
for p in places:
self.run_network(p, use_split=True)
self.run_network(p, use_split=False)
for has_persistable in [False, True]:
for use_split in [False, True]:
self.run_network(
p, use_split=use_split, has_persistable=has_persistable)
class TestInferencePartialFeedUsingDataLoader(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册