提交 8627ef33 编写于 作者: C chenweihang

refactor: simplify unittest function

上级 964d631c
...@@ -19,7 +19,7 @@ import contextlib ...@@ -19,7 +19,7 @@ import contextlib
import unittest import unittest
def train_simulator(use_cuda, test_batch_size=10): def train_simulator(test_batch_size=10):
if test_batch_size <= 0: if test_batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, " raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(test_batch_size)) "but got batch_size={}".format(test_batch_size))
...@@ -34,14 +34,7 @@ def train_simulator(use_cuda, test_batch_size=10): ...@@ -34,14 +34,7 @@ def train_simulator(use_cuda, test_batch_size=10):
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch( # Calculate memory usage in current network config
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=test_batch_size)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
lower_usage, upper_usage, unit = fluid.contrib.memory_usage( lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
fluid.default_main_program(), batch_size=test_batch_size) fluid.default_main_program(), batch_size=test_batch_size)
...@@ -50,21 +43,17 @@ def train_simulator(use_cuda, test_batch_size=10): ...@@ -50,21 +43,17 @@ def train_simulator(use_cuda, test_batch_size=10):
class TestMemoryUsage(unittest.TestCase): class TestMemoryUsage(unittest.TestCase):
def test_cpu(self): def test_with_unit_B(self):
with self.program_scope_guard():
train_simulator(use_cuda=False)
def test_cpu_with_unit_KB(self):
with self.program_scope_guard(): with self.program_scope_guard():
train_simulator(use_cuda=False, test_batch_size=1000) train_simulator()
def test_cpu_with_unit_MB(self): def test_with_unit_KB(self):
with self.program_scope_guard(): with self.program_scope_guard():
train_simulator(use_cuda=False, test_batch_size=100000) train_simulator(test_batch_size=1000)
def test_cuda(self): def test_with_unit_MB(self):
with self.program_scope_guard(): with self.program_scope_guard():
train_simulator(use_cuda=True) train_simulator(test_batch_size=100000)
@contextlib.contextmanager @contextlib.contextmanager
def program_scope_guard(self): def program_scope_guard(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册