From 8627ef3330e7d0740ab1267ca5fdc0ee9268135d Mon Sep 17 00:00:00 2001 From: chenweihang Date: Wed, 1 Aug 2018 03:19:57 +0000 Subject: [PATCH] refactor: simplify unittest function --- .../tests/unittests/test_memory_usage.py | 27 ++++++------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_memory_usage.py b/python/paddle/fluid/tests/unittests/test_memory_usage.py index c1e286d3275..f9daf83652e 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_usage.py +++ b/python/paddle/fluid/tests/unittests/test_memory_usage.py @@ -19,7 +19,7 @@ import contextlib import unittest -def train_simulator(use_cuda, test_batch_size=10): +def train_simulator(test_batch_size=10): if test_batch_size <= 0: raise ValueError("batch_size should be a positive integeral value, " "but got batch_size={}".format(test_batch_size)) @@ -34,14 +34,7 @@ def train_simulator(use_cuda, test_batch_size=10): sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer.minimize(avg_cost) - train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.uci_housing.train(), buf_size=500), - batch_size=test_batch_size) - - place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() - exe = fluid.Executor(place) - + # Calculate memory usage in current network config lower_usage, upper_usage, unit = fluid.contrib.memory_usage( fluid.default_main_program(), batch_size=test_batch_size) @@ -50,21 +43,17 @@ def train_simulator(use_cuda, test_batch_size=10): class TestMemoryUsage(unittest.TestCase): - def test_cpu(self): - with self.program_scope_guard(): - train_simulator(use_cuda=False) - - def test_cpu_with_unit_KB(self): + def test_with_unit_B(self): with self.program_scope_guard(): - train_simulator(use_cuda=False, test_batch_size=1000) + train_simulator() - def test_cpu_with_unit_MB(self): + def test_with_unit_KB(self): with self.program_scope_guard(): - train_simulator(use_cuda=False, test_batch_size=100000) + train_simulator(test_batch_size=1000) - def test_cuda(self): + def test_with_unit_MB(self): with self.program_scope_guard(): - train_simulator(use_cuda=True) + train_simulator(test_batch_size=100000) @contextlib.contextmanager def program_scope_guard(self): -- GitLab