diff --git a/fluid/PaddleCV/face_detection/README_cn.md b/fluid/PaddleCV/face_detection/README_cn.md index 9f165f781f7c931913053fac85bf667c7b360466..80485009d24e278a00b3d21001602fbe6ef9eef6 100644 --- a/fluid/PaddleCV/face_detection/README_cn.md +++ b/fluid/PaddleCV/face_detection/README_cn.md @@ -1,8 +1,3 @@ -运行本目录下的程序示例需要使用 PaddlePaddle 最新的 develop branch 版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。 - ---- - - ## Pyramidbox 人脸检测 ## Table of Contents diff --git a/fluid/PaddleCV/object_detection/.gitignore b/fluid/PaddleCV/object_detection/.gitignore index a4552fd2acc864059f0cee0d88f96c0b5bd73aa0..404af33d9659de6c2c34a755475be5d0ad5948af 100644 --- a/fluid/PaddleCV/object_detection/.gitignore +++ b/fluid/PaddleCV/object_detection/.gitignore @@ -21,3 +21,4 @@ data/pascalvoc/trainval.txt log* *.log ssd_mobilenet_v1_pascalvoc* +quant_model diff --git a/fluid/PaddleCV/object_detection/README.md b/fluid/PaddleCV/object_detection/README.md index ec93f153e085401fd9d89b257b5ba45a700db08c..651016cdffa7fe6c4fa1dc5e886b9b18e8e40b04 100644 --- a/fluid/PaddleCV/object_detection/README.md +++ b/fluid/PaddleCV/object_detection/README.md @@ -1,7 +1,3 @@ -The minimum PaddlePaddle version needed for the code sample in this directory is the latest develop branch. If you are on a version of PaddlePaddle earlier than this, [please update your installation](http://www.paddlepaddle.org/docs/develop/documentation/en/build_and_install/pip_install_en.html). - ---- - ## SSD Object Detection ## Table of Contents diff --git a/fluid/PaddleCV/object_detection/README_cn.md b/fluid/PaddleCV/object_detection/README_cn.md index 6595c05460128223296f8fdd1cddbc482812616f..99603953a9dad956bcd13e7af68c59a9ae45c9cd 100644 --- a/fluid/PaddleCV/object_detection/README_cn.md +++ b/fluid/PaddleCV/object_detection/README_cn.md @@ -1,7 +1,3 @@ -运行本目录下的程序示例需要使用 PaddlePaddle 最新的 develop branch 版本。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新 PaddlePaddle 安装版本。 - ---- - ## SSD 目标检测 ## Table of Contents diff --git a/fluid/PaddleCV/object_detection/README_quant.md b/fluid/PaddleCV/object_detection/README_quant.md new file mode 100644 index 0000000000000000000000000000000000000000..ce334179ef197334bc4473c897e3afbefa58fb61 --- /dev/null +++ b/fluid/PaddleCV/object_detection/README_quant.md @@ -0,0 +1,142 @@ +## Quantization-aware training for SSD + +### Introduction + +The quantization-aware training used in this experiments is introduced in [fixed-point quantization desigin](https://gthub.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/quantization/fixed_point_quantization.md). Since quantization-aware training is still an active area of research and experimentation, +here, we just give an simple quantization training usage in Fluid based on MobileNet-SSD model, and more other exeperiments are still needed, like how to quantization traning by considering fusing batch normalization and convolution/fully-connected layers, channel-wise quantization of weights and so on. + + +A Python transpiler is used to rewrite Fluid training program or evaluation program for quantization-aware training: + +```python + + #startup_prog = fluid.Program() + #train_prog = fluid.Program() + #loss = build_program( + # main_prog=train_prog, + # startup_prog=startup_prog, + # is_train=True) + #build_program( + # main_prog=test_prog, + # startup_prog=startup_prog, + # is_train=False) + #test_prog = test_prog.clone(for_test=True) + # above is an pseudo code + + transpiler = fluid.contrib.QuantizeTranspiler( + weight_bits=8, + activation_bits=8, + activation_quantize_type='abs_max', # or 'range_abs_max' + weight_quantize_type='abs_max') + # note, transpiler.training_transpile will rewrite train_prog + # startup_prog is needed since it needs to insert and initialize + # some state variable + transpiler.training_transpile(train_prog, startup_prog) + transpiler.training_transpile(test_prog, startup_prog) +``` + + According to above design, this transpiler inserts fake quantization and de-quantization operation for each convolution operation (including depthwise convolution operation) and fully-connected operation. These quantizations take affect on weights and activations. + + In the design, we introduce dynamic quantization and static quantization strategies for different activation quantization methods. In the expriments, when set `activation_quantize_type` to `abs_max`, it is dynamic quantization. That is to say, the quantization scale (maximum of absolute value) of activation will be calculated each mini-batch during inference. When set `activation_quantize_type` to `range_abs_max`, a quantization scale for inference period will be calculated during training. Following part will introduce how to train. + +### Quantization-aware training + + The training is fine-tuned on the well-trained MobileNet-SSD model. So download model at first: + + ```bash + wget http://paddlemodels.bj.bcebos.com/ssd_mobilenet_v1_pascalvoc.tar.gz + ``` + +- dynamic quantization: + + ```python + python main_quant.py \ + --data_dir=$PascalVOC_DIR$ \ + --mode='train' \ + --init_model=ssd_mobilenet_v1_pascalvoc \ + --act_quant_type='abs_max' \ + --epoc_num=20 \ + --learning_rate=0.0001 \ + --batch_size=64 \ + --model_save_dir=$OUTPUT_DIR$ + ``` + Since fine-tuned on a well-trained model, we use a small start learnng rate 0.0001, and train 20 epocs. + +- static quantization: + ```python + python main_quant.py \ + --data_dir=$PascalVOC_DIR$ \ + --mode='train' \ + --init_model=ssd_mobilenet_v1_pascalvoc \ + --act_quant_type='range_abs_max' \ + --epoc_num=80 \ + --learning_rate=0.001 \ + --lr_epochs=30,60 \ + --lr_decay_rates=1,0.1,0.01 \ + --batch_size=64 \ + --model_save_dir=$OUTPUT_DIR$ + ``` + Here, train 80 epocs, learning rate decays at 30 and 60 epocs by 0.1 every time. Users can adjust these hype-parameters. + +### Convert to inference model + + As described in the design documentation, the inference graph is a little different from training, the difference is the de-quantization operation is before or after conv/fc. This is equivalent in training due to linear operation of conv/fc and de-quantization and functions' commutative law. But for inference, it needs to convert the graph, `fluid.contrib.QuantizeTranspiler.freeze_program` is used to do this: + + ```python + #startup_prog = fluid.Program() + #test_prog = fluid.Program() + #test_py_reader, map_eval, nmsed_out, image = build_program( + # main_prog=test_prog, + # startup_prog=startup_prog, + # train_params=configs, + # is_train=False) + #test_prog = test_prog.clone(for_test=True) + #transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8, + # activation_bits=8, + # activation_quantize_type=act_quant_type, + # weight_quantize_type='abs_max') + #transpiler.training_transpile(test_prog, startup_prog) + #place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + #exe = fluid.Executor(place) + #exe.run(startup_prog) + + def if_exist(var): + return os.path.exists(os.path.join(init_model, var.name)) + fluid.io.load_vars(exe, init_model, main_program=test_prog, + predicate=if_exist) + # freeze the rewrited training program + # freeze after load parameters, it will quantized weights + transpiler.freeze_program(test_prog, place) + ``` + + Users can evaluate the converted model by: + + ```bash + python main_quant.py \ + --data_dir=$PascalVOC_DIR$ \ + --mode='test' \ + --init_model=$MODLE_DIR$ \ + --model_save_dir=$MobileNet_SSD_8BIT_MODEL$ + ``` + + You also can check the 8-bit model by the inference scripts + + ```bash + python main_quant.py \ + --mode='infer' \ + --init_model=$MobileNet_SSD_8BIT_MODEL$ \ + --image_path='/data/PascalVOC/VOCdevkit/VOC2007/JPEGImages/002271.jpg' + ``` + See 002271.jpg for the visualized image with bbouding boxes. + +### Results + +Results of MobileNet-v1-SSD 300x300 model on PascalVOC dataset. + +| Model | mAP | +|:---------------------------------------:|:------------------:| +|Floating point: 32bit | 73.32% | +|Fixed point: 8bit, dynamic quantization | 72.77% | +|Fixed point: 8bit, static quantization | 72.45% | + + As mentioned above, other experiments, like how to quantization traning by considering fusing batch normalization and convolution/fully-connected layers, channel-wise quantization of weights, quantizated weights type with uint8 instead of int8 and so on. diff --git a/fluid/PaddleCV/object_detection/eval.py b/fluid/PaddleCV/object_detection/eval.py index 1e8ec8860e537d0c7bd472099970d3b8c21ee78b..106fb67e073648f94934e7b17f02b964d276e5ec 100644 --- a/fluid/PaddleCV/object_detection/eval.py +++ b/fluid/PaddleCV/object_detection/eval.py @@ -77,12 +77,13 @@ def eval(args, data_args, test_list, batch_size, model_dir=None): place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup_prog) - # yapf: disable - if model_dir: - def if_exist(var): - return os.path.exists(os.path.join(model_dir, var.name)) - fluid.io.load_vars(exe, model_dir, main_program=test_prog, predicate=if_exist) - # yapf: enable + + def if_exist(var): + return os.path.exists(os.path.join(model_dir, var.name)) + + fluid.io.load_vars( + exe, model_dir, main_program=test_prog, predicate=if_exist) + test_reader = reader.test(data_args, test_list, batch_size=batch_size) test_py_reader.decorate_paddle_reader(test_reader) @@ -96,7 +97,7 @@ def eval(args, data_args, test_list, batch_size, model_dir=None): if batch_id % 10 == 0: print("Batch {0}, map {1}".format(batch_id, test_map)) batch_id += 1 - except fluid.core.EOFException: + except (fluid.core.EOFException, StopIteration): test_py_reader.reset() print("Test model {0}, map {1}".format(model_dir, test_map)) diff --git a/fluid/PaddleCV/object_detection/main_quant.py b/fluid/PaddleCV/object_detection/main_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2baf66ce14154ee49e47a805a67bd9867eb73b --- /dev/null +++ b/fluid/PaddleCV/object_detection/main_quant.py @@ -0,0 +1,281 @@ +import os +import time +import numpy as np +import argparse +import functools +import shutil +import math + +import paddle +import paddle.fluid as fluid +import reader +from mobilenet_ssd import mobile_net +from utility import add_arguments, print_arguments +from train import build_program +from train import train_parameters +from infer import draw_bounding_box_on_image + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('learning_rate', float, 0.0001, "Learning rate.") +add_arg('batch_size', int, 64, "Minibatch size.") +add_arg('epoc_num', int, 20, "Epoch number.") +add_arg('use_gpu', bool, True, "Whether use GPU.") +add_arg('parallel', bool, True, "Whether train in parallel on multi-devices.") +add_arg('model_save_dir', str, 'quant_model', "The path to save model.") +add_arg('init_model', str, 'ssd_mobilenet_v1_pascalvoc', "The init model path.") +add_arg('ap_version', str, '11point', "mAP version can be integral or 11point.") +add_arg('image_shape', str, '3,300,300', "Input image shape.") +add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.") +add_arg('lr_epochs', str, '30,60', "The learning decay steps.") +add_arg('lr_decay_rates', str, '1,0.1,0.01', "The learning decay rates for each step.") +add_arg('data_dir', str, 'data/pascalvoc', "Data directory") +add_arg('act_quant_type', str, 'abs_max', "Quantize type of activation, whicn can be abs_max or range_abs_max") +add_arg('image_path', str, '', "The image used to inference and visualize.") +add_arg('confs_threshold', float, 0.5, "Confidence threshold to draw bbox.") +add_arg('mode', str, 'train', "Job mode can be one of ['train', 'test', 'infer'].") +#yapf: enable + +def test(exe, test_prog, map_eval, test_py_reader): + _, accum_map = map_eval.get_map_var() + map_eval.reset(exe) + test_py_reader.start() + try: + batch = 0 + while True: + test_map, = exe.run(test_prog, fetch_list=[accum_map]) + if batch % 10 == 0: + print("Batch {0}, map {1}".format(batch, test_map)) + batch += 1 + except fluid.core.EOFException: + test_py_reader.reset() + finally: + test_py_reader.reset() + print("Test map {0}".format(test_map)) + return test_map + + +def save_model(exe, main_prog, model_save_dir, postfix): + model_path = os.path.join(model_save_dir, postfix) + if os.path.isdir(model_path): + shutil.rmtree(model_path) + fluid.io.save_persistables(exe, model_path, main_program=main_prog) + + +def train(args, + data_args, + train_params, + train_file_list, + val_file_list): + + model_save_dir = args.model_save_dir + init_model = args.init_model + epoc_num = args.epoc_num + use_gpu = args.use_gpu + parallel = args.parallel + is_shuffle = True + act_quant_type = args.act_quant_type + + if use_gpu: + devices_num = fluid.core.get_cuda_device_count() + else: + devices_num = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + batch_size = train_params['batch_size'] + batch_size_per_device = batch_size // devices_num + iters_per_epoc = train_params["train_images"] // batch_size + num_workers = 4 + + startup_prog = fluid.Program() + train_prog = fluid.Program() + test_prog = fluid.Program() + + train_py_reader, loss = build_program( + main_prog=train_prog, + startup_prog=startup_prog, + train_params=train_params, + is_train=True) + test_py_reader, map_eval, _, _ = build_program( + main_prog=test_prog, + startup_prog=startup_prog, + train_params=train_params, + is_train=False) + + test_prog = test_prog.clone(for_test=True) + + transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8, + activation_bits=8, + activation_quantize_type=act_quant_type, + weight_quantize_type='abs_max') + + transpiler.training_transpile(train_prog, startup_prog) + transpiler.training_transpile(test_prog, startup_prog) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + if init_model: + print('Load init model %s.' % init_model) + def if_exist(var): + return os.path.exists(os.path.join(init_model, var.name)) + fluid.io.load_vars(exe, init_model, main_program=train_prog, + predicate=if_exist) + else: + print('There is no init model.') + + if parallel: + train_exe = fluid.ParallelExecutor(main_program=train_prog, + use_cuda=use_gpu, loss_name=loss.name) + + train_reader = reader.train(data_args, + train_file_list, + batch_size_per_device, + shuffle=is_shuffle, + use_multiprocessing=True, + num_workers=num_workers, + max_queue=24) + test_reader = reader.test(data_args, val_file_list, batch_size) + train_py_reader.decorate_paddle_reader(train_reader) + test_py_reader.decorate_paddle_reader(test_reader) + + train_py_reader.start() + best_map = 0. + try: + for epoc in range(epoc_num): + if epoc == 0: + # test quantized model without quantization-aware training. + test_map = test(exe, test_prog, map_eval, test_py_reader) + # train + for batch in range(iters_per_epoc): + start_time = time.time() + if parallel: + outs = train_exe.run(fetch_list=[loss.name]) + else: + outs = exe.run(train_prog, fetch_list=[loss]) + end_time = time.time() + avg_loss = np.mean(np.array(outs[0])) + if batch % 20 == 0: + print("Epoc {:d}, batch {:d}, loss {:.6f}, time {:.5f}".format( + epoc , batch, avg_loss, end_time - start_time)) + end_time = time.time() + test_map = test(exe, test_prog, map_eval, test_py_reader) + save_model(exe, train_prog, model_save_dir, str(epoc)) + if test_map > best_map: + best_map = test_map + save_model(exe, train_prog, model_save_dir, 'best_map') + print("Best test map {0}".format(best_map)) + except (fluid.core.EOFException, StopIteration): + train_py_reader.reset() + + +def eval(args, data_args, configs, val_file_list): + init_model = args.init_model + use_gpu = args.use_gpu + act_quant_type = args.act_quant_type + model_save_dir = args.model_save_dir + + batch_size = configs['batch_size'] + batch_size_per_device = batch_size + + startup_prog = fluid.Program() + test_prog = fluid.Program() + test_py_reader, map_eval, nmsed_out, image = build_program( + main_prog=test_prog, + startup_prog=startup_prog, + train_params=configs, + is_train=False) + test_prog = test_prog.clone(for_test=True) + + transpiler = fluid.contrib.QuantizeTranspiler(weight_bits=8, + activation_bits=8, + activation_quantize_type=act_quant_type, + weight_quantize_type='abs_max') + transpiler.training_transpile(test_prog, startup_prog) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + def if_exist(var): + return os.path.exists(os.path.join(init_model, var.name)) + fluid.io.load_vars(exe, init_model, main_program=test_prog, + predicate=if_exist) + + # freeze after load parameters + transpiler.freeze_program(test_prog, place) + + test_reader = reader.test(data_args, val_file_list, batch_size) + test_py_reader.decorate_paddle_reader(test_reader) + + test_map = test(exe, test_prog, map_eval, test_py_reader) + print("Test model {0}, map {1}".format(init_model, test_map)) + fluid.io.save_inference_model(model_save_dir, [image.name], + [nmsed_out], exe, test_prog) + + +def infer(args, data_args): + model_dir = args.init_model + image_path = args.image_path + confs_threshold = args.confs_threshold + voc_labels = 'data/pascalvoc/label_list' + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + [inference_program, feed , fetch] = fluid.io.load_inference_model( + dirname=model_dir, + executor=exe, + model_filename='__model__') + + print(np.array(fluid.global_scope().find_var('conv2d_20.w_0').get_tensor())) + print(np.max(np.array(fluid.global_scope().find_var('conv2d_20.w_0').get_tensor()))) + infer_reader = reader.infer(data_args, image_path) + data = infer_reader() + data = data.reshape((1,) + data.shape) + outs = exe.run(inference_program, + feed={feed[0]: data}, + fetch_list=fetch, + return_numpy=False) + out = np.array(outs[0]) + draw_bounding_box_on_image(image_path, out, confs_threshold, voc_labels) + + +if __name__ == '__main__': + args = parser.parse_args() + print_arguments(args) + + # for pascalvoc + label_file = 'label_list' + train_list = 'trainval.txt' + val_list = 'test.txt' + dataset = 'pascalvoc' + + mean_BGR = [float(m) for m in args.mean_BGR.split(",")] + image_shape = [int(m) for m in args.image_shape.split(",")] + lr_epochs = [int(m) for m in args.lr_epochs.split(",")] + lr_rates = [float(m) for m in args.lr_decay_rates.split(",")] + train_parameters[dataset]['image_shape'] = image_shape + train_parameters[dataset]['batch_size'] = args.batch_size + train_parameters[dataset]['lr'] = args.learning_rate + train_parameters[dataset]['epoc_num'] = args.epoc_num + train_parameters[dataset]['ap_version'] = args.ap_version + train_parameters[dataset]['lr_epochs'] = lr_epochs + train_parameters[dataset]['lr_decay'] = lr_rates + + data_args = reader.Settings( + dataset=dataset, + data_dir=args.data_dir, + label_file=label_file, + resize_h=image_shape[1], + resize_w=image_shape[2], + mean_value=mean_BGR, + apply_distort=True, + apply_expand=True, + ap_version = args.ap_version) + if args.mode == 'train': + train(args, data_args, train_parameters[dataset], train_list, val_list) + elif args.mode == 'test': + eval(args, data_args, train_parameters[dataset], val_list) + else: + infer(args, data_args) diff --git a/fluid/PaddleCV/object_detection/train.py b/fluid/PaddleCV/object_detection/train.py index 7552c92124c0fac44f34b647f358d7e0acf3b643..2d830bcdf1d7900ca2f27055a9ec7568f75b6211 100644 --- a/fluid/PaddleCV/object_detection/train.py +++ b/fluid/PaddleCV/object_detection/train.py @@ -5,6 +5,7 @@ import argparse import functools import shutil import math +import multiprocessing import paddle import paddle.fluid as fluid @@ -16,18 +17,18 @@ parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('learning_rate', float, 0.001, "Learning rate.") -add_arg('batch_size', int, 64, "Minibatch size.") +add_arg('batch_size', int, 64, "Minibatch size of all devices.") add_arg('epoc_num', int, 120, "Epoch number.") add_arg('use_gpu', bool, True, "Whether use GPU.") -add_arg('parallel', bool, True, "Parallel.") -add_arg('dataset', str, 'pascalvoc', "coco2014, coco2017, and pascalvoc.") +add_arg('parallel', bool, True, "Whether train in parallel on multi-devices.") +add_arg('dataset', str, 'pascalvoc', "dataset can be coco2014, coco2017, and pascalvoc.") add_arg('model_save_dir', str, 'model', "The path to save model.") add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.") -add_arg('ap_version', str, '11point', "Integral, 11point.") +add_arg('ap_version', str, '11point', "mAP version can be integral or 11point.") add_arg('image_shape', str, '3,300,300', "Input image shape.") -add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.") -add_arg('data_dir', str, 'data/pascalvoc', "data directory") -add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model") +add_arg('mean_BGR', str, '127.5,127.5,127.5', "Mean value for B,G,R channel which will be subtracted.") +add_arg('data_dir', str, 'data/pascalvoc', "Data directory.") +add_arg('enable_ce', bool, False, "Whether use CE to evaluate the model.") #yapf: enable train_parameters = { @@ -81,6 +82,7 @@ def build_program(main_prog, startup_prog, train_params, is_train): image_shape = train_params['image_shape'] class_num = train_params['class_num'] ap_version = train_params['ap_version'] + outs = [] with fluid.program_guard(main_prog, startup_prog): py_reader = fluid.layers.py_reader( capacity=64, @@ -98,11 +100,12 @@ def build_program(main_prog, startup_prog, train_params, is_train): loss = fluid.layers.reduce_sum(loss) optimizer = optimizer_setting(train_params) optimizer.minimize(loss) + outs = [py_reader, loss] else: with fluid.unique_name.guard("inference"): nmsed_out = fluid.layers.detection_output( locs, confs, box, box_var, nms_threshold=0.45) - loss = fluid.evaluator.DetectionMAP( + map_eval = fluid.evaluator.DetectionMAP( nmsed_out, gt_label, gt_box, @@ -111,7 +114,9 @@ def build_program(main_prog, startup_prog, train_params, is_train): overlap_threshold=0.5, evaluate_difficult=False, ap_version=ap_version) - return py_reader, loss + # nmsed_out and image is used to save mode for inference + outs = [py_reader, map_eval, nmsed_out, image] + return outs def train(args, @@ -127,8 +132,12 @@ def train(args, enable_ce = args.enable_ce is_shuffle = True - devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" - devices_num = len(devices.split(",")) + if not use_gpu: + devices_num = int(os.environ.get('CPU_NUM', + multiprocessing.cpu_count())) + else: + devices_num = fluid.core.get_cuda_device_count() + batch_size = train_params['batch_size'] epoc_num = train_params['epoc_num'] batch_size_per_device = batch_size // devices_num @@ -153,7 +162,7 @@ def train(args, startup_prog=startup_prog, train_params=train_params, is_train=True) - test_py_reader, map_eval = build_program( + test_py_reader, map_eval, _, _ = build_program( main_prog=test_prog, startup_prog=startup_prog, train_params=train_params, @@ -258,11 +267,9 @@ def train(args, print("kpis train_speed_card%s %f" % (devices_num, total_time / epoch_idx)) - except fluid.core.EOFException: - train_py_reader.reset() - except StopIteration: + except (fluid.core.EOFException, StopIteration): + train_reader().close() train_py_reader.reset() - train_py_reader.reset() if __name__ == '__main__': @@ -291,6 +298,7 @@ if __name__ == '__main__': train_parameters[dataset]['batch_size'] = args.batch_size train_parameters[dataset]['lr'] = args.learning_rate train_parameters[dataset]['epoc_num'] = args.epoc_num + train_parameters[dataset]['ap_version'] = args.ap_version data_args = reader.Settings( dataset=args.dataset,