提交 a5a373f4 编写于 作者: S sneaxiy

enhance ut to test more cases, test=develop

上级 a0d14b18
...@@ -925,6 +925,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -925,6 +925,9 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The feeded number of persistable variables should " "The feeded number of persistable variables should "
"not be less than non-persistable variables")); "not be less than non-persistable variables"));
}
if (non_persistable_feed_len != -1UL) {
for (size_t i = 0; i < non_persistable_feed_len; ++i) { for (size_t i = 0; i < non_persistable_feed_len; ++i) {
member_->SetHasFeed(i); member_->SetHasFeed(i);
} }
......
...@@ -23,15 +23,18 @@ class TestInferencePartialFeed(unittest.TestCase): ...@@ -23,15 +23,18 @@ class TestInferencePartialFeed(unittest.TestCase):
self.iterations = 10 self.iterations = 10
self.size = 10 self.size = 10
def run_network(self, places, use_split): def run_network(self, places, use_split, has_persistable):
startup_prog = fluid.Program() startup_prog = fluid.Program()
main_prog = fluid.Program() main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
x = fluid.data(name='x', shape=[None, self.size], dtype='float32') x = fluid.data(name='x', shape=[None, self.size], dtype='float32')
y = fluid.data(name='y', shape=[None, self.size], dtype='float32') y = fluid.data(name='y', shape=[None, self.size], dtype='float32')
lr = fluid.data(name='lr', shape=[1], dtype='float32') if has_persistable:
lr.persistable = True lr = fluid.data(name='lr', shape=[1], dtype='float32')
lr.persistable = True
else:
lr = fluid.data(name='lr', shape=[None], dtype='float32')
relu_x = fluid.layers.relu(x) relu_x = fluid.layers.relu(x)
relu_y = fluid.layers.relu(y) relu_y = fluid.layers.relu(y)
...@@ -50,7 +53,7 @@ class TestInferencePartialFeed(unittest.TestCase): ...@@ -50,7 +53,7 @@ class TestInferencePartialFeed(unittest.TestCase):
for place_num in six.moves.range(1, len(places) * 3): for place_num in six.moves.range(1, len(places) * 3):
x_np = gen_random([place_num, self.size]) x_np = gen_random([place_num, self.size])
y_np = gen_random([place_num, self.size]) y_np = gen_random([place_num, self.size])
if place_num <= len(places): if not lr.persistable or place_num <= len(places):
lr_np = gen_random([place_num]) lr_np = gen_random([place_num])
else: else:
lr_np = gen_random([1]) lr_np = gen_random([1])
...@@ -64,7 +67,7 @@ class TestInferencePartialFeed(unittest.TestCase): ...@@ -64,7 +67,7 @@ class TestInferencePartialFeed(unittest.TestCase):
assert_result(x_np, relu_x_np) assert_result(x_np, relu_x_np)
assert_result(y_np, relu_y_np) assert_result(y_np, relu_y_np)
if place_num <= len(places): if not lr.persistable or place_num <= len(places):
assert_result(lr_np, relu_lr_np) assert_result(lr_np, relu_lr_np)
else: else:
expected_relu_lr_np = max(lr_np[0], 0) expected_relu_lr_np = max(lr_np[0], 0)
...@@ -113,8 +116,10 @@ class TestInferencePartialFeed(unittest.TestCase): ...@@ -113,8 +116,10 @@ class TestInferencePartialFeed(unittest.TestCase):
places.append(fluid.cuda_places()) places.append(fluid.cuda_places())
for p in places: for p in places:
self.run_network(p, use_split=True) for has_persistable in [False, True]:
self.run_network(p, use_split=False) for use_split in [False, True]:
self.run_network(
p, use_split=use_split, has_persistable=has_persistable)
class TestInferencePartialFeedUsingDataLoader(unittest.TestCase): class TestInferencePartialFeedUsingDataLoader(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册