提交 465eaf3e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!986 clear yolov3 and ssd script pylint warning

Merge pull request !986 from chengxb7532/master
......@@ -137,7 +137,7 @@ def ssd_bboxes_encode(boxes):
num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
return bboxes, t_label.astype(np.int32), num_match_num
def ssd_bboxes_decode(boxes, index, image_shape):
def ssd_bboxes_decode(boxes, index):
"""Decode predict boxes to [x, y, w, h]"""
boxes_t = boxes[index]
default_boxes_t = default_boxes[index]
......
......@@ -110,14 +110,12 @@ def metrics(pred_data):
pred_boxes = sample['boxes']
boxes_scores = sample['box_scores']
annotation = sample['annotation']
image_shape = sample['image_shape']
annotation = np.squeeze(annotation, axis=0)
image_shape = np.squeeze(image_shape, axis=0)
pred_labels = np.argmax(boxes_scores, axis=-1)
index = np.nonzero(pred_labels)
pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape)
pred_boxes = ssd_bboxes_decode(pred_boxes, index)
pred_boxes = pred_boxes.clip(0, 1)
boxes_scores = np.max(boxes_scores, axis=-1)
......
......@@ -60,7 +60,7 @@ def init_net_param(net, init='ones'):
p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype()))
if __name__ == '__main__':
def main():
parser = argparse.ArgumentParser(description="YOLOv3 train")
parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create "
"Mindrecord, default is false.")
......@@ -153,3 +153,6 @@ if __name__ == '__main__':
dataset_sink_mode = True
print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.")
model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册