提交 8627ef33 编写于 作者: C chenweihang

refactor: simplify unittest function

上级 964d631c
......@@ -19,7 +19,7 @@ import contextlib
import unittest
def train_simulator(use_cuda, test_batch_size=10):
def train_simulator(test_batch_size=10):
if test_batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(test_batch_size))
......@@ -34,14 +34,7 @@ def train_simulator(use_cuda, test_batch_size=10):
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.uci_housing.train(), buf_size=500),
batch_size=test_batch_size)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Calculate memory usage in current network config
lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
fluid.default_main_program(), batch_size=test_batch_size)
......@@ -50,21 +43,17 @@ def train_simulator(use_cuda, test_batch_size=10):
class TestMemoryUsage(unittest.TestCase):
def test_cpu(self):
with self.program_scope_guard():
train_simulator(use_cuda=False)
def test_cpu_with_unit_KB(self):
def test_with_unit_B(self):
with self.program_scope_guard():
train_simulator(use_cuda=False, test_batch_size=1000)
train_simulator()
def test_cpu_with_unit_MB(self):
def test_with_unit_KB(self):
with self.program_scope_guard():
train_simulator(use_cuda=False, test_batch_size=100000)
train_simulator(test_batch_size=1000)
def test_cuda(self):
def test_with_unit_MB(self):
with self.program_scope_guard():
train_simulator(use_cuda=True)
train_simulator(test_batch_size=100000)
@contextlib.contextmanager
def program_scope_guard(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册