未验证 提交 9e87942a 编写于 作者: X xiaoting 提交者: GitHub

polish the yolov3 dygraph (#4343)

* remove num_workers

* fix devices_num

* fix data loader

* remove args syncbn

* fix typo
上级 a027cb59
......@@ -139,7 +139,6 @@ Train Loss
模型评估是指对训练完毕的模型评估各类性能指标。本示例采用[COCO官方评估](http://cocodataset.org/#detections-eval)
sh ./weights/download.sh
`eval.py`是评估模块的主要执行程序,调用示例如下:
......@@ -151,6 +150,8 @@ Train Loss
- 通过设置`export CUDA_VISIBLE_DEVICES=0`指定单卡GPU评估。
## 进阶使用
### 背景介绍
......
......@@ -197,6 +197,7 @@ class DataSetReader(object):
gt_scores, mixup_im, mixup_gt_boxes,
mixup_gt_labels, mixup_gt_scores)
im, gt_boxes, gt_labels, gt_scores = \
image_utils.image_augment(im, gt_boxes, gt_labels,
gt_scores, size, mean)
......@@ -325,10 +326,11 @@ def train(size=416,
place = fluid.CPUPlace()
data_loader.set_sample_list_generator(infinite_reader,places=place)
generator_out = []
for data in data_loader():
for i in data:
generator_out.append(i.numpy()[0])
yield [generator_out]
for data in data_loader():
im, gt_boxes, gt_labels, gt_scores = data[0].numpy(),data[1].numpy(),data[2].numpy(),data[3].numpy()
for i in range(batch_size):
generator_out.append([im[i],gt_boxes[i],gt_labels[i],gt_scores[i]])
yield generator_out
generator_out = []
cnt += 1
if cnt >= total_iter:
......
......@@ -67,9 +67,10 @@ def train():
os.makedirs(cfg.model_save_dir)
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
if not cfg.use_gpu:
palce = fluid.CPUPlace()
if cfg.use_gpu:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) if cfg.use_data_parallel else fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context()
......
......@@ -133,7 +133,6 @@ def parse_args():
add_arg('no_mixup_iter', int, 40000, "Disable mixup in last N iter.")
# TRAIN TEST INFER
add_arg('input_size', int, 608, "Image input size of YOLOv3.")
add_arg('syncbn', bool, True, "Whether to use synchronized batch normalization.")
add_arg('random_shape', bool, True, "Resize to random shape for train reader.")
add_arg('valid_thresh', float, 0.005, "Valid confidence score for NMS.")
add_arg('nms_thresh', float, 0.45, "NMS threshold.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册