提交 2bf8d446 编写于 作者: H haoyuying

revise get_label_infos from yolov3

上级 2f8ff8e6
...@@ -5,7 +5,6 @@ from paddlehub.finetune.trainer import Trainer ...@@ -5,7 +5,6 @@ from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.pascalvoc import DetectionData from paddlehub.datasets.pascalvoc import DetectionData
import paddlehub.process.detect_transforms as T import paddlehub.process.detect_transforms as T
if __name__ == "__main__": if __name__ == "__main__":
place = paddle.CUDAPlace(0)
paddle.disable_static() paddle.disable_static()
transform = T.Compose([ transform = T.Compose([
...@@ -20,7 +19,6 @@ if __name__ == "__main__": ...@@ -20,7 +19,6 @@ if __name__ == "__main__":
train_reader = DetectionData(transform) train_reader = DetectionData(transform)
model = hub.Module(name='yolov3_darknet53_pascalvoc') model = hub.Module(name='yolov3_darknet53_pascalvoc')
model.train()
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters()) optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_det')
trainer.train(train_reader, epochs=5, batch_size=4, eval_dataset=train_reader, log_interval=1, save_interval=1) trainer.train(train_reader, epochs=5, batch_size=4, eval_dataset=train_reader, log_interval=1, save_interval=1)
...@@ -5,7 +5,6 @@ import paddle.nn as nn ...@@ -5,7 +5,6 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import Normal, Constant from paddle.nn.initializer import Normal, Constant
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from pycocotools.coco import COCO
from paddlehub.module.cv_module import Yolov3Module from paddlehub.module.cv_module import Yolov3Module
import paddlehub.process.detect_transforms as T import paddlehub.process.detect_transforms as T
from paddlehub.module.module import moduleinfo from paddlehub.module.module import moduleinfo
...@@ -274,10 +273,10 @@ class YOLOv3(nn.Layer): ...@@ -274,10 +273,10 @@ class YOLOv3(nn.Layer):
print("load custom checkpoint success") print("load custom checkpoint success")
else: else:
checkpoint = os.path.join(self.directory, 'yolov3_70000.pdparams') checkpoint = os.path.join(self.directory, 'yolov3_darknet53_voc.pdparams')
if not os.path.exists(checkpoint): if not os.path.exists(checkpoint):
os.system( os.system(
'wget https://bj.bcebos.com/paddlehub/model/image/object_detection/yolov3_70000.pdparams -O ' \ 'wget https://paddlehub.bj.bcebos.com/dygraph/detection/yolov3_darknet53_voc.pdparams -O ' \
+ checkpoint) + checkpoint)
model_dict = paddle.load(checkpoint)[0] model_dict = paddle.load(checkpoint)[0]
self.set_dict(model_dict) self.set_dict(model_dict)
...@@ -302,14 +301,6 @@ class YOLOv3(nn.Layer): ...@@ -302,14 +301,6 @@ class YOLOv3(nn.Layer):
return transform(img) return transform(img)
def get_label_infos(self, file_list: str):
self.COCO = COCO(file_list)
label_names = []
categories = self.COCO.loadCats(self.COCO.getCatIds())
for category in categories:
label_names.append(category['name'])
return label_names
def forward(self, inputs: paddle.Tensor): def forward(self, inputs: paddle.Tensor):
outputs = [] outputs = []
blocks = self.block(inputs) blocks = self.block(inputs)
......
...@@ -265,7 +265,7 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -265,7 +265,7 @@ class Yolov3Module(RunModule, ImageServing):
im = self.transform(imgpath) im = self.transform(imgpath)
h, w, c = Func.img_shape(imgpath) h, w, c = Func.img_shape(imgpath)
im_shape = paddle.to_tensor(np.array([[h, w]]).astype('int32')) im_shape = paddle.to_tensor(np.array([[h, w]]).astype('int32'))
label_names = self.get_label_infos(filelist) label_names = Func.get_label_infos(filelist)
img_data = paddle.to_tensor(np.array([im]).astype('float32')) img_data = paddle.to_tensor(np.array([im]).astype('float32'))
outputs = self(img_data) outputs = self(img_data)
......
...@@ -17,6 +17,7 @@ import cv2 ...@@ -17,6 +17,7 @@ import cv2
import paddle import paddle
import matplotlib import matplotlib
import numpy as np import numpy as np
from pycocotools.coco import COCO
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
...@@ -235,3 +236,13 @@ def img_shape(img_path: str): ...@@ -235,3 +236,13 @@ def img_shape(img_path: str):
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w, c = im.shape h, w, c = im.shape
return h, w, c return h, w, c
def get_label_infos(file_list: str):
"""Get label names by corresponding category ids."""
map_label = COCO(file_list)
label_names = []
categories = map_label.loadCats(map_label.getCatIds())
for category in categories:
label_names.append(category['name'])
return label_names
...@@ -20,7 +20,6 @@ from collections import OrderedDict ...@@ -20,7 +20,6 @@ from collections import OrderedDict
import cv2 import cv2
import numpy as np import numpy as np
import matplotlib
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
from paddlehub.process.functional import * from paddlehub.process.functional import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册