未验证 提交 6f153ddc 编写于 作者: H haoyuying 提交者: GitHub

adapt rc for colorization and style transfer

上级 332f3a0c
......@@ -2,7 +2,5 @@ import paddle
import paddlehub as hub
if __name__ == '__main__':
model = hub.Module(name='user_guided_colorization')
state_dict = paddle.load('img_colorization_ckpt')
model.set_dict(state_dict)
result = model.predict('house.png')
model = hub.Module(name='user_guided_colorization', load_checkpoint='/PATH/TO/CHECKPOINT')
result = model.predict(images='house.png')
......@@ -2,7 +2,5 @@ import paddle
import paddlehub as hub
if __name__ == '__main__':
model = hub.Module(name='msgnet')
state_dict = paddle.load('img_style_transfer_ckpt')
model.set_dict(state_dict)
model = hub.Module(name='msgnet', load_checkpoint='/PATH/TO/CHECKPOINT')
result = model.predict("venice-boat.jpg", "candy.jpg")
......@@ -179,11 +179,7 @@ class UserGuidedColorization(nn.Layer):
print("load custom checkpoint success")
else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)
self.set_dict(model_dict)
print("load pretrained checkpoint success")
......
......@@ -314,9 +314,6 @@ class MSGNet(nn.Layer):
else:
checkpoint = os.path.join(self.directory, 'style_paddle.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint)
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
......
......@@ -164,7 +164,7 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns:
results(list[dict]) : The prediction result of each input image
'''
self.eval()
lab2rgb = T.LAB2RGB()
process = T.ColorPostprocess()
resize = T.Resize((256, 256))
......@@ -239,16 +239,17 @@ class Yolov3Module(RunModule, ImageServing):
for i, out in enumerate(outputs):
anchor_mask = self.anchor_masks[i]
loss = F.yolov3_loss(x=out,
gt_box=gtbox,
gt_label=gtlabel,
gt_score=gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=self.ignore_thresh,
downsample_ratio=32,
use_label_smooth=False)
loss = F.yolov3_loss(
x=out,
gt_box=gtbox,
gt_label=gtlabel,
gt_score=gtscore,
anchors=self.anchors,
anchor_mask=anchor_mask,
class_num=self.class_num,
ignore_thresh=self.ignore_thresh,
downsample_ratio=32,
use_label_smooth=False)
losses.append(paddle.mean(loss))
self.downsample //= 2
......@@ -269,6 +270,7 @@ class Yolov3Module(RunModule, ImageServing):
scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels.
'''
self.eval()
boxes = []
scores = []
self.downsample = 32
......@@ -287,13 +289,14 @@ class Yolov3Module(RunModule, ImageServing):
mask_anchors.append((self.anchors[2 * m]))
mask_anchors.append(self.anchors[2 * m + 1])
box, score = F.yolo_box(x=out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.class_num,
conf_thresh=self.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
box, score = F.yolo_box(
x=out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.class_num,
conf_thresh=self.valid_thresh,
downsample_ratio=self.downsample,
name="yolo_box" + str(i))
boxes.append(box)
scores.append(paddle.transpose(score, perm=[0, 2, 1]))
......@@ -302,13 +305,14 @@ class Yolov3Module(RunModule, ImageServing):
yolo_boxes = paddle.concat(boxes, axis=1)
yolo_scores = paddle.concat(scores, axis=2)
pred = F.multiclass_nms(bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
pred = F.multiclass_nms(
bboxes=yolo_boxes,
scores=yolo_scores,
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
bboxes = pred.numpy()
labels = bboxes[:, 0].astype('int32')
......@@ -388,6 +392,7 @@ class StyleTransferModule(RunModule, ImageServing):
Returns:
output(np.ndarray) : The style transformed images with bgr mode.
'''
self.eval()
content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册