From 31323f7911bf9cb87fb1875b036622e4d2704d80 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 19 Dec 2017 13:29:06 +0800 Subject: [PATCH] add test --- paddle/operators/get_places_op.cc | 2 +- python/paddle/v2/fluid/tests/test_layers.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/operators/get_places_op.cc b/paddle/operators/get_places_op.cc index dd937488f4..96a019ac79 100644 --- a/paddle/operators/get_places_op.cc +++ b/paddle/operators/get_places_op.cc @@ -35,7 +35,7 @@ class GetPlacesOp : public framework::OperatorBase { out_var_name); auto &places = *(out_var->GetMutable>()); - places.reserve(trainer_count); + places.resize(trainer_count); if (use_gpu) { for (int i = 0; i < trainer_count; i++) { places.emplace_back(platform::GPUPlace(i)); diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index 2286e94a90..4a03ca68ff 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -170,6 +170,12 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.sequence_expand(x=x, y=y)) print(str(program)) + def test_get_places(self): + program = Program() + with program_guard(program): + x = layers.get_places(use_gpu=True, trainer_count=4) + print(str(program)) + if __name__ == '__main__': unittest.main() -- GitLab