提交 ad9e6476 编写于 作者: M minqiyang

Force object deletion on trainer in unit test

上级 5af0c60f
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
from __future__ import print_function from __future__ import print_function
import six
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy import numpy
import six
import os import os
import cifar10_small_test_set import cifar10_small_test_set
...@@ -101,23 +101,14 @@ def train(use_cuda, train_program, parallel, params_dirname): ...@@ -101,23 +101,14 @@ def train(use_cuda, train_program, parallel, params_dirname):
optimizer_func=optimizer_func, optimizer_func=optimizer_func,
parallel=parallel) parallel=parallel)
if six.PY2: trainer.train(
trainer.train( reader=train_reader,
reader=train_reader, num_epochs=1,
num_epochs=1, event_handler=event_handler,
event_handler=event_handler, feed_order=['pixel', 'label'])
feed_order=['pixel', 'label'])
else: if six.PY3:
import paddle.fluid.core as core del trainer
import paddle.compat as cpt
try:
trainer.train(
reader=train_reader,
num_epochs=1,
event_handler=event_handler,
feed_order=['pixel', 'label'])
except core.EnforceNotMet as ex:
assert ("kid scope" in cpt.get_exception_message(ex))
def infer(use_cuda, inference_program, parallel, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
......
...@@ -84,23 +84,14 @@ def train(use_cuda, train_program, params_dirname, parallel): ...@@ -84,23 +84,14 @@ def train(use_cuda, train_program, params_dirname, parallel):
paddle.dataset.mnist.train(), buf_size=500), paddle.dataset.mnist.train(), buf_size=500),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
if six.PY2: trainer.train(
trainer.train( num_epochs=1,
num_epochs=1, event_handler=event_handler,
event_handler=event_handler, reader=train_reader,
reader=train_reader, feed_order=['img', 'label'])
feed_order=['img', 'label'])
else: if six.PY3:
import paddle.fluid.core as core del trainer
import paddle.compat as cpt
try:
trainer.train(
num_epochs=1,
event_handler=event_handler,
reader=train_reader,
feed_order=['img', 'label'])
except core.EnforceNotMet as ex:
assert ("kid scope" in cpt.get_exception_message(ex))
def infer(use_cuda, inference_program, parallel, params_dirname=None): def infer(use_cuda, inference_program, parallel, params_dirname=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册