提交 2bf8d446 编写于 作者: H haoyuying

revise get_label_infos from yolov3

上级 2f8ff8e6
......@@ -5,7 +5,6 @@ from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.pascalvoc import DetectionData
import paddlehub.process.detect_transforms as T
if __name__ == "__main__":
place = paddle.CUDAPlace(0)
paddle.disable_static()
transform = T.Compose([
......@@ -20,7 +19,6 @@ if __name__ == "__main__":
train_reader = DetectionData(transform)
model = hub.Module(name='yolov3_darknet53_pascalvoc')
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_det')
trainer.train(train_reader, epochs=5, batch_size=4, eval_dataset=train_reader, log_interval=1, save_interval=1)
......@@ -5,7 +5,6 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay
from pycocotools.coco import COCO
from paddlehub.module.cv_module import Yolov3Module
import paddlehub.process.detect_transforms as T
from paddlehub.module.module import moduleinfo
......@@ -274,10 +273,10 @@ class YOLOv3(nn.Layer):
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'yolov3_70000.pdparams')
checkpoint = os.path.join(self.directory, 'yolov3_darknet53_voc.pdparams')
if not os.path.exists(checkpoint):
os.system(
'wget https://bj.bcebos.com/paddlehub/model/image/object_detection/yolov3_70000.pdparams -O ' \
'wget https://paddlehub.bj.bcebos.com/dygraph/detection/yolov3_darknet53_voc.pdparams -O ' \
+ checkpoint)
model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict)
......@@ -302,14 +301,6 @@ class YOLOv3(nn.Layer):
return transform(img)
def get_label_infos(self, file_list: str):
self.COCO = COCO(file_list)
label_names = []
categories = self.COCO.loadCats(self.COCO.getCatIds())
for category in categories:
label_names.append(category['name'])
return label_names
def forward(self, inputs: paddle.Tensor):
outputs = []
blocks = self.block(inputs)
......
......@@ -265,7 +265,7 @@ class Yolov3Module(RunModule, ImageServing):
im = self.transform(imgpath)
h, w, c = Func.img_shape(imgpath)
im_shape = paddle.to_tensor(np.array([[h, w]]).astype('int32'))
label_names = self.get_label_infos(filelist)
label_names = Func.get_label_infos(filelist)
img_data = paddle.to_tensor(np.array([im]).astype('float32'))
outputs = self(img_data)
......
......@@ -17,6 +17,7 @@ import cv2
import paddle
import matplotlib
import numpy as np
from pycocotools.coco import COCO
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
......@@ -235,3 +236,13 @@ def img_shape(img_path: str):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, c = im.shape
return h, w, c
def get_label_infos(file_list: str):
"""Get label names by corresponding category ids."""
map_label = COCO(file_list)
label_names = []
categories = map_label.loadCats(map_label.getCatIds())
for category in categories:
label_names.append(category['name'])
return label_names
......@@ -20,7 +20,6 @@ from collections import OrderedDict
import cv2
import numpy as np
import matplotlib
from PIL import Image, ImageEnhance
from paddlehub.process.functional import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册