diff --git a/.gitignore b/.gitignore index 4c39d631f8d5ceb2346373e84c837faaf9643450..d91c22d1624ea8f988ec19750f161c2c174fec5a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,9 @@ __pycache__/ # json file *.json +# log file +*.log + # Distribution / packaging /bin/ /build/ diff --git a/docs/advanced_tutorials/READER.md b/docs/advanced_tutorials/READER.md index 836ef0a5c2054e8d33b31a2293d3e1bacb929446..d809ad82a0fb30967d485f03f8e4d6cb27758710 100644 --- a/docs/advanced_tutorials/READER.md +++ b/docs/advanced_tutorials/READER.md @@ -406,7 +406,7 @@ reader = create_reader(cfg.EvalReader) # infer reader = create_reader(cfg.TestReader) # 将reader设置为DataLoader数据源 -loader.set_sample_list_generator(reader, place) +loader.set_sample_list_generator(reader) ``` 在运行程序中设置完数据处理模块后,就可以开始训练、评估与测试了,具体请参考相应运行程序python源码。 diff --git a/slim/distillation/distill.py b/slim/distillation/distill.py index dd3a86c0fcab74979baafbcc51becfa927770846..0372914bf63a75e2bde06bfb45041de3bd4a5bb2 100644 --- a/slim/distillation/distill.py +++ b/slim/distillation/distill.py @@ -160,7 +160,8 @@ def main(): start_iter = 0 train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg) - train_loader.set_sample_list_generator(train_reader, place) + # When iterable mode, set set_sample_list_generator(train_reader, place) + train_loader.set_sample_list_generator(train_reader) # get all student variables student_vars = [] @@ -183,7 +184,8 @@ def main(): eval_prog = eval_prog.clone(True) eval_reader = create_reader(cfg.EvalReader) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) # parse eval fetches extra_keys = [] diff --git a/slim/extensions/distill_pruned_model/distill_pruned_model.py b/slim/extensions/distill_pruned_model/distill_pruned_model.py index a2adc6cf7652026fb493e7b141e295454bd450e4..5733220613fde77107dd25ab99f7c5512ea22595 100644 --- a/slim/extensions/distill_pruned_model/distill_pruned_model.py +++ b/slim/extensions/distill_pruned_model/distill_pruned_model.py @@ -149,7 +149,8 @@ def main(): start_iter = 0 train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg) - train_loader.set_sample_list_generator(train_reader, place) + # When iterable mode, set set_sample_list_generator(train_reader, place) + train_loader.set_sample_list_generator(train_reader) eval_prog = fluid.Program() with fluid.program_guard(eval_prog, fluid.default_startup_program()): @@ -161,7 +162,8 @@ def main(): eval_prog = eval_prog.clone(True) eval_reader = create_reader(cfg.EvalReader) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) teacher_cfg = load_config(FLAGS.teacher_config) merge_config(FLAGS.opt) diff --git a/slim/nas/train_nas.py b/slim/nas/train_nas.py index 3c30c7ffdd1c50b2d3c24eb78de14e9e84ed3fa3..07bd2b58a0d006c7e09592799cd8ca1678ef03a4 100644 --- a/slim/nas/train_nas.py +++ b/slim/nas/train_nas.py @@ -296,7 +296,8 @@ def main(): fetches = archs(feed_vars, 'eval', cfg) eval_prog = eval_prog.clone(True) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) extra_keys = ['im_id', 'im_shape', 'gt_bbox'] eval_keys, eval_values, eval_cls = parse_fetches(fetches, eval_prog, extra_keys) @@ -324,8 +325,8 @@ def main(): exec_strategy=exec_strategy) if FLAGS.eval: compiled_eval_prog = fluid.CompiledProgram(eval_prog) - - train_loader.set_sample_list_generator(train_reader, place) + # When iterable mode, set set_sample_list_generator(train_reader, place) + train_loader.set_sample_list_generator(train_reader) train_stats = TrainingStats(cfg.log_smooth_window, train_keys) train_loader.start() diff --git a/slim/prune/eval.py b/slim/prune/eval.py index 73386d0f7026f8ec8785c57b0917773d28d80e8d..82281fea6922b1fbc954c136b382a5753701966e 100644 --- a/slim/prune/eval.py +++ b/slim/prune/eval.py @@ -78,7 +78,8 @@ def main(): exe.run(startup_prog) reader = create_reader(cfg.EvalReader) - loader.set_sample_list_generator(reader, place) + # When iterable mode, set set_sample_list_generator(reader, place) + loader.set_sample_list_generator(reader) dataset = cfg['EvalReader']['dataset'] diff --git a/slim/prune/infer.py b/slim/prune/infer.py index 7c7d825922896a970c71c6f97934639a336366e5..ed2e2a100e994f6b013b242d4a11033d2dae51c8 100644 --- a/slim/prune/infer.py +++ b/slim/prune/infer.py @@ -147,7 +147,8 @@ def main(): logger.info("pruned FLOPS: {}".format( float(base_flops - pruned_flops) / base_flops)) reader = create_reader(cfg.TestReader, devices_num=1) - loader.set_sample_list_generator(reader, place) + # When iterable mode, set set_sample_list_generator(reader, place) + loader.set_sample_list_generator(reader) exe.run(startup_prog) if cfg.weights: diff --git a/slim/prune/prune.py b/slim/prune/prune.py index 01b3488d7d80e1090fb5ddc0a3b3bc694eb94918..795905bee631d1640d0736e06e8cdc452ad0dfd7 100644 --- a/slim/prune/prune.py +++ b/slim/prune/prune.py @@ -132,7 +132,8 @@ def main(): eval_prog = eval_prog.clone(True) eval_reader = create_reader(cfg.EvalReader) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) # parse eval fetches extra_keys = [] diff --git a/slim/quantization/eval.py b/slim/quantization/eval.py index ecc77e45f85e7b6ef67714fc28c85c604485a182..cd5216a25700c2356fc206f463e2657a0249fad3 100644 --- a/slim/quantization/eval.py +++ b/slim/quantization/eval.py @@ -73,7 +73,8 @@ def main(): eval_prog = eval_prog.clone(True) reader = create_reader(cfg.EvalReader) - loader.set_sample_list_generator(reader, place) + # When iterable mode, set set_sample_list_generator(reader, place) + loader.set_sample_list_generator(reader) # eval already exists json file if FLAGS.json_eval: diff --git a/slim/quantization/infer.py b/slim/quantization/infer.py index cb16d3b95f8e6c4d51763593db97d48692476042..a27b804c2c936cfeaa028358cae171b4fcd47bc0 100644 --- a/slim/quantization/infer.py +++ b/slim/quantization/infer.py @@ -75,7 +75,8 @@ def main(): infer_prog = infer_prog.clone(True) reader = create_reader(cfg.TestReader) - loader.set_sample_list_generator(reader, place) + # When iterable mode, set set_sample_list_generator(reader, place) + loader.set_sample_list_generator(reader) not_quant_pattern = [] if FLAGS.not_quant_pattern: not_quant_pattern = FLAGS.not_quant_pattern diff --git a/slim/quantization/train.py b/slim/quantization/train.py index 63c65816b8a337c8f0c169ff260c47427d757b74..81fe62e0c9d9a182a9b26f06ea6672937b9078de 100644 --- a/slim/quantization/train.py +++ b/slim/quantization/train.py @@ -129,7 +129,8 @@ def main(): eval_prog = eval_prog.clone(True) eval_reader = create_reader(cfg.EvalReader) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) # parse eval fetches extra_keys = [] @@ -210,7 +211,8 @@ def main(): train_reader = create_reader(cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num) - train_loader.set_sample_list_generator(train_reader, place) + # When iterable mode, set set_sample_list_generator(train_reader, place) + train_loader.set_sample_list_generator(train_reader) # whether output bbox is normalized in model output layer is_bbox_normalized = False diff --git a/slim/sensitive/sensitive.py b/slim/sensitive/sensitive.py index 4ddea20150623a62f057e4645591ad9050447d27..a825f199b77277b3edaa13e6c4ad1c418855b43e 100644 --- a/slim/sensitive/sensitive.py +++ b/slim/sensitive/sensitive.py @@ -84,7 +84,8 @@ def main(): return eval_reader = create_reader(cfg.EvalReader) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) # parse eval fetches extra_keys = [] diff --git a/tools/eval.py b/tools/eval.py index 84da9b49ade12ce42bed5c1d3e2ee7053c373cb1..e7daf2bcee3c234667b0affc703ee6b02da00990 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -74,7 +74,8 @@ def main(): eval_prog = eval_prog.clone(True) reader = create_reader(cfg.EvalReader, devices_num=1) - loader.set_sample_list_generator(reader, place) + # When iterable mode, set set_sample_list_generator(reader, place) + loader.set_sample_list_generator(reader) dataset = cfg['EvalReader']['dataset'] diff --git a/tools/infer.py b/tools/infer.py index 034e33e66a4a422a70bf5d746d4c472618baa5e1..4f2ae5743f2be9d9b0c08ab2b55eaf341261dd55 100644 --- a/tools/infer.py +++ b/tools/infer.py @@ -120,7 +120,8 @@ def main(): infer_prog = infer_prog.clone(True) reader = create_reader(cfg.TestReader, devices_num=1) - loader.set_sample_list_generator(reader, place) + # When iterable mode, set set_sample_list_generator(reader, place) + loader.set_sample_list_generator(reader) exe.run(startup_prog) if cfg.weights: diff --git a/tools/train.py b/tools/train.py index dd2edbd4383524aa038347ea00997c997b980e93..0a541667dfecf122453bccc3108cf5dd04c3fba5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -146,7 +146,8 @@ def main(): eval_prog = eval_prog.clone(True) eval_reader = create_reader(cfg.EvalReader, devices_num=1) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) # parse eval fetches extra_keys = [] @@ -206,7 +207,8 @@ def main(): cfg, devices_num=devices_num, num_trainers=num_trainers) - train_loader.set_sample_list_generator(train_reader, place) + # When iterable mode, set set_sample_list_generator(train_reader, place) + train_loader.set_sample_list_generator(train_reader) # whether output bbox is normalized in model output layer is_bbox_normalized = False diff --git a/tools/train_multi_machine.py b/tools/train_multi_machine.py index 2a82f0505612291fdbecae940110fe9462505825..6cfa8be7cb997dffef8b2524603a14783ce3dc50 100644 --- a/tools/train_multi_machine.py +++ b/tools/train_multi_machine.py @@ -162,7 +162,8 @@ def main(): eval_prog = eval_prog.clone(True) eval_reader = create_reader(cfg.EvalReader, devices_num=1) - eval_loader.set_sample_list_generator(eval_reader, place) + # When iterable mode, set set_sample_list_generator(eval_reader, place) + eval_loader.set_sample_list_generator(eval_reader) # parse eval fetches extra_keys = [] @@ -200,7 +201,8 @@ def main(): cfg.TrainReader, (cfg.max_iters - start_iter) * devices_num, cfg, devices_num=devices_num) - train_loader.set_sample_list_generator(train_reader, place) + # When iterable mode, set set_sample_list_generator(train_reader, place) + train_loader.set_sample_list_generator(train_reader) # whether output bbox is normalized in model output layer is_bbox_normalized = False