diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index 548ebd6710fa85ecbba2cc9c64904ffcfcaa95a5..2767e8b5d948da4aee1010b99f1a58435f81fae2 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -107,7 +107,11 @@ def train(use_cuda, train_program, parallel, params_dirname): event_handler=event_handler, feed_order=['pixel', 'label']) - return trainer + def _del_trainer(trainer): + del trainer + + if six.PY3: + _del_trainer(trainer) def infer(use_cuda, inference_program, parallel, params_dirname=None): @@ -131,15 +135,12 @@ def main(use_cuda, parallel): save_path = "image_classification_vgg.inference.model" os.environ['CPU_NUM'] = str(4) - trainer = train( + train( use_cuda=use_cuda, train_program=train_network, params_dirname=save_path, parallel=parallel) - if six.PY3: - del trainer - # FIXME(zcd): in the inference stage, the number of # input data is one, it is not appropriate to use parallel. if parallel and use_cuda: diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index 1e1069d5f6ffca28352954100ca7afa1183dd3de..b7846574664f9190ce5c1fd52771f8b572425aab 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -90,7 +90,11 @@ def train(use_cuda, train_program, params_dirname, parallel): reader=train_reader, feed_order=['img', 'label']) - return trainer + def _del_trainer(trainer): + del trainer + + if six.PY3: + _del_trainer(trainer) def infer(use_cuda, inference_program, parallel, params_dirname=None): @@ -116,15 +120,12 @@ def main(use_cuda, parallel): # call train() with is_local argument to run distributed train os.environ['CPU_NUM'] = str(4) - trainer = train( + train( use_cuda=use_cuda, train_program=train_program, params_dirname=params_dirname, parallel=parallel) - if six.PY3: - del trainer - # FIXME(zcd): in the inference stage, the number of # input data is one, it is not appropriate to use parallel. if parallel and use_cuda: