未验证 提交 6f153ddc 编写于 作者: H haoyuying 提交者: GitHub

adapt rc for colorization and style transfer

上级 332f3a0c
...@@ -2,7 +2,5 @@ import paddle ...@@ -2,7 +2,5 @@ import paddle
import paddlehub as hub import paddlehub as hub
if __name__ == '__main__': if __name__ == '__main__':
model = hub.Module(name='user_guided_colorization') model = hub.Module(name='user_guided_colorization', load_checkpoint='/PATH/TO/CHECKPOINT')
state_dict = paddle.load('img_colorization_ckpt') result = model.predict(images='house.png')
model.set_dict(state_dict)
result = model.predict('house.png')
...@@ -2,7 +2,5 @@ import paddle ...@@ -2,7 +2,5 @@ import paddle
import paddlehub as hub import paddlehub as hub
if __name__ == '__main__': if __name__ == '__main__':
model = hub.Module(name='msgnet') model = hub.Module(name='msgnet', load_checkpoint='/PATH/TO/CHECKPOINT')
state_dict = paddle.load('img_style_transfer_ckpt')
model.set_dict(state_dict)
result = model.predict("venice-boat.jpg", "candy.jpg") result = model.predict("venice-boat.jpg", "candy.jpg")
...@@ -179,11 +179,7 @@ class UserGuidedColorization(nn.Layer): ...@@ -179,11 +179,7 @@ class UserGuidedColorization(nn.Layer):
print("load custom checkpoint success") print("load custom checkpoint success")
else: else:
checkpoint = os.path.join(self.directory, 'user_guided.pdparams') checkpoint = os.path.join(self.directory, 'user_guided.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint) model_dict = paddle.load(checkpoint)
self.set_dict(model_dict) self.set_dict(model_dict)
print("load pretrained checkpoint success") print("load pretrained checkpoint success")
......
...@@ -314,9 +314,6 @@ class MSGNet(nn.Layer): ...@@ -314,9 +314,6 @@ class MSGNet(nn.Layer):
else: else:
checkpoint = os.path.join(self.directory, 'style_paddle.pdparams') checkpoint = os.path.join(self.directory, 'style_paddle.pdparams')
if not os.path.exists(checkpoint):
os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O ' +
checkpoint)
model_dict = paddle.load(checkpoint) model_dict = paddle.load(checkpoint)
model_dict_clone = model_dict.copy() model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items(): for key, value in model_dict_clone.items():
......
...@@ -164,7 +164,7 @@ class ImageColorizeModule(RunModule, ImageServing): ...@@ -164,7 +164,7 @@ class ImageColorizeModule(RunModule, ImageServing):
Returns: Returns:
results(list[dict]) : The prediction result of each input image results(list[dict]) : The prediction result of each input image
''' '''
self.eval()
lab2rgb = T.LAB2RGB() lab2rgb = T.LAB2RGB()
process = T.ColorPostprocess() process = T.ColorPostprocess()
resize = T.Resize((256, 256)) resize = T.Resize((256, 256))
...@@ -239,16 +239,17 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -239,16 +239,17 @@ class Yolov3Module(RunModule, ImageServing):
for i, out in enumerate(outputs): for i, out in enumerate(outputs):
anchor_mask = self.anchor_masks[i] anchor_mask = self.anchor_masks[i]
loss = F.yolov3_loss(x=out, loss = F.yolov3_loss(
gt_box=gtbox, x=out,
gt_label=gtlabel, gt_box=gtbox,
gt_score=gtscore, gt_label=gtlabel,
anchors=self.anchors, gt_score=gtscore,
anchor_mask=anchor_mask, anchors=self.anchors,
class_num=self.class_num, anchor_mask=anchor_mask,
ignore_thresh=self.ignore_thresh, class_num=self.class_num,
downsample_ratio=32, ignore_thresh=self.ignore_thresh,
use_label_smooth=False) downsample_ratio=32,
use_label_smooth=False)
losses.append(paddle.mean(loss)) losses.append(paddle.mean(loss))
self.downsample //= 2 self.downsample //= 2
...@@ -269,6 +270,7 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -269,6 +270,7 @@ class Yolov3Module(RunModule, ImageServing):
scores(np.ndarray): Predict score. scores(np.ndarray): Predict score.
labels(np.ndarray): Predict labels. labels(np.ndarray): Predict labels.
''' '''
self.eval()
boxes = [] boxes = []
scores = [] scores = []
self.downsample = 32 self.downsample = 32
...@@ -287,13 +289,14 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -287,13 +289,14 @@ class Yolov3Module(RunModule, ImageServing):
mask_anchors.append((self.anchors[2 * m])) mask_anchors.append((self.anchors[2 * m]))
mask_anchors.append(self.anchors[2 * m + 1]) mask_anchors.append(self.anchors[2 * m + 1])
box, score = F.yolo_box(x=out, box, score = F.yolo_box(
img_size=im_shape, x=out,
anchors=mask_anchors, img_size=im_shape,
class_num=self.class_num, anchors=mask_anchors,
conf_thresh=self.valid_thresh, class_num=self.class_num,
downsample_ratio=self.downsample, conf_thresh=self.valid_thresh,
name="yolo_box" + str(i)) downsample_ratio=self.downsample,
name="yolo_box" + str(i))
boxes.append(box) boxes.append(box)
scores.append(paddle.transpose(score, perm=[0, 2, 1])) scores.append(paddle.transpose(score, perm=[0, 2, 1]))
...@@ -302,13 +305,14 @@ class Yolov3Module(RunModule, ImageServing): ...@@ -302,13 +305,14 @@ class Yolov3Module(RunModule, ImageServing):
yolo_boxes = paddle.concat(boxes, axis=1) yolo_boxes = paddle.concat(boxes, axis=1)
yolo_scores = paddle.concat(scores, axis=2) yolo_scores = paddle.concat(scores, axis=2)
pred = F.multiclass_nms(bboxes=yolo_boxes, pred = F.multiclass_nms(
scores=yolo_scores, bboxes=yolo_boxes,
score_threshold=self.valid_thresh, scores=yolo_scores,
nms_top_k=self.nms_topk, score_threshold=self.valid_thresh,
keep_top_k=self.nms_posk, nms_top_k=self.nms_topk,
nms_threshold=self.nms_thresh, keep_top_k=self.nms_posk,
background_label=-1) nms_threshold=self.nms_thresh,
background_label=-1)
bboxes = pred.numpy() bboxes = pred.numpy()
labels = bboxes[:, 0].astype('int32') labels = bboxes[:, 0].astype('int32')
...@@ -388,6 +392,7 @@ class StyleTransferModule(RunModule, ImageServing): ...@@ -388,6 +392,7 @@ class StyleTransferModule(RunModule, ImageServing):
Returns: Returns:
output(np.ndarray) : The style transformed images with bgr mode. output(np.ndarray) : The style transformed images with bgr mode.
''' '''
self.eval()
content = paddle.to_tensor(self.transform(origin_path)) content = paddle.to_tensor(self.transform(origin_path))
style = paddle.to_tensor(self.transform(style_path)) style = paddle.to_tensor(self.transform(style_path))
content = content.unsqueeze(0) content = content.unsqueeze(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册