From aa07814df30e1bf2959dcab543b08cee282e9ff0 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 4 Apr 2019 10:26:30 +0800 Subject: [PATCH] Add 3 uts test=develop --- .../fluid/tests/unittests/test_layers.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index fb40109bdc1..850f4df0a1c 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1909,39 +1909,34 @@ class TestBook(LayerTest): return (out) def test_kldiv_loss(self): - program = Program() - with program_guard(program): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): x = layers.data(name='x', shape=[32, 128, 128], dtype="float32") target = layers.data( name='target', shape=[32, 128, 128], dtype="float32") loss = layers.kldiv_loss(x=x, target=target, reduction='batchmean') - self.assertIsNotNone(loss) - - print(str(program)) + return (loss) def test_temporal_shift(self): - program = Program() - with program_guard(program): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") out = layers.temporal_shift(x, seg_num=4, shift_ratio=0.2) - self.assertIsNotNone(out) - print(str(program)) + return (out) def test_shuffle_channel(self): - program = Program() - with program_guard(program): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") out = layers.shuffle_channel(x, group=4) - self.assertIsNotNone(out) - print(str(program)) + return (out) def test_pixel_shuffle(self): - program = Program() - with program_guard(program): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): x = layers.data(name="X", shape=[9, 4, 4], dtype="float32") out = layers.pixel_shuffle(x, upscale_factor=3) - self.assertIsNotNone(out) - print(str(program)) + return (out) if __name__ == '__main__': -- GitLab