diff --git a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py index dbd8e5a8818b9b60e6b52307e01602f6c5b427a4..548ebd6710fa85ecbba2cc9c64904ffcfcaa95a5 100644 --- a/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/high-level-api/image_classification/test_image_classification_vgg.py @@ -107,8 +107,7 @@ def train(use_cuda, train_program, parallel, params_dirname): event_handler=event_handler, feed_order=['pixel', 'label']) - if six.PY3: - del trainer + return trainer def infer(use_cuda, inference_program, parallel, params_dirname=None): @@ -132,12 +131,15 @@ def main(use_cuda, parallel): save_path = "image_classification_vgg.inference.model" os.environ['CPU_NUM'] = str(4) - train( + trainer = train( use_cuda=use_cuda, train_program=train_network, params_dirname=save_path, parallel=parallel) + if six.PY3: + del trainer + # FIXME(zcd): in the inference stage, the number of # input data is one, it is not appropriate to use parallel. if parallel and use_cuda: diff --git a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py index 2546fdbb7196295f7972e99ff29d7645dada9a58..1e1069d5f6ffca28352954100ca7afa1183dd3de 100644 --- a/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py +++ b/python/paddle/fluid/tests/book/high-level-api/recognize_digits/test_recognize_digits_mlp.py @@ -90,8 +90,7 @@ def train(use_cuda, train_program, params_dirname, parallel): reader=train_reader, feed_order=['img', 'label']) - if six.PY3: - del trainer + return trainer def infer(use_cuda, inference_program, parallel, params_dirname=None): @@ -117,12 +116,15 @@ def main(use_cuda, parallel): # call train() with is_local argument to run distributed train os.environ['CPU_NUM'] = str(4) - train( + trainer = train( use_cuda=use_cuda, train_program=train_program, params_dirname=params_dirname, parallel=parallel) + if six.PY3: + del trainer + # FIXME(zcd): in the inference stage, the number of # input data is one, it is not appropriate to use parallel. if parallel and use_cuda: