From 99e7dd5ef9ac2c5d2c6eb8b53700b8fbda35f1bb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E8=B4=BE=E6=99=93?=
Date: Fri, 24 May 2019 19:27:00 +0800
Subject: [PATCH] add yolov3 ce (#2312)
---
PaddleCV/yolov3/.run_ce.sh | 7 ++
PaddleCV/yolov3/README.md | 5 +-
PaddleCV/yolov3/README_cn.md | 3 +-
PaddleCV/yolov3/_ce.py | 48 +++++++++
PaddleCV/yolov3/box_utils.py | 62 ++++++++----
PaddleCV/yolov3/config.py | 7 +-
PaddleCV/yolov3/eval.py | 19 ++--
PaddleCV/yolov3/image_utils.py | 90 ++++++++---------
PaddleCV/yolov3/infer.py | 14 +--
PaddleCV/yolov3/models/darknet.py | 110 +++++++++++---------
PaddleCV/yolov3/models/yolov3.py | 162 ++++++++++++++++--------------
PaddleCV/yolov3/reader.py | 85 ++++++++--------
PaddleCV/yolov3/train.py | 58 +++++++----
PaddleCV/yolov3/utility.py | 7 +-
14 files changed, 397 insertions(+), 280 deletions(-)
create mode 100644 PaddleCV/yolov3/.run_ce.sh
create mode 100644 PaddleCV/yolov3/_ce.py
diff --git a/PaddleCV/yolov3/.run_ce.sh b/PaddleCV/yolov3/.run_ce.sh
new file mode 100644
index 00000000..761e0377
--- /dev/null
+++ b/PaddleCV/yolov3/.run_ce.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+#This file is only used for continuous evaluation.
+export CUDA_VISIBLE_DEVICES=0
+python train.py --enable_ce True --use_multiprocess False --snapshot_iter 100 --max_iter 200 | python _ce.py
+export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+python train.py --enable_ce True --use_multiprocess False --snapshot_iter 100 --max_iter 200 | python _ce.py
diff --git a/PaddleCV/yolov3/README.md b/PaddleCV/yolov3/README.md
index d2747df2..3750af11 100644
--- a/PaddleCV/yolov3/README.md
+++ b/PaddleCV/yolov3/README.md
@@ -62,7 +62,7 @@ The data catalog structure is as follows:
│ ├── 000000000139.jpg
│ ├── 000000000285.jpg
| ...
-
+
```
## Training
@@ -170,7 +170,7 @@ Inference speed(Tesla P40):
| input size | 608x608 | 416x416 | 320x320 |
|:-------------:| :-----: | :-----: | :-----: |
-| infer speed | 48 ms/frame | 29 ms/frame |24 ms/frame |
+| infer speed | 48 ms/frame | 29 ms/frame |24 ms/frame |
Visualization of infer result is shown as below:
@@ -181,4 +181,3 @@ Visualization of infer result is shown as below:
YOLOv3 Visualization Examples
-
diff --git a/PaddleCV/yolov3/README_cn.md b/PaddleCV/yolov3/README_cn.md
index 247dbc7e..51d12127 100644
--- a/PaddleCV/yolov3/README_cn.md
+++ b/PaddleCV/yolov3/README_cn.md
@@ -172,7 +172,7 @@ Train Loss
| input size | 608x608 | 416x416 | 320x320 |
|:-------------:| :-----: | :-----: | :-----: |
-| infer speed | 48 ms/frame | 29 ms/frame |24 ms/frame |
+| infer speed | 48 ms/frame | 29 ms/frame |24 ms/frame |
下图为模型可视化预测结果:
@@ -182,4 +182,3 @@ Train Loss
YOLOv3 预测可视化
-
diff --git a/PaddleCV/yolov3/_ce.py b/PaddleCV/yolov3/_ce.py
new file mode 100644
index 00000000..c0ce52df
--- /dev/null
+++ b/PaddleCV/yolov3/_ce.py
@@ -0,0 +1,48 @@
+### This file is only used for continuous evaluation test!
+from __future__ import print_function
+from __future__ import division
+from __future__ import absolute_import
+import os
+import sys
+sys.path.append(os.environ['ceroot'])
+from kpi import CostKpi
+from kpi import DurationKpi
+
+train_cost_1card_kpi = CostKpi(
+ 'train_cost_1card', 0.02, 0, actived=True, desc='train cost')
+train_duration_1card_kpi = DurationKpi(
+ 'train_duration_1card', 0.1, 0, actived=True, desc='train duration')
+train_cost_8card_kpi = CostKpi(
+ 'train_cost_8card', 0.02, 0, actived=True, desc='train cost')
+train_duration_8card_kpi = DurationKpi(
+ 'train_duration_8card', 0.1, 0, actived=True, desc='train duration')
+tracking_kpis = [
+ train_cost_1card_kpi, train_duration_1card_kpi, train_cost_8card_kpi,
+ train_duration_8card_kpi
+]
+
+
+def parse_log(log):
+ for line in log.split('\n'):
+ fs = line.strip().split('\t')
+ print(fs)
+ if len(fs) == 3 and fs[0] == 'kpis':
+ print("-----%s" % fs)
+ kpi_name = fs[1]
+ kpi_value = float(fs[2])
+ yield kpi_name, kpi_value
+
+
+def log_to_ce(log):
+ kpi_tracker = {}
+ for kpi in tracking_kpis:
+ kpi_tracker[kpi.name] = kpi
+ for (kpi_name, kpi_value) in parse_log(log):
+ print(kpi_name, kpi_value)
+ kpi_tracker[kpi_name].add_record(kpi_value)
+ kpi_tracker[kpi_name].persist()
+
+
+if __name__ == '__main__':
+ log = sys.stdin.read()
+ log_to_ce(log)
diff --git a/PaddleCV/yolov3/box_utils.py b/PaddleCV/yolov3/box_utils.py
index 37ad5d7c..b5bc4250 100644
--- a/PaddleCV/yolov3/box_utils.py
+++ b/PaddleCV/yolov3/box_utils.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -47,6 +46,7 @@ def coco_anno_box_to_center_relative(box, img_height, img_width):
return np.array([x, y, w, h])
+
def clip_relative_box_in_image(x, y, w, h):
"""Clip relative box coordinates x, y, w, h to [0, 1]"""
x1 = max(x - w / 2, 0.)
@@ -58,6 +58,7 @@ def clip_relative_box_in_image(x, y, w, h):
w = x2 - x1
h = y2 - y1
+
def box_xywh_to_xyxy(box):
shape = box.shape
assert shape[-1] == 4, "Box shape[-1] should be 4."
@@ -68,6 +69,7 @@ def box_xywh_to_xyxy(box):
box = box.reshape(shape)
return box
+
def box_iou_xywh(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
@@ -92,6 +94,7 @@ def box_iou_xywh(box1, box2):
return inter_area / (b1_area + b2_area - inter_area)
+
def box_iou_xyxy(box1, box2):
assert box1.shape[-1] == 4, "Box1 shape[-1] should be 4."
assert box2.shape[-1] == 4, "Box2 shape[-1] should be 4."
@@ -114,17 +117,21 @@ def box_iou_xyxy(box1, box2):
return inter_area / (b1_area + b2_area - inter_area)
+
def box_crop(boxes, labels, scores, crop, img_shape):
x, y, w, h = map(float, crop)
im_w, im_h = map(float, img_shape)
boxes = boxes.copy()
- boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (boxes[:, 0] + boxes[:, 2] / 2) * im_w
- boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (boxes[:, 1] + boxes[:, 3] / 2) * im_h
+ boxes[:, 0], boxes[:, 2] = (boxes[:, 0] - boxes[:, 2] / 2) * im_w, (
+ boxes[:, 0] + boxes[:, 2] / 2) * im_w
+ boxes[:, 1], boxes[:, 3] = (boxes[:, 1] - boxes[:, 3] / 2) * im_h, (
+ boxes[:, 1] + boxes[:, 3] / 2) * im_h
crop_box = np.array([x, y, x + w, y + h])
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
- mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(axis=1)
+ mask = np.logical_and(crop_box[:2] <= centers, centers <= crop_box[2:]).all(
+ axis=1)
boxes[:, :2] = np.maximum(boxes[:, :2], crop_box[:2])
boxes[:, 2:] = np.minimum(boxes[:, 2:], crop_box[2:])
@@ -135,12 +142,20 @@ def box_crop(boxes, labels, scores, crop, img_shape):
boxes = boxes * np.expand_dims(mask.astype('float32'), axis=1)
labels = labels * mask.astype('float32')
scores = scores * mask.astype('float32')
- boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (boxes[:, 2] - boxes[:, 0]) / w
- boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (boxes[:, 3] - boxes[:, 1]) / h
+ boxes[:, 0], boxes[:, 2] = (boxes[:, 0] + boxes[:, 2]) / 2 / w, (
+ boxes[:, 2] - boxes[:, 0]) / w
+ boxes[:, 1], boxes[:, 3] = (boxes[:, 1] + boxes[:, 3]) / 2 / h, (
+ boxes[:, 3] - boxes[:, 1]) / h
return boxes, labels, scores, mask.sum()
-def draw_boxes_on_image(image_path, boxes, scores, labels, label_names, score_thresh=0.5):
+
+def draw_boxes_on_image(image_path,
+ boxes,
+ scores,
+ labels,
+ label_names,
+ score_thresh=0.5):
image = np.array(Image.open(image_path))
plt.figure()
_, ax = plt.subplots(1)
@@ -158,22 +173,33 @@ def draw_boxes_on_image(image_path, boxes, scores, labels, label_names, score_th
if label not in colors:
colors[label] = plt.get_cmap('hsv')(label / len(label_names))
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
- rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1,
- fill=False, linewidth=2.0,
- edgecolor=colors[label])
+ rect = plt.Rectangle(
+ (x1, y1),
+ x2 - x1,
+ y2 - y1,
+ fill=False,
+ linewidth=2.0,
+ edgecolor=colors[label])
ax.add_patch(rect)
- ax.text(x1, y1, '{} {:.4f}'.format(label_names[label], score),
- verticalalignment='bottom', horizontalalignment='left',
- bbox={'facecolor': colors[label], 'alpha': 0.5, 'pad': 0},
- fontsize=8, color='white')
- print("\t {:15s} at {:25} score: {:.5f}".format(
- label_names[int(label)], str(list(map(int, list(box)))), score))
+ ax.text(
+ x1,
+ y1,
+ '{} {:.4f}'.format(label_names[label], score),
+ verticalalignment='bottom',
+ horizontalalignment='left',
+ bbox={'facecolor': colors[label],
+ 'alpha': 0.5,
+ 'pad': 0},
+ fontsize=8,
+ color='white')
+ print("\t {:15s} at {:25} score: {:.5f}".format(label_names[int(
+ label)], str(list(map(int, list(box)))), score))
image_name = image_name.replace('jpg', 'png')
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
- plt.savefig("./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
+ plt.savefig(
+ "./output/{}".format(image_name), bbox_inches='tight', pad_inches=0.0)
print("Detect result save at ./output/{}\n".format(image_name))
plt.cla()
plt.close('all')
-
diff --git a/PaddleCV/yolov3/config.py b/PaddleCV/yolov3/config.py
index b7e1eb1c..784cffed 100644
--- a/PaddleCV/yolov3/config.py
+++ b/PaddleCV/yolov3/config.py
@@ -33,7 +33,6 @@ _C.gt_min_area = -1
# max target box number in an image
_C.max_box_num = 50
-
#
# Training options
#
@@ -53,7 +52,6 @@ _C.nms_posk = 100
# score threshold for draw box in debug mode
_C.draw_thresh = 0.5
-
#
# Model options
#
@@ -65,7 +63,9 @@ _C.pixel_means = [0.485, 0.456, 0.406]
_C.pixel_stds = [0.229, 0.224, 0.225]
# anchors box weight and height
-_C.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
+_C.anchors = [
+ 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326
+]
# anchor mask of each yolo layer
_C.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
@@ -73,7 +73,6 @@ _C.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
# IoU threshold to ignore objectness loss of pred box
_C.ignore_thresh = .7
-
#
# SOLVER options
#
diff --git a/PaddleCV/yolov3/eval.py b/PaddleCV/yolov3/eval.py
index 0381ec44..7393829f 100644
--- a/PaddleCV/yolov3/eval.py
+++ b/PaddleCV/yolov3/eval.py
@@ -64,12 +64,12 @@ def eval():
w = x2 - x1 + 1
h = y2 - y1 + 1
bbox = [x1, y1, w, h]
-
+
res = {
- 'image_id': im_id,
- 'category_id': label_ids[int(label)],
- 'bbox': list(map(float, bbox)),
- 'score': float(score)
+ 'image_id': im_id,
+ 'category_id': label_ids[int(label)],
+ 'bbox': list(map(float, bbox)),
+ 'score': float(score)
}
result.append(res)
return result
@@ -79,11 +79,10 @@ def eval():
total_time = 0
for batch_id, batch_data in enumerate(test_reader()):
start_time = time.time()
- batch_outputs = exe.run(
- fetch_list=[v.name for v in fetch_list],
- feed=feeder.feed(batch_data),
- return_numpy=False,
- use_program_cache=True)
+ batch_outputs = exe.run(fetch_list=[v.name for v in fetch_list],
+ feed=feeder.feed(batch_data),
+ return_numpy=False,
+ use_program_cache=True)
lod = batch_outputs[0].lod()[0]
nmsed_boxes = np.array(batch_outputs[0])
if nmsed_boxes.shape[1] != 6:
diff --git a/PaddleCV/yolov3/image_utils.py b/PaddleCV/yolov3/image_utils.py
index 2e713525..16edd255 100644
--- a/PaddleCV/yolov3/image_utils.py
+++ b/PaddleCV/yolov3/image_utils.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -30,46 +29,41 @@ def random_distort(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
-
+
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
-
+
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
-
+
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
-
+
img = Image.fromarray(img)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
img = np.asarray(img)
-
+
return img
-def random_crop(img,
- boxes,
- labels,
- scores,
- scales=[0.3, 1.0],
- max_ratio=2.0,
- constraints=None,
+def random_crop(img,
+ boxes,
+ labels,
+ scores,
+ scales=[0.3, 1.0],
+ max_ratio=2.0,
+ constraints=None,
max_trial=50):
if len(boxes) == 0:
return img, boxes
if not constraints:
- constraints = [
- (0.1, 1.0),
- (0.3, 1.0),
- (0.5, 1.0),
- (0.7, 1.0),
- (0.9, 1.0),
- (0.0, 1.0)]
+ constraints = [(0.1, 1.0), (0.3, 1.0), (0.5, 1.0), (0.7, 1.0),
+ (0.9, 1.0), (0.0, 1.0)]
img = Image.fromarray(img)
w, h = img.size
@@ -83,12 +77,9 @@ def random_crop(img,
crop_w = int(w * scale * np.sqrt(aspect_ratio))
crop_x = random.randrange(w - crop_w)
crop_y = random.randrange(h - crop_h)
- crop_box = np.array([[
- (crop_x + crop_w / 2.0) / w,
- (crop_y + crop_h / 2.0) / h,
- crop_w / float(w),
- crop_h /float(h)
- ]])
+ crop_box = np.array([[(crop_x + crop_w / 2.0) / w,
+ (crop_y + crop_h / 2.0) / h,
+ crop_w / float(w), crop_h / float(h)]])
iou = box_utils.box_iou_xywh(crop_box, boxes)
if min_iou <= iou.min() and max_iou >= iou.max():
@@ -101,19 +92,21 @@ def random_crop(img,
box_utils.box_crop(boxes, labels, scores, crop, (w, h))
if box_num < 1:
continue
- img = img.crop((crop[0], crop[1], crop[0] + crop[2],
+ img = img.crop((crop[0], crop[1], crop[0] + crop[2],
crop[1] + crop[3])).resize(img.size, Image.LANCZOS)
img = np.asarray(img)
return img, crop_boxes, crop_labels, crop_scores
img = np.asarray(img)
return img, boxes, labels, scores
+
def random_flip(img, gtboxes, thresh=0.5):
if random.random() > thresh:
img = img[:, ::-1, :]
gtboxes[:, 0] = 1.0 - gtboxes[:, 0]
return img, gtboxes
+
def random_interp(img, size, interp=None):
interp_method = [
cv2.INTER_NEAREST,
@@ -121,28 +114,29 @@ def random_interp(img, size, interp=None):
cv2.INTER_AREA,
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
- ]
+ ]
if not interp or interp not in interp_method:
interp = interp_method[random.randint(0, len(interp_method) - 1)]
h, w, _ = img.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
- img = cv2.resize(img, None, None, fx=im_scale_x, fy=im_scale_y,
- interpolation=interp)
+ img = cv2.resize(
+ img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp)
return img
-def random_expand(img,
- gtboxes,
- max_ratio=4.,
- fill=None,
- keep_ratio=True,
+
+def random_expand(img,
+ gtboxes,
+ max_ratio=4.,
+ fill=None,
+ keep_ratio=True,
thresh=0.5):
if random.random() > thresh:
return img, gtboxes
if max_ratio < 1.0:
return img, gtboxes
-
+
h, w, c = img.shape
ratio_x = random.uniform(1, max_ratio)
if keep_ratio:
@@ -151,15 +145,15 @@ def random_expand(img,
ratio_y = random.uniform(1, max_ratio)
oh = int(h * ratio_y)
ow = int(w * ratio_x)
- off_x = random.randint(0, ow -w)
- off_y = random.randint(0, oh -h)
+ off_x = random.randint(0, ow - w)
+ off_y = random.randint(0, oh - h)
out_img = np.zeros((oh, ow, c))
if fill and len(fill) == c:
for i in range(c):
out_img[:, :, i] = fill[i] * 255.0
- out_img[off_y: off_y + h, off_x: off_x + w, :] = img
+ out_img[off_y:off_y + h, off_x:off_x + w, :] = img
gtboxes[:, 0] = ((gtboxes[:, 0] * w) + off_x) / float(ow)
gtboxes[:, 1] = ((gtboxes[:, 1] * h) + off_y) / float(oh)
gtboxes[:, 2] = gtboxes[:, 2] / ratio_x
@@ -167,21 +161,17 @@ def random_expand(img,
return out_img.astype('uint8'), gtboxes
+
def shuffle_gtbox(gtbox, gtlabel, gtscore):
- gt = np.concatenate([gtbox, gtlabel[:, np.newaxis],
- gtscore[:, np.newaxis]], axis=1)
+ gt = np.concatenate(
+ [gtbox, gtlabel[:, np.newaxis], gtscore[:, np.newaxis]], axis=1)
idx = np.arange(gt.shape[0])
np.random.shuffle(idx)
gt = gt[idx, :]
return gt[:, :4], gt[:, 4], gt[:, 5]
-def image_mixup(img1,
- gtboxes1,
- gtlabels1,
- gtscores1,
- img2,
- gtboxes2,
- gtlabels2,
+
+def image_mixup(img1, gtboxes1, gtlabels1, gtscores1, img2, gtboxes2, gtlabels2,
gtscores2):
factor = np.random.beta(1.5, 1.5)
factor = max(0.0, min(1.0, factor))
@@ -229,7 +219,8 @@ def image_mixup(img1,
gtscores[:gt_num] = gtscores_all[:gt_num]
return img.astype('uint8'), gtboxes, gtlabels, gtscores
-def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
+
+def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
img = random_distort(img)
img, gtboxes = random_expand(img, gtboxes, fill=means)
img, gtboxes, gtlabels, gtscores = \
@@ -240,4 +231,3 @@ def image_augment(img, gtboxes, gtlabels, gtscores, size, means=None):
return img.astype('float32'), gtboxes.astype('float32'), \
gtlabels.astype('int32'), gtscores.astype('float32')
-
diff --git a/PaddleCV/yolov3/infer.py b/PaddleCV/yolov3/infer.py
index 58615ccf..5520efff 100644
--- a/PaddleCV/yolov3/infer.py
+++ b/PaddleCV/yolov3/infer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import os
import time
import numpy as np
@@ -54,14 +53,14 @@ def infer():
if image_name.split('.')[-1] in ['jpg', 'png']:
image_names.append(image_name)
for image_name in image_names:
- infer_reader = reader.infer(input_size, os.path.join(cfg.image_path, image_name))
+ infer_reader = reader.infer(input_size,
+ os.path.join(cfg.image_path, image_name))
label_names, _ = reader.get_label_infos()
data = next(infer_reader())
im_shape = data[0][2]
- outputs = exe.run(
- fetch_list=[v.name for v in fetch_list],
- feed=feeder.feed(data),
- return_numpy=False)
+ outputs = exe.run(fetch_list=[v.name for v in fetch_list],
+ feed=feeder.feed(data),
+ return_numpy=False)
bboxes = np.array(outputs[0])
if bboxes.shape[1] != 6:
print("No object found in {}".format(image_name))
@@ -71,7 +70,8 @@ def infer():
boxes = bboxes[:, 2:].astype('float32')
path = os.path.join(cfg.image_path, image_name)
- box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names, cfg.draw_thresh)
+ box_utils.draw_boxes_on_image(path, boxes, scores, labels, label_names,
+ cfg.draw_thresh)
if __name__ == '__main__':
diff --git a/PaddleCV/yolov3/models/darknet.py b/PaddleCV/yolov3/models/darknet.py
index bfce6f3b..9b9b7dd6 100644
--- a/PaddleCV/yolov3/models/darknet.py
+++ b/PaddleCV/yolov3/models/darknet.py
@@ -17,6 +17,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from paddle.fluid.regularizer import L2Decay
+
def conv_bn_layer(input,
ch_out,
filter_size,
@@ -32,8 +33,9 @@ def conv_bn_layer(input,
stride=stride,
padding=padding,
act=None,
- param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02),
- name=name+".conv.weights"),
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.Normal(0., 0.02),
+ name=name + ".conv.weights"),
bias_attr=False)
bn_name = name + ".bn"
@@ -42,72 +44,88 @@ def conv_bn_layer(input,
act=None,
is_test=is_test,
param_attr=ParamAttr(
- initializer=fluid.initializer.Normal(0., 0.02),
- regularizer=L2Decay(0.),
- name=bn_name + '.scale'),
+ initializer=fluid.initializer.Normal(0., 0.02),
+ regularizer=L2Decay(0.),
+ name=bn_name + '.scale'),
bias_attr=ParamAttr(
- initializer=fluid.initializer.Constant(0.0),
- regularizer=L2Decay(0.),
- name=bn_name + '.offset'),
+ initializer=fluid.initializer.Constant(0.0),
+ regularizer=L2Decay(0.),
+ name=bn_name + '.offset'),
moving_mean_name=bn_name + '.mean',
moving_variance_name=bn_name + '.var')
if act == 'leaky':
out = fluid.layers.leaky_relu(x=out, alpha=0.1)
return out
-def downsample(input,
- ch_out,
- filter_size=3,
- stride=2,
- padding=1,
- is_test=True,
+
+def downsample(input,
+ ch_out,
+ filter_size=3,
+ stride=2,
+ padding=1,
+ is_test=True,
name=None):
- return conv_bn_layer(input,
- ch_out=ch_out,
- filter_size=filter_size,
- stride=stride,
- padding=padding,
- is_test=is_test,
- name=name)
+ return conv_bn_layer(
+ input,
+ ch_out=ch_out,
+ filter_size=filter_size,
+ stride=stride,
+ padding=padding,
+ is_test=is_test,
+ name=name)
+
def basicblock(input, ch_out, is_test=True, name=None):
- conv1 = conv_bn_layer(input, ch_out, 1, 1, 0,
- is_test=is_test, name=name+".0")
- conv2 = conv_bn_layer(conv1, ch_out*2, 3, 1, 1,
- is_test=is_test, name=name+".1")
+ conv1 = conv_bn_layer(
+ input, ch_out, 1, 1, 0, is_test=is_test, name=name + ".0")
+ conv2 = conv_bn_layer(
+ conv1, ch_out * 2, 3, 1, 1, is_test=is_test, name=name + ".1")
out = fluid.layers.elementwise_add(x=input, y=conv2, act=None)
return out
+
def layer_warp(block_func, input, ch_out, count, is_test=True, name=None):
- res_out = block_func(input, ch_out, is_test=is_test,
- name='{}.0'.format(name))
+ res_out = block_func(
+ input, ch_out, is_test=is_test, name='{}.0'.format(name))
for j in range(1, count):
- res_out = block_func(res_out, ch_out, is_test=is_test,
- name='{}.{}'.format(name, j))
+ res_out = block_func(
+ res_out, ch_out, is_test=is_test, name='{}.{}'.format(name, j))
return res_out
-DarkNet_cfg = {
- 53: ([1,2,8,8,4],basicblock)
-}
+
+DarkNet_cfg = {53: ([1, 2, 8, 8, 4], basicblock)}
+
def add_DarkNet53_conv_body(body_input, is_test=True):
stages, block_func = DarkNet_cfg[53]
stages = stages[0:5]
- conv1 = conv_bn_layer(body_input, ch_out=32, filter_size=3,
- stride=1, padding=1, is_test=is_test,
- name="yolo_input")
- downsample_ = downsample(conv1, ch_out=conv1.shape[1]*2,
- is_test=is_test,
- name="yolo_input.downsample")
+ conv1 = conv_bn_layer(
+ body_input,
+ ch_out=32,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ is_test=is_test,
+ name="yolo_input")
+ downsample_ = downsample(
+ conv1,
+ ch_out=conv1.shape[1] * 2,
+ is_test=is_test,
+ name="yolo_input.downsample")
blocks = []
for i, stage in enumerate(stages):
- block = layer_warp(block_func, downsample_, 32 *(2**i),
- stage, is_test=is_test,
- name="stage.{}".format(i))
+ block = layer_warp(
+ block_func,
+ downsample_,
+ 32 * (2**i),
+ stage,
+ is_test=is_test,
+ name="stage.{}".format(i))
blocks.append(block)
- if i < len(stages) - 1: # do not downsaple in the last stage
- downsample_ = downsample(block, ch_out=block.shape[1]*2,
- is_test=is_test,
- name="stage.{}.downsample".format(i))
+ if i < len(stages) - 1: # do not downsaple in the last stage
+ downsample_ = downsample(
+ block,
+ ch_out=block.shape[1] * 2,
+ is_test=is_test,
+ name="stage.{}.downsample".format(i))
return blocks[-1:-4:-1]
-
diff --git a/PaddleCV/yolov3/models/yolov3.py b/PaddleCV/yolov3/models/yolov3.py
index ef57abdd..fe491249 100644
--- a/PaddleCV/yolov3/models/yolov3.py
+++ b/PaddleCV/yolov3/models/yolov3.py
@@ -26,26 +26,48 @@ from config import cfg
from .darknet import add_DarkNet53_conv_body
from .darknet import conv_bn_layer
+
def yolo_detection_block(input, channel, is_test=True, name=None):
assert channel % 2 == 0, \
"channel {} cannot be divided by 2".format(channel)
conv = input
for j in range(2):
- conv = conv_bn_layer(conv, channel, filter_size=1,
- stride=1, padding=0, is_test=is_test,
- name='{}.{}.0'.format(name, j))
- conv = conv_bn_layer(conv, channel*2, filter_size=3,
- stride=1, padding=1, is_test=is_test,
- name='{}.{}.1'.format(name, j))
- route = conv_bn_layer(conv, channel, filter_size=1, stride=1,
- padding=0, is_test=is_test,
- name='{}.2'.format(name))
- tip = conv_bn_layer(route,channel*2, filter_size=3, stride=1,
- padding=1, is_test=is_test,
- name='{}.tip'.format(name))
+ conv = conv_bn_layer(
+ conv,
+ channel,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ is_test=is_test,
+ name='{}.{}.0'.format(name, j))
+ conv = conv_bn_layer(
+ conv,
+ channel * 2,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ is_test=is_test,
+ name='{}.{}.1'.format(name, j))
+ route = conv_bn_layer(
+ conv,
+ channel,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ is_test=is_test,
+ name='{}.2'.format(name))
+ tip = conv_bn_layer(
+ route,
+ channel * 2,
+ filter_size=3,
+ stride=1,
+ padding=1,
+ is_test=is_test,
+ name='{}.tip'.format(name))
return route, tip
-def upsample(input, scale=2,name=None):
+
+def upsample(input, scale=2, name=None):
# get dynamic upsample output shape
shape_nchw = fluid.layers.shape(input)
shape_hw = fluid.layers.slice(shape_nchw, axes=[0], starts=[2], ends=[4])
@@ -56,16 +78,12 @@ def upsample(input, scale=2,name=None):
# reisze by actual_shape
out = fluid.layers.resize_nearest(
- input=input,
- scale=scale,
- actual_shape=out_shape,
- name=name)
+ input=input, scale=scale, actual_shape=out_shape, name=name)
return out
+
class YOLOv3(object):
- def __init__(self,
- is_train=True,
- use_random=True):
+ def __init__(self, is_train=True, use_random=True):
self.is_train = is_train
self.use_random = use_random
self.outputs = []
@@ -77,10 +95,8 @@ class YOLOv3(object):
if self.is_train:
self.py_reader = fluid.layers.py_reader(
capacity=64,
- shapes = [[-1] + self.image_shape,
- [-1, cfg.max_box_num, 4],
- [-1, cfg.max_box_num],
- [-1, cfg.max_box_num]],
+ shapes=[[-1] + self.image_shape, [-1, cfg.max_box_num, 4],
+ [-1, cfg.max_box_num], [-1, cfg.max_box_num]],
lod_levels=[0, 0, 0, 0],
dtypes=['float32'] * 2 + ['int32'] + ['float32'],
use_double_buffer=True)
@@ -88,13 +104,12 @@ class YOLOv3(object):
fluid.layers.read_file(self.py_reader)
else:
self.image = fluid.layers.data(
- name='image', shape=self.image_shape, dtype='float32'
- )
+ name='image', shape=self.image_shape, dtype='float32')
self.im_shape = fluid.layers.data(
- name="im_shape", shape=[2], dtype='int32')
+ name="im_shape", shape=[2], dtype='int32')
self.im_id = fluid.layers.data(
- name="im_id", shape=[1], dtype='int32')
-
+ name="im_id", shape=[1], dtype='int32')
+
def feeds(self):
if not self.is_train:
return [self.image, self.im_id, self.im_shape]
@@ -110,12 +125,12 @@ class YOLOv3(object):
blocks = add_DarkNet53_conv_body(self.image, not self.is_train)
for i, block in enumerate(blocks):
if i > 0:
- block = fluid.layers.concat(
- input=[route, block],
- axis=1)
- route, tip = yolo_detection_block(block, channel=512//(2**i),
- is_test=(not self.is_train),
- name="yolo_block.{}".format(i))
+ block = fluid.layers.concat(input=[route, block], axis=1)
+ route, tip = yolo_detection_block(
+ block,
+ channel=512 // (2**i),
+ is_test=(not self.is_train),
+ name="yolo_block.{}".format(i))
# out channel number = mask_num * (5 + class_num)
num_filters = len(cfg.anchor_masks[i]) * (cfg.class_num + 5)
@@ -126,17 +141,19 @@ class YOLOv3(object):
stride=1,
padding=0,
act=None,
- param_attr=ParamAttr(initializer=fluid.initializer.Normal(0., 0.02),
- name="yolo_output.{}.conv.weights".format(i)),
- bias_attr=ParamAttr(initializer=fluid.initializer.Constant(0.0),
- regularizer=L2Decay(0.),
- name="yolo_output.{}.conv.bias".format(i)))
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.Normal(0., 0.02),
+ name="yolo_output.{}.conv.weights".format(i)),
+ bias_attr=ParamAttr(
+ initializer=fluid.initializer.Constant(0.0),
+ regularizer=L2Decay(0.),
+ name="yolo_output.{}.conv.bias".format(i)))
self.outputs.append(block_out)
if i < len(blocks) - 1:
route = conv_bn_layer(
input=route,
- ch_out=256//(2**i),
+ ch_out=256 // (2**i),
filter_size=1,
stride=1,
padding=0,
@@ -145,42 +162,42 @@ class YOLOv3(object):
# upsample
route = upsample(route)
-
for i, out in enumerate(self.outputs):
anchor_mask = cfg.anchor_masks[i]
if self.is_train:
loss = fluid.layers.yolov3_loss(
- x=out,
- gt_box=self.gtbox,
- gt_label=self.gtlabel,
- gt_score=self.gtscore,
- anchors=cfg.anchors,
- anchor_mask=anchor_mask,
- class_num=cfg.class_num,
- ignore_thresh=cfg.ignore_thresh,
- downsample_ratio=self.downsample,
- use_label_smooth=cfg.label_smooth,
- name="yolo_loss"+str(i))
+ x=out,
+ gt_box=self.gtbox,
+ gt_label=self.gtlabel,
+ gt_score=self.gtscore,
+ anchors=cfg.anchors,
+ anchor_mask=anchor_mask,
+ class_num=cfg.class_num,
+ ignore_thresh=cfg.ignore_thresh,
+ downsample_ratio=self.downsample,
+ use_label_smooth=cfg.label_smooth,
+ name="yolo_loss" + str(i))
self.losses.append(fluid.layers.reduce_mean(loss))
else:
- mask_anchors=[]
+ mask_anchors = []
for m in anchor_mask:
mask_anchors.append(cfg.anchors[2 * m])
mask_anchors.append(cfg.anchors[2 * m + 1])
boxes, scores = fluid.layers.yolo_box(
- x=out,
- img_size=self.im_shape,
- anchors=mask_anchors,
- class_num=cfg.class_num,
- conf_thresh=cfg.valid_thresh,
- downsample_ratio=self.downsample,
- name="yolo_box"+str(i))
+ x=out,
+ img_size=self.im_shape,
+ anchors=mask_anchors,
+ class_num=cfg.class_num,
+ conf_thresh=cfg.valid_thresh,
+ downsample_ratio=self.downsample,
+ name="yolo_box" + str(i))
self.boxes.append(boxes)
- self.scores.append(fluid.layers.transpose(scores, perm=[0, 2, 1]))
-
- self.downsample //= 2
+ self.scores.append(
+ fluid.layers.transpose(
+ scores, perm=[0, 2, 1]))
+ self.downsample //= 2
def loss(self):
return sum(self.losses)
@@ -189,12 +206,11 @@ class YOLOv3(object):
yolo_boxes = fluid.layers.concat(self.boxes, axis=1)
yolo_scores = fluid.layers.concat(self.scores, axis=2)
return fluid.layers.multiclass_nms(
- bboxes=yolo_boxes,
- scores=yolo_scores,
- score_threshold=cfg.valid_thresh,
- nms_top_k=cfg.nms_topk,
- keep_top_k=cfg.nms_posk,
- nms_threshold=cfg.nms_thresh,
- background_label=-1,
- name="multiclass_nms")
-
+ bboxes=yolo_boxes,
+ scores=yolo_scores,
+ score_threshold=cfg.valid_thresh,
+ nms_top_k=cfg.nms_topk,
+ keep_top_k=cfg.nms_posk,
+ nms_threshold=cfg.nms_thresh,
+ background_label=-1,
+ name="multiclass_nms")
diff --git a/PaddleCV/yolov3/reader.py b/PaddleCV/yolov3/reader.py
index 7d1f0de7..d434e2b1 100644
--- a/PaddleCV/yolov3/reader.py
+++ b/PaddleCV/yolov3/reader.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -53,21 +52,17 @@ class DataSetReader(object):
cfg.dataset))
if mode == 'train':
- cfg.train_file_list = os.path.join(cfg.data_dir,
+ cfg.train_file_list = os.path.join(cfg.data_dir,
cfg.train_file_list)
- cfg.train_data_dir = os.path.join(cfg.data_dir,
- cfg.train_data_dir)
+ cfg.train_data_dir = os.path.join(cfg.data_dir, cfg.train_data_dir)
self.COCO = COCO(cfg.train_file_list)
self.img_dir = cfg.train_data_dir
elif mode == 'test' or mode == 'infer':
- cfg.val_file_list = os.path.join(cfg.data_dir,
- cfg.val_file_list)
- cfg.val_data_dir = os.path.join(cfg.data_dir,
- cfg.val_data_dir)
+ cfg.val_file_list = os.path.join(cfg.data_dir, cfg.val_file_list)
+ cfg.val_data_dir = os.path.join(cfg.data_dir, cfg.val_data_dir)
self.COCO = COCO(cfg.val_file_list)
self.img_dir = cfg.val_data_dir
-
def _parse_dataset_catagory(self):
self.categories = self.COCO.loadCats(self.COCO.getCatIds())
self.num_category = len(self.categories)
@@ -76,10 +71,7 @@ class DataSetReader(object):
for category in self.categories:
self.label_names.append(category['name'])
self.label_ids.append(int(category['id']))
- self.category_to_id_map = {
- v: i
- for i, v in enumerate(self.label_ids)
- }
+ self.category_to_id_map = {v: i for i, v in enumerate(self.label_ids)}
print("Load in {} categories.".format(self.num_category))
self.has_parsed_categpry = True
@@ -93,7 +85,8 @@ class DataSetReader(object):
img_height = img['height']
img_width = img['width']
anno = self.COCO.loadAnns(
- self.COCO.getAnnIds(imgIds=img['id'], iscrowd=None))
+ self.COCO.getAnnIds(
+ imgIds=img['id'], iscrowd=None))
gt_index = 0
for target in anno:
if target['area'] < cfg.gt_min_area:
@@ -102,7 +95,7 @@ class DataSetReader(object):
continue
box = box_utils.coco_anno_box_to_center_relative(
- target['bbox'], img_height, img_width)
+ target['bbox'], img_height, img_width)
if box[2] <= 0 and box[3] <= 0:
continue
@@ -141,15 +134,15 @@ class DataSetReader(object):
if mode == 'infer':
return []
else:
- return self._parse_images(is_train=(mode=='train'))
-
- def get_reader(self,
- mode,
- size=416,
- batch_size=None,
- shuffle=False,
- mixup_iter=0,
- random_sizes=[],
+ return self._parse_images(is_train=(mode == 'train'))
+
+ def get_reader(self,
+ mode,
+ size=416,
+ batch_size=None,
+ shuffle=False,
+ mixup_iter=0,
+ random_sizes=[],
image=None):
assert mode in ['train', 'test', 'infer'], "Unknow mode type!"
if mode != 'infer':
@@ -166,9 +159,13 @@ class DataSetReader(object):
h, w, _ = im.shape
im_scale_x = size / float(w)
im_scale_y = size / float(h)
- out_img = cv2.resize(im, None, None,
- fx=im_scale_x, fy=im_scale_y,
- interpolation=cv2.INTER_CUBIC)
+ out_img = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=cv2.INTER_CUBIC)
mean = np.array(mean).reshape((1, 1, -1))
std = np.array(std).reshape((1, 1, -1))
out_img = (out_img / 255.0 - mean) / std
@@ -191,12 +188,12 @@ class DataSetReader(object):
mixup_gt_labels = np.array(mixup_img['gt_labels']).copy()
mixup_gt_scores = np.ones_like(mixup_gt_labels)
im, gt_boxes, gt_labels, gt_scores = \
- image_utils.image_mixup(im, gt_boxes, gt_labels,
- gt_scores, mixup_im, mixup_gt_boxes,
+ image_utils.image_mixup(im, gt_boxes, gt_labels,
+ gt_scores, mixup_im, mixup_gt_boxes,
mixup_gt_labels, mixup_gt_scores)
im, gt_boxes, gt_labels, gt_scores = \
- image_utils.image_augment(im, gt_boxes, gt_labels,
+ image_utils.image_augment(im, gt_boxes, gt_labels,
gt_scores, size, mean)
mean = np.array(mean).reshape((1, 1, -1))
@@ -230,12 +227,13 @@ class DataSetReader(object):
img_size = get_img_size(size, random_sizes)
while True:
img = imgs[read_cnt % len(imgs)]
- mixup_img = get_mixup_img(imgs, mixup_iter, total_iter, read_cnt)
+ mixup_img = get_mixup_img(imgs, mixup_iter, total_iter,
+ read_cnt)
read_cnt += 1
if read_cnt % len(imgs) == 0 and shuffle:
np.random.shuffle(imgs)
im, gt_boxes, gt_labels, gt_scores = \
- img_reader_with_augment(img, img_size, cfg.pixel_means,
+ img_reader_with_augment(img, img_size, cfg.pixel_means,
cfg.pixel_stds, mixup_img)
batch_out.append([im, gt_boxes, gt_labels, gt_scores])
@@ -249,8 +247,7 @@ class DataSetReader(object):
imgs = self._parse_images_by_mode(mode)
batch_out = []
for img in imgs:
- im, im_id, im_shape = img_reader(img, size,
- cfg.pixel_means,
+ im, im_id, im_shape = img_reader(img, size, cfg.pixel_means,
cfg.pixel_stds)
batch_out.append((im, im_id, im_shape))
if len(batch_out) == batch_size:
@@ -262,8 +259,7 @@ class DataSetReader(object):
img = {}
img['image'] = image
img['id'] = 0
- im, im_id, im_shape = img_reader(img, size,
- cfg.pixel_means,
+ im, im_id, im_shape = img_reader(img, size, cfg.pixel_means,
cfg.pixel_stds)
batch_out = [(im, im_id, im_shape)]
yield batch_out
@@ -273,17 +269,18 @@ class DataSetReader(object):
dsr = DataSetReader()
-def train(size=416,
- batch_size=64,
- shuffle=True,
+
+def train(size=416,
+ batch_size=64,
+ shuffle=True,
total_iter=0,
mixup_iter=0,
random_sizes=[],
num_workers=8,
max_queue=32,
use_multiprocessing=True):
- generator = dsr.get_reader('train', size, batch_size, shuffle,
- int(mixup_iter/num_workers), random_sizes)
+ generator = dsr.get_reader('train', size, batch_size, shuffle,
+ int(mixup_iter / num_workers), random_sizes)
if not use_multiprocessing:
return generator
@@ -316,15 +313,17 @@ def train(size=416,
finally:
if enqueuer is not None:
enqueuer.stop()
-
+
return reader
+
def test(size=416, batch_size=1):
return dsr.get_reader('test', size, batch_size)
+
def infer(size=416, image=None):
return dsr.get_reader('infer', size, image=image)
+
def get_label_infos():
return dsr.get_label_infos()
-
diff --git a/PaddleCV/yolov3/train.py b/PaddleCV/yolov3/train.py
index 97d23d38..5ad9a774 100644
--- a/PaddleCV/yolov3/train.py
+++ b/PaddleCV/yolov3/train.py
@@ -33,12 +33,12 @@ from config import cfg
def train():
- if cfg.debug:
+ if cfg.debug or args.enable_ce:
fluid.default_startup_program().random_seed = 1000
fluid.default_main_program().random_seed = 1000
random.seed(0)
np.random.seed(0)
-
+
if not os.path.exists(cfg.model_save_dir):
os.makedirs(cfg.model_save_dir)
@@ -76,16 +76,18 @@ def train():
if cfg.pretrain:
if not os.path.exists(cfg.pretrain):
print("Pretrain weights not found: {}".format(cfg.pretrain))
+
def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrain, var.name))
+
fluid.io.load_vars(exe, cfg.pretrain, predicate=if_exist)
- build_strategy= fluid.BuildStrategy()
+ build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = True
- build_strategy.sync_batch_norm = cfg.syncbn
- compile_program = fluid.compiler.CompiledProgram(
- fluid.default_main_program()).with_data_parallel(
- loss_name=loss.name, build_strategy=build_strategy)
+ build_strategy.sync_batch_norm = cfg.syncbn
+ compile_program = fluid.compiler.CompiledProgram(fluid.default_main_program(
+ )).with_data_parallel(
+ loss_name=loss.name, build_strategy=build_strategy)
random_sizes = [cfg.input_size]
if cfg.random_shape:
@@ -93,13 +95,17 @@ def train():
total_iter = cfg.max_iter - cfg.start_iter
mixup_iter = total_iter - cfg.no_mixup_iter
- train_reader = reader.train(input_size,
- batch_size=cfg.batch_size,
- shuffle=True,
- total_iter=total_iter*devices_num,
- mixup_iter=mixup_iter*devices_num,
- random_sizes=random_sizes,
- use_multiprocessing=cfg.use_multiprocess)
+ shuffle = True
+ if args.enable_ce:
+ shuffle = False
+ train_reader = reader.train(
+ input_size,
+ batch_size=cfg.batch_size,
+ shuffle=shuffle,
+ total_iter=total_iter * devices_num,
+ mixup_iter=mixup_iter * devices_num,
+ random_sizes=random_sizes,
+ use_multiprocessing=cfg.use_multiprocess)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
@@ -121,7 +127,7 @@ def train():
for iter_id in range(cfg.start_iter, cfg.max_iter):
prev_start_time = start_time
start_time = time.time()
- losses = exe.run(compile_program,
+ losses = exe.run(compile_program,
fetch_list=[v.name for v in fetch_list])
smoothed_loss.add_value(np.mean(np.array(losses[0])))
snapshot_loss += np.mean(np.array(losses[0]))
@@ -129,17 +135,27 @@ def train():
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
- iter_id, lr[0],
- smoothed_loss.get_mean_value(),
- start_time - prev_start_time))
+ iter_id, lr[0],
+ smoothed_loss.get_mean_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
print("Snapshot {} saved, average loss: {}, \
average time: {}".format(
- iter_id + 1,
- snapshot_loss / float(cfg.snapshot_iter),
- snapshot_time / float(cfg.snapshot_iter)))
+ iter_id + 1, snapshot_loss / float(cfg.snapshot_iter),
+ snapshot_time / float(cfg.snapshot_iter)))
+ if args.enable_ce and iter_id == cfg.max_iter - 1:
+ if devices_num == 1:
+ print("kpis\ttrain_cost_1card\t%f" %
+ (snapshot_loss / float(cfg.snapshot_iter)))
+ print("kpis\ttrain_duration_1card\t%f" %
+ (snapshot_time / float(cfg.snapshot_iter)))
+ else:
+ print("kpis\ttrain_cost_8card\t%f" %
+ (snapshot_loss / float(cfg.snapshot_iter)))
+ print("kpis\ttrain_duration_8card\t%f" %
+ (snapshot_time / float(cfg.snapshot_iter)))
+
snapshot_loss = 0
snapshot_time = 0
except fluid.core.EOFException:
diff --git a/PaddleCV/yolov3/utility.py b/PaddleCV/yolov3/utility.py
index 3f5c3c6e..49a7d8fa 100644
--- a/PaddleCV/yolov3/utility.py
+++ b/PaddleCV/yolov3/utility.py
@@ -120,12 +120,13 @@ def parse_args():
add_arg('nms_posk', int, 100, "The number of boxes of NMS output.")
add_arg('debug', bool, False, "Debug mode")
# SINGLE EVAL AND DRAW
- add_arg('image_path', str, 'image',
+ add_arg('image_path', str, 'image',
"The image path used to inference and visualize.")
- add_arg('image_name', str, None,
+ add_arg('image_name', str, None,
"The single image used to inference and visualize. None to inference all images in image_path")
- add_arg('draw_thresh', float, 0.5,
+ add_arg('draw_thresh', float, 0.5,
"Confidence score threshold to draw prediction box in image in debug mode")
+ add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
# yapf: enable
args = parser.parse_args()
file_name = sys.argv[0]
--
GitLab