From 2bb0ac927b006ed375322fbd5fe4b0cbc72f39fa Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 7 Sep 2018 09:55:17 +0800 Subject: [PATCH] Polish code --- .../image_classification/test_image_classification_vgg.py | 8 +++++--- .../recognize_digits/test_recognize_digits_mlp.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index dbd8e5a8818..548ebd6710f 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -107,8 +107,7 @@ def train(use_cuda, train_program, parallel, params_dirname): event_handler=event_handler, feed_order=['pixel', 'label']) - if six.PY3: - del trainer + return trainer def infer(use_cuda, inference_program, parallel, params_dirname=None): @@ -132,12 +131,15 @@ def main(use_cuda, parallel): save_path = "image_classification_vgg.inference.model" os.environ['CPU_NUM'] = str(4) - train( + trainer = train( use_cuda=use_cuda, train_program=train_network, params_dirname=save_path, parallel=parallel) + if six.PY3: + del trainer + # FIXME(zcd): in the inference stage, the number of # input data is one, it is not appropriate to use parallel. if parallel and use_cuda: diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index 2546fdbb719..1e1069d5f6f 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -90,8 +90,7 @@ def train(use_cuda, train_program, params_dirname, parallel): reader=train_reader, feed_order=['img', 'label']) - if six.PY3: - del trainer + return trainer def infer(use_cuda, inference_program, parallel, params_dirname=None): @@ -117,12 +116,15 @@ def main(use_cuda, parallel): # call train() with is_local argument to run distributed train os.environ['CPU_NUM'] = str(4) - train( + trainer = train( use_cuda=use_cuda, train_program=train_program, params_dirname=params_dirname, parallel=parallel) + if six.PY3: + del trainer + # FIXME(zcd): in the inference stage, the number of # input data is one, it is not appropriate to use parallel. if parallel and use_cuda: -- GitLab