From 6f153ddce7d70ea8694ba89957765f20a610f00c Mon Sep 17 00:00:00 2001 From: haoyuying <35907364+haoyuying@users.noreply.github.com> Date: Fri, 30 Oct 2020 15:25:58 +0800 Subject: [PATCH] adapt rc for colorization and style transfer --- demo/colorization/predict.py | 6 +- demo/style_transfer/predict.py | 4 +- .../user_guided_colorization/module.py | 4 -- modules/image/style_transfer/msgnet/module.py | 3 - paddlehub/module/cv_module.py | 55 ++++++++++--------- 5 files changed, 33 insertions(+), 39 deletions(-) diff --git a/demo/colorization/predict.py b/demo/colorization/predict.py index e72bafcc..f24dab79 100644 --- a/demo/colorization/predict.py +++ b/demo/colorization/predict.py @@ -2,7 +2,5 @@ import paddle import paddlehub as hub if __name__ == '__main__': - model = hub.Module(name='user_guided_colorization') - state_dict = paddle.load('img_colorization_ckpt') - model.set_dict(state_dict) - result = model.predict('house.png') + model = hub.Module(name='user_guided_colorization', load_checkpoint='/PATH/TO/CHECKPOINT') + result = model.predict(images='house.png') diff --git a/demo/style_transfer/predict.py b/demo/style_transfer/predict.py index 9f44868f..61918dc8 100644 --- a/demo/style_transfer/predict.py +++ b/demo/style_transfer/predict.py @@ -2,7 +2,5 @@ import paddle import paddlehub as hub if __name__ == '__main__': - model = hub.Module(name='msgnet') - state_dict = paddle.load('img_style_transfer_ckpt') - model.set_dict(state_dict) + model = hub.Module(name='msgnet', load_checkpoint='/PATH/TO/CHECKPOINT') result = model.predict("venice-boat.jpg", "candy.jpg") diff --git a/modules/image/colorization/user_guided_colorization/module.py b/modules/image/colorization/user_guided_colorization/module.py index e31017aa..8983d28b 100644 --- a/modules/image/colorization/user_guided_colorization/module.py +++ b/modules/image/colorization/user_guided_colorization/module.py @@ -179,11 +179,7 @@ class UserGuidedColorization(nn.Layer): print("load custom checkpoint success") else: checkpoint = os.path.join(self.directory, 'user_guided.pdparams') - if not os.path.exists(checkpoint): - os.system('wget https://paddlehub.bj.bcebos.com/dygraph/image_colorization/user_guided.pdparams -O ' + - checkpoint) model_dict = paddle.load(checkpoint) - self.set_dict(model_dict) print("load pretrained checkpoint success") diff --git a/modules/image/style_transfer/msgnet/module.py b/modules/image/style_transfer/msgnet/module.py index 5c126124..ef1bc202 100644 --- a/modules/image/style_transfer/msgnet/module.py +++ b/modules/image/style_transfer/msgnet/module.py @@ -314,9 +314,6 @@ class MSGNet(nn.Layer): else: checkpoint = os.path.join(self.directory, 'style_paddle.pdparams') - if not os.path.exists(checkpoint): - os.system('wget https://bj.bcebos.com/paddlehub/model/image/image_editing/style_paddle.pdparams -O ' + - checkpoint) model_dict = paddle.load(checkpoint) model_dict_clone = model_dict.copy() for key, value in model_dict_clone.items(): diff --git a/paddlehub/module/cv_module.py b/paddlehub/module/cv_module.py index 431302a6..ad1d625c 100644 --- a/paddlehub/module/cv_module.py +++ b/paddlehub/module/cv_module.py @@ -164,7 +164,7 @@ class ImageColorizeModule(RunModule, ImageServing): Returns: results(list[dict]) : The prediction result of each input image ''' - + self.eval() lab2rgb = T.LAB2RGB() process = T.ColorPostprocess() resize = T.Resize((256, 256)) @@ -239,16 +239,17 @@ class Yolov3Module(RunModule, ImageServing): for i, out in enumerate(outputs): anchor_mask = self.anchor_masks[i] - loss = F.yolov3_loss(x=out, - gt_box=gtbox, - gt_label=gtlabel, - gt_score=gtscore, - anchors=self.anchors, - anchor_mask=anchor_mask, - class_num=self.class_num, - ignore_thresh=self.ignore_thresh, - downsample_ratio=32, - use_label_smooth=False) + loss = F.yolov3_loss( + x=out, + gt_box=gtbox, + gt_label=gtlabel, + gt_score=gtscore, + anchors=self.anchors, + anchor_mask=anchor_mask, + class_num=self.class_num, + ignore_thresh=self.ignore_thresh, + downsample_ratio=32, + use_label_smooth=False) losses.append(paddle.mean(loss)) self.downsample //= 2 @@ -269,6 +270,7 @@ class Yolov3Module(RunModule, ImageServing): scores(np.ndarray): Predict score. labels(np.ndarray): Predict labels. ''' + self.eval() boxes = [] scores = [] self.downsample = 32 @@ -287,13 +289,14 @@ class Yolov3Module(RunModule, ImageServing): mask_anchors.append((self.anchors[2 * m])) mask_anchors.append(self.anchors[2 * m + 1]) - box, score = F.yolo_box(x=out, - img_size=im_shape, - anchors=mask_anchors, - class_num=self.class_num, - conf_thresh=self.valid_thresh, - downsample_ratio=self.downsample, - name="yolo_box" + str(i)) + box, score = F.yolo_box( + x=out, + img_size=im_shape, + anchors=mask_anchors, + class_num=self.class_num, + conf_thresh=self.valid_thresh, + downsample_ratio=self.downsample, + name="yolo_box" + str(i)) boxes.append(box) scores.append(paddle.transpose(score, perm=[0, 2, 1])) @@ -302,13 +305,14 @@ class Yolov3Module(RunModule, ImageServing): yolo_boxes = paddle.concat(boxes, axis=1) yolo_scores = paddle.concat(scores, axis=2) - pred = F.multiclass_nms(bboxes=yolo_boxes, - scores=yolo_scores, - score_threshold=self.valid_thresh, - nms_top_k=self.nms_topk, - keep_top_k=self.nms_posk, - nms_threshold=self.nms_thresh, - background_label=-1) + pred = F.multiclass_nms( + bboxes=yolo_boxes, + scores=yolo_scores, + score_threshold=self.valid_thresh, + nms_top_k=self.nms_topk, + keep_top_k=self.nms_posk, + nms_threshold=self.nms_thresh, + background_label=-1) bboxes = pred.numpy() labels = bboxes[:, 0].astype('int32') @@ -388,6 +392,7 @@ class StyleTransferModule(RunModule, ImageServing): Returns: output(np.ndarray) : The style transformed images with bgr mode. ''' + self.eval() content = paddle.to_tensor(self.transform(origin_path)) style = paddle.to_tensor(self.transform(style_path)) content = content.unsqueeze(0) -- GitLab