提交 31323f79 编写于 作者: Q qijun

add test

上级 9fbd9426
...@@ -35,7 +35,7 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -35,7 +35,7 @@ class GetPlacesOp : public framework::OperatorBase {
out_var_name); out_var_name);
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>()); auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
places.reserve(trainer_count); places.resize(trainer_count);
if (use_gpu) { if (use_gpu) {
for (int i = 0; i < trainer_count; i++) { for (int i = 0; i < trainer_count; i++) {
places.emplace_back(platform::GPUPlace(i)); places.emplace_back(platform::GPUPlace(i));
......
...@@ -170,6 +170,12 @@ class TestBook(unittest.TestCase): ...@@ -170,6 +170,12 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) self.assertIsNotNone(layers.sequence_expand(x=x, y=y))
print(str(program)) print(str(program))
def test_get_places(self):
program = Program()
with program_guard(program):
x = layers.get_places(use_gpu=True, trainer_count=4)
print(str(program))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册