提交 d5b9c2ea 编写于 作者: L LielinJiang

fix url name

上级 93d8fca1
...@@ -18,7 +18,6 @@ from paddle.utils.download import get_path_from_url ...@@ -18,7 +18,6 @@ from paddle.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames from ppgan.utils.video import frames2video, video2frames
from ppgan.models.generators.deoldify import build_model from ppgan.models.generators.deoldify import build_model
parser = argparse.ArgumentParser(description='DeOldify') parser = argparse.ArgumentParser(description='DeOldify')
parser.add_argument('--input', type=str, default='none', help='Input video') parser.add_argument('--input', type=str, default='none', help='Input video')
parser.add_argument('--output', type=str, default='output', help='output dir') parser.add_argument('--output', type=str, default='output', help='output dir')
...@@ -31,7 +30,7 @@ parser.add_argument('--weight_path', ...@@ -31,7 +30,7 @@ parser.add_argument('--weight_path',
default=None, default=None,
help='Path to the reference image directory') help='Path to the reference image directory')
DeOldify_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams' DEOLDIFY_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DeOldify_stable.pdparams'
class DeOldifyPredictor(): class DeOldifyPredictor():
...@@ -46,7 +45,7 @@ class DeOldifyPredictor(): ...@@ -46,7 +45,7 @@ class DeOldifyPredictor():
self.render_factor = render_factor self.render_factor = render_factor
self.model = build_model() self.model = build_model()
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(DeOldify_weight_url, cur_path) weight_path = get_path_from_url(DEOLDIFY_WEIGHT_URL, cur_path)
state_dict, _ = paddle.load(weight_path) state_dict, _ = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
...@@ -127,8 +126,7 @@ class DeOldifyPredictor(): ...@@ -127,8 +126,7 @@ class DeOldifyPredictor():
vid_out_path = os.path.join(output_path, vid_out_path = os.path.join(output_path,
'{}_deoldify_out.mp4'.format(base_name)) '{}_deoldify_out.mp4'.format(base_name))
frames2video(frame_pattern_combined, vid_out_path, frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
......
...@@ -17,7 +17,7 @@ import utils ...@@ -17,7 +17,7 @@ import utils
from ppgan.models.generators.remaster import NetworkR, NetworkC from ppgan.models.generators.remaster import NetworkR, NetworkC
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
DeepRemaster_weight_url = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams' DEEPREMASTER_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/deep_remaster.pdparams'
parser = argparse.ArgumentParser(description='Remastering') parser = argparse.ArgumentParser(description='Remastering')
parser.add_argument('--input', type=str, default=None, help='Input video') parser.add_argument('--input', type=str, default=None, help='Input video')
...@@ -51,7 +51,7 @@ class DeepReasterPredictor: ...@@ -51,7 +51,7 @@ class DeepReasterPredictor:
self.mindim = mindim self.mindim = mindim
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(DeepRemaster_weight_url, cur_path) weight_path = get_path_from_url(DEEPREMASTER_WEIGHT_URL, cur_path)
state_dict, _ = paddle.load(weight_path) state_dict, _ = paddle.load(weight_path)
......
...@@ -32,7 +32,7 @@ from data import EDVRDataset ...@@ -32,7 +32,7 @@ from data import EDVRDataset
from paddle.utils.download import get_path_from_url from paddle.utils.download import get_path_from_url
from ppgan.utils.video import frames2video, video2frames from ppgan.utils.video import frames2video, video2frames
EDVR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar' EDVR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/edvr_infer_model.tar'
def parse_args(): def parse_args():
...@@ -82,7 +82,7 @@ class EDVRPredictor: ...@@ -82,7 +82,7 @@ class EDVRPredictor:
self.exe = fluid.Executor(place) self.exe = fluid.Executor(place)
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(EDVR_weight_url, cur_path) weight_path = get_path_from_url(EDVR_WEIGHT_URL, cur_path)
model_filename = 'EDVR_model.pdmodel' model_filename = 'EDVR_model.pdmodel'
params_filename = 'EDVR_params.pdparams' params_filename = 'EDVR_params.pdparams'
...@@ -141,8 +141,7 @@ class EDVRPredictor: ...@@ -141,8 +141,7 @@ class EDVRPredictor:
frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png') frame_pattern_combined = os.path.join(pred_frame_path, '%08d.png')
vid_out_path = os.path.join(self.output, vid_out_path = os.path.join(self.output,
'{}_edvr_out.mp4'.format(base_name)) '{}_edvr_out.mp4'.format(base_name))
frames2video(frame_pattern_combined, vid_out_path, frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
......
...@@ -26,7 +26,7 @@ parser.add_argument('--weight_path', ...@@ -26,7 +26,7 @@ parser.add_argument('--weight_path',
default=None, default=None,
help='Path to the reference image directory') help='Path to the reference image directory')
RealSR_weight_url = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams' REALSR_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/applications/DF2K_JPEG.pdparams'
class RealSRPredictor(): class RealSRPredictor():
...@@ -35,7 +35,7 @@ class RealSRPredictor(): ...@@ -35,7 +35,7 @@ class RealSRPredictor():
self.output = os.path.join(output, 'RealSR') self.output = os.path.join(output, 'RealSR')
self.model = RRDBNet(3, 3, 64, 23) self.model = RRDBNet(3, 3, 64, 23)
if weight_path is None: if weight_path is None:
weight_path = get_path_from_url(RealSR_weight_url, cur_path) weight_path = get_path_from_url(REALSR_WEIGHT_URL, cur_path)
state_dict, _ = paddle.load(weight_path) state_dict, _ = paddle.load(weight_path)
self.model.load_dict(state_dict) self.model.load_dict(state_dict)
...@@ -88,8 +88,7 @@ class RealSRPredictor(): ...@@ -88,8 +88,7 @@ class RealSRPredictor():
vid_out_path = os.path.join(output_path, vid_out_path = os.path.join(output_path,
'{}_realsr_out.mp4'.format(base_name)) '{}_realsr_out.mp4'.format(base_name))
frames2video(frame_pattern_combined, vid_out_path, frames2video(frame_pattern_combined, vid_out_path, str(int(fps)))
str(int(fps)))
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册