diff --git a/applications/EDVR/configs/edvr_L.yaml b/applications/EDVR/configs/edvr_L.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91b05f945e5e8fcc751ce878fb67c513c752e5f4 --- /dev/null +++ b/applications/EDVR/configs/edvr_L.yaml @@ -0,0 +1,24 @@ +MODEL: + name: "EDVR" + format: "png" + num_frames: 5 + center: 2 + num_filters: 128 #64 + deform_conv_groups: 8 + front_RBs: 5 + back_RBs: 40 #10 + predeblur: False + HR_in: False + w_TSA: True #False + +INFER: + scale: 4 + crop_size: 256 + interval_list: [1] + random_reverse: False + number_frames: 5 + batch_size: 1 + file_root: "/workspace/color/input_frames" + inference_model: "/workspace/PaddleGAN/applications/EDVR/data/inference_model" + use_flip: False + use_rot: False diff --git a/applications/EDVR/predict.py b/applications/EDVR/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..c45904a2698d2e38702df2f91dc96b7ed1cf0ae8 --- /dev/null +++ b/applications/EDVR/predict.py @@ -0,0 +1,174 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import time +import logging +import argparse +import ast +import numpy as np +try: + import cPickle as pickle +except: + import pickle +import paddle.fluid as fluid +import cv2 + +from utils.config_utils import * +#import models +from reader import get_reader +#from metrics import get_metrics +from utils.utility import check_cuda +from utils.utility import check_version + +logging.root.handlers = [] +FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' +logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model_name', + type=str, + default='AttentionCluster', + help='name of model to train.') + parser.add_argument( + '--inference_model', + type=str, + default='./data/inference_model', + help='path of inference_model.') + parser.add_argument( + '--config', + type=str, + default='configs/attention_cluster.txt', + help='path to config file of model') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='default use gpu.') + parser.add_argument( + '--batch_size', + type=int, + default=1, + help='sample number in a batch for inference.') + parser.add_argument( + '--filelist', + type=str, + default=None, + help='path to inferenece data file lists file.') + parser.add_argument( + '--log_interval', + type=int, + default=1, + help='mini-batch interval to log.') + parser.add_argument( + '--infer_topk', + type=int, + default=20, + help='topk predictions to restore.') + parser.add_argument( + '--save_dir', + type=str, + default=os.path.join('data', 'predict_results'), + help='directory to store results') + parser.add_argument( + '--video_path', + type=str, + default=None, + help='directory to store results') + args = parser.parse_args() + return args + +def get_img(pred): + print('pred shape', pred.shape) + pred = pred.squeeze() + pred = np.clip(pred, a_min=0., a_max=1.0) + pred = pred * 255 + pred = pred.round() + pred = pred.astype('uint8') + pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc + pred = pred[:, :, ::-1] # rgb -> bgr + return pred + +def save_img(img, framename): + dirname = './demo/resultpng' + filename = os.path.join(dirname, framename+'.png') + cv2.imwrite(filename, img) + + +def infer(args): + # parse config + config = parse_config(args.config) + infer_config = merge_configs(config, 'infer', vars(args)) + print_configs(infer_config, "Infer") + inference_model = args.inference_model + model_filename = 'EDVR_model.pdmodel' + params_filename = 'EDVR_params.pdparams' + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + [inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=inference_model, model_filename=model_filename, params_filename=params_filename, executor=exe) + + infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config) + #infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config) + #infer_metrics.reset() + + periods = [] + cur_time = time.time() + for infer_iter, data in enumerate(infer_reader()): + if args.model_name == 'EDVR': + data_feed_in = [items[0] for items in data] + video_info = [items[1:] for items in data] + infer_outs = exe.run(inference_program, + fetch_list=fetch_list, + feed={feed_list[0]:np.array(data_feed_in)}) + infer_result_list = [item for item in infer_outs] + videonames = [item[0] for item in video_info] + framenames = [item[1] for item in video_info] + for i in range(len(infer_result_list)): + img_i = get_img(infer_result_list[i]) + save_img(img_i, 'img' + videonames[i] + framenames[i]) + + + + prev_time = cur_time + cur_time = time.time() + period = cur_time - prev_time + periods.append(period) + + #infer_metrics.accumulate(infer_result_list) + + if args.log_interval > 0 and infer_iter % args.log_interval == 0: + logger.info('Processed {} samples'.format(infer_iter + 1)) + + logger.info('[INFER] infer finished. average time: {}'.format(np.mean(periods))) + + if not os.path.isdir(args.save_dir): + os.makedirs(args.save_dir) + + #infer_metrics.finalize_and_log_out(savedir=args.save_dir) + + +if __name__ == "__main__": + args = parse_args() + # check whether the installed paddle is compiled with GPU + check_cuda(args.use_gpu) + check_version() + logger.info(args) + + infer(args) diff --git a/applications/EDVR/reader/__init__.py b/applications/EDVR/reader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..080f19930bb2f495ea071059d656fdd7ba4558ea --- /dev/null +++ b/applications/EDVR/reader/__init__.py @@ -0,0 +1,4 @@ +from .reader_utils import regist_reader, get_reader +from .edvr_reader import EDVRReader + +regist_reader("EDVR", EDVRReader) diff --git a/applications/EDVR/reader/edvr_reader.py b/applications/EDVR/reader/edvr_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..75191254ea6394d4047963c0d51981f2f5707adb --- /dev/null +++ b/applications/EDVR/reader/edvr_reader.py @@ -0,0 +1,434 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import cv2 +import math +import random +import multiprocessing +import functools +import numpy as np +import paddle +import cv2 +import logging +from .reader_utils import DataReader + +logger = logging.getLogger(__name__) +python_ver = sys.version_info + +random.seed(0) +np.random.seed(0) + +class EDVRReader(DataReader): + """ + Data reader for video super resolution task fit for EDVR model. + This is specified for REDS dataset. + """ + def __init__(self, name, mode, cfg): + super(EDVRReader, self).__init__(name, mode, cfg) + self.format = cfg.MODEL.format + self.crop_size = self.get_config_from_sec(mode, 'crop_size') + self.interval_list = self.get_config_from_sec(mode, 'interval_list') + self.random_reverse = self.get_config_from_sec(mode, 'random_reverse') + self.number_frames = self.get_config_from_sec(mode, 'number_frames') + # set batch size and file list + self.batch_size = cfg[mode.upper()]['batch_size'] + self.fileroot = cfg[mode.upper()]['file_root'] + self.use_flip = self.get_config_from_sec(mode, 'use_flip', False) + self.use_rot = self.get_config_from_sec(mode, 'use_rot', False) + + self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1) + self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024) + self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False) + + if self.mode != 'infer': + self.gtroot = self.get_config_from_sec(mode, 'gt_root') + self.scale = self.get_config_from_sec(mode, 'scale', 1) + self.LR_input = (self.scale > 1) + if self.fix_random_seed: + random.seed(0) + np.random.seed(0) + self.num_reader_threads = 1 + + def create_reader(self): + logger.info('initialize reader ... ') + self.filelist = [] + for video_name in os.listdir(self.fileroot): + if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']): + continue + for frame_name in os.listdir(os.path.join(self.fileroot, video_name)): + frame_idx = frame_name.split('.')[0] + video_frame_idx = video_name + '_' + frame_idx + # for each item in self.filelist is like '010_00000015', '260_00000090' + self.filelist.append(video_frame_idx) + if self.mode == 'test' or self.mode == 'infer': + self.filelist.sort() + + if self.num_reader_threads == 1: + reader_func = make_reader + else: + reader_func = make_multi_reader + + if self.mode != 'infer': + return reader_func(filelist = self.filelist, + num_threads = self.num_reader_threads, + batch_size = self.batch_size, + is_training = (self.mode == 'train'), + number_frames = self.number_frames, + interval_list = self.interval_list, + random_reverse = self.random_reverse, + fileroot = self.fileroot, + crop_size = self.crop_size, + use_flip = self.use_flip, + use_rot = self.use_rot, + gtroot = self.gtroot, + LR_input = self.LR_input, + scale = self.scale, + mode = self.mode) + else: + return reader_func(filelist = self.filelist, + num_threads = self.num_reader_threads, + batch_size = self.batch_size, + is_training = (self.mode == 'train'), + number_frames = self.number_frames, + interval_list = self.interval_list, + random_reverse = self.random_reverse, + fileroot = self.fileroot, + crop_size = self.crop_size, + use_flip = self.use_flip, + use_rot = self.use_rot, + gtroot = '', + LR_input = True, + scale = 4, + mode = self.mode) + + +def get_sample_data(item, number_frames, interval_list, random_reverse, fileroot, + crop_size, use_flip, use_rot, gtroot, LR_input, scale, mode='train'): + video_name = item.split('_')[0] + frame_name = item.split('_')[1] + if (mode == 'train') or (mode == 'valid'): + ngb_frames, name_b = get_neighbor_frames(frame_name, \ + number_frames = number_frames, \ + interval_list = interval_list, \ + random_reverse = random_reverse) + elif (mode == 'test') or (mode == 'infer'): + ngb_frames, name_b = get_test_neighbor_frames(int(frame_name), number_frames) + else: + raise NotImplementedError('mode {} not implemented'.format(mode)) + frame_name = name_b + print('key2', ngb_frames, name_b) + if mode != 'infer': + img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'), is_gt=True) + #print('gt_mean', np.mean(img_GT)) + frame_list = [] + for ngb_frm in ngb_frames: + ngb_name = "%04d"%ngb_frm + #img = read_img(os.path.join(fileroot, video_name, frame_name + '.png')) + img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png')) + frame_list.append(img) + #print('img_mean', np.mean(img)) + + H, W, C = frame_list[0].shape + # add random crop + if (mode == 'train') or (mode == 'valid'): + if LR_input: + LQ_size = crop_size // scale + rnd_h = random.randint(0, max(0, H - LQ_size)) + rnd_w = random.randint(0, max(0, W - LQ_size)) + #print('rnd_h {}, rnd_w {}', rnd_h, rnd_w) + frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list] + rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) + img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :] + else: + rnd_h = random.randint(0, max(0, H - crop_size)) + rnd_w = random.randint(0, max(0, W - crop_size)) + frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list] + img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] + + # add random flip and rotation + if mode != 'infer': + frame_list.append(img_GT) + if (mode == 'train') or (mode == 'valid'): + rlt = img_augment(frame_list, use_flip, use_rot) + else: + rlt = frame_list + if mode != 'infer': + frame_list = rlt[0:-1] + img_GT = rlt[-1] + else: + frame_list = rlt + + # stack LQ images to NHWC, N is the frame number + img_LQs = np.stack(frame_list, axis=0) + # BGR to RGB, HWC to CHW, numpy to tensor + img_LQs = img_LQs[:, :, :, [2, 1, 0]] + img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32') + if mode != 'infer': + img_GT = img_GT[:, :, [2, 1, 0]] + img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32') + + return img_LQs, img_GT + else: + return img_LQs + +def get_test_neighbor_frames(crt_i, N, max_n=100, padding='new_info'): + """Generate an index list for reading N frames from a sequence of images + Args: + crt_i (int): current center index + max_n (int): max number of the sequence of images (calculated from 1) + N (int): reading N frames + padding (str): padding mode, one of replicate | reflection | new_info | circle + Example: crt_i = 0, N = 5 + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + new_info: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + return_l (list [int]): a list of indexes + """ + max_n = max_n - 1 + n_pad = N // 2 + return_l = [] + + for i in range(crt_i - n_pad, crt_i + n_pad + 1): + if i < 0: + if padding == 'replicate': + add_idx = 0 + elif padding == 'reflection': + add_idx = -i + elif padding == 'new_info': + add_idx = (crt_i + n_pad) + (-i) + elif padding == 'circle': + add_idx = N + i + else: + raise ValueError('Wrong padding mode') + elif i > max_n: + if padding == 'replicate': + add_idx = max_n + elif padding == 'reflection': + add_idx = max_n * 2 - i + elif padding == 'new_info': + add_idx = (crt_i - n_pad) - (i - max_n) + elif padding == 'circle': + add_idx = i - N + else: + raise ValueError('Wrong padding mode') + else: + add_idx = i + return_l.append(add_idx) + name_b = '{:08d}'.format(crt_i) + return return_l, name_b + + +def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, max_frame=99, bordermode=False): + center_frame_idx = int(frame_name) + half_N_frames = number_frames // 2 + #### determine the neighbor frames + interval = random.choice(interval_list) + if bordermode: + direction = 1 # 1: forward; 0: backward + if random_reverse and random.random() < 0.5: + direction = random.choice([0, 1]) + if center_frame_idx + interval * (number_frames - 1) > max_frame: + direction = 0 + elif center_frame_idx - interval * (number_frames - 1) < 0: + direction = 1 + # get the neighbor list + if direction == 1: + neighbor_list = list( + range(center_frame_idx, center_frame_idx + interval * number_frames, interval)) + else: + neighbor_list = list( + range(center_frame_idx, center_frame_idx - interval * number_frames, -interval)) + name_b = '{:08d}'.format(neighbor_list[0]) + else: + # ensure not exceeding the borders + while (center_frame_idx + half_N_frames * interval > + max_frame) or (center_frame_idx - half_N_frames * interval < 0): + center_frame_idx = random.randint(0, max_frame) + # get the neighbor list + neighbor_list = list( + range(center_frame_idx - half_N_frames * interval, + center_frame_idx + half_N_frames * interval + 1, interval)) + if random_reverse and random.random() < 0.5: + neighbor_list.reverse() + name_b = '{:08d}'.format(neighbor_list[half_N_frames]) + assert len(neighbor_list) == number_frames, \ + "frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames) + + return neighbor_list, name_b + + +def read_img(path, size=None, is_gt=False): + """read image by cv2 + return: Numpy float32, HWC, BGR, [0,1]""" + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + #if not is_gt: + # #print(path) + # img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25) + #print("path: ", path) + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +def img_augment(img_list, hflip=True, rot=True): + """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +def make_reader(filelist, + num_threads, + batch_size, + is_training, + number_frames, + interval_list, + random_reverse, + fileroot, + crop_size, + use_flip, + use_rot, + gtroot, + LR_input, + scale, + mode='train'): + fl = filelist + def reader_(): + if is_training: + random.shuffle(fl) + batch_out = [] + for item in fl: + if mode != 'infer': + img_LQs, img_GT = get_sample_data(item, + number_frames, interval_list, random_reverse, fileroot, + crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode) + else: + img_LQs = get_sample_data(item, + number_frames, interval_list, random_reverse, fileroot, + crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode) + videoname = item.split('_')[0] + framename = item.split('_')[1] + if (mode == 'train') or (mode == 'valid'): + batch_out.append((img_LQs, img_GT)) + elif mode == 'test': + batch_out.append((img_LQs, img_GT, videoname, framename)) + elif mode == 'infer': + batch_out.append((img_LQs, videoname, framename)) + else: + raise NotImplementedError("mode {} not implemented".format(mode)) + if len(batch_out) == batch_size: + yield batch_out + batch_out = [] + return reader_ + + +def make_multi_reader(filelist, + num_threads, + batch_size, + is_training, + number_frames, + interval_list, + random_reverse, + fileroot, + crop_size, + use_flip, + use_rot, + gtroot, + LR_input, + scale, + mode='train'): + def read_into_queue(flq, queue): + batch_out = [] + for item in flq: + if mode != 'infer': + img_LQs, img_GT = get_sample_data(item, + number_frames, interval_list, random_reverse, fileroot, + crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode) + else: + img_LQs = get_sample_data(item, + number_frames, interval_list, random_reverse, fileroot, + crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode) + videoname = item.split('_')[0] + framename = item.split('_')[1] + if (mode == 'train') or (mode == 'valid'): + batch_out.append((img_LQs, img_GT)) + elif mode == 'test': + batch_out.append((img_LQs, img_GT, videoname, framename)) + elif mode == 'infer': + batch_out.append((img_LQs, videoname, framename)) + else: + raise NotImplementedError("mode {} not implemented".format(mode)) + if len(batch_out) == batch_size: + queue.put(batch_out) + batch_out = [] + queue.put(None) + + + def queue_reader(): + fl = filelist + if is_training: + random.shuffle(fl) + + n = num_threads + queue_size = 20 + reader_lists = [None] * n + file_num = int(len(fl) // n) + for i in range(n): + if i < len(reader_lists) - 1: + tmp_list = fl[i * file_num:(i + 1) * file_num] + else: + tmp_list = fl[i * file_num:] + reader_lists[i] = tmp_list + + queue = multiprocessing.Queue(queue_size) + p_list = [None] * len(reader_lists) + # for reader_list in reader_lists: + for i in range(len(reader_lists)): + reader_list = reader_lists[i] + p_list[i] = multiprocessing.Process( + target=read_into_queue, args=(reader_list, queue)) + p_list[i].start() + reader_num = len(reader_lists) + finish_num = 0 + while finish_num < reader_num: + sample = queue.get() + if sample is None: + finish_num += 1 + else: + yield sample + for i in range(len(p_list)): + if p_list[i].is_alive(): + p_list[i].join() + + return queue_reader diff --git a/applications/EDVR/reader/reader_utils.py b/applications/EDVR/reader/reader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..93e1f3a6aebe3f1812bf3b4ecb7745dd56cd8fcd --- /dev/null +++ b/applications/EDVR/reader/reader_utils.py @@ -0,0 +1,81 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import pickle +import cv2 +import numpy as np +import random + + +class ReaderNotFoundError(Exception): + "Error: reader not found" + + def __init__(self, reader_name, avail_readers): + super(ReaderNotFoundError, self).__init__() + self.reader_name = reader_name + self.avail_readers = avail_readers + + def __str__(self): + msg = "Reader {} Not Found.\nAvailiable readers:\n".format( + self.reader_name) + for reader in self.avail_readers: + msg += " {}\n".format(reader) + return msg + + +class DataReader(object): + """data reader for video input""" + + def __init__(self, model_name, mode, cfg): + self.name = model_name + self.mode = mode + self.cfg = cfg + + def create_reader(self): + """Not implemented""" + pass + + def get_config_from_sec(self, sec, item, default=None): + if sec.upper() not in self.cfg: + return default + return self.cfg[sec.upper()].get(item, default) + + +class ReaderZoo(object): + def __init__(self): + self.reader_zoo = {} + + def regist(self, name, reader): + assert reader.__base__ == DataReader, "Unknow model type {}".format( + type(reader)) + self.reader_zoo[name] = reader + + def get(self, name, mode, cfg): + for k, v in self.reader_zoo.items(): + if k == name: + return v(name, mode, cfg) + raise ReaderNotFoundError(name, self.reader_zoo.keys()) + + +# singleton reader_zoo +reader_zoo = ReaderZoo() + + +def regist_reader(name, reader): + reader_zoo.regist(name, reader) + + +def get_reader(name, mode, cfg): + reader_model = reader_zoo.get(name, mode, cfg) + return reader_model.create_reader() diff --git a/applications/EDVR/run.sh b/applications/EDVR/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..7f05d6b9d31354a8eecfb9c8656c9d9545b0bac7 --- /dev/null +++ b/applications/EDVR/run.sh @@ -0,0 +1,41 @@ +# examples of running programs: +# bash ./run.sh inference EDVR ./configs/edvr_L.yaml +# bash ./run.sh predict EDvR ./cofings/edvr_L.yaml + +# configs should be ./configs/xxx.yaml + +mode=$1 +name=$2 +configs=$3 + +save_inference_dir="./data/inference_model" +use_gpu=True +fix_random_seed=False +log_interval=1 +valid_interval=1 + +weights="./weights/paddle_state_dict_L.npz" + + +export CUDA_VISIBLE_DEVICES=4,5,6,7 #0,1,5,6 fast, 2,3,4,7 slow +export FLAGS_fast_eager_deletion_mode=1 +export FLAGS_eager_delete_tensor_gb=0.0 +export FLAGS_fraction_of_gpu_memory_to_use=0.98 + +if [ "$mode"x == "predict"x ]; then + echo $mode $name $configs $weights + if [ "$weights"x != ""x ]; then + python predict.py --model_name=$name \ + --config=$configs \ + --log_interval=$log_interval \ + --video_path='' \ + --use_gpu=$use_gpu + else + python predict.py --model_name=$name \ + --config=$configs \ + --log_interval=$log_interval \ + --use_gpu=$use_gpu \ + --video_path='' + fi +fi + diff --git a/applications/EDVR/utils/__init__.py b/applications/EDVR/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/applications/EDVR/utils/config_utils.py b/applications/EDVR/utils/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1acb9d28fc0bf15c54b210f6ed11c13d803c1043 --- /dev/null +++ b/applications/EDVR/utils/config_utils.py @@ -0,0 +1,75 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import yaml +from .utility import AttrDict +import logging +logger = logging.getLogger(__name__) + +CONFIG_SECS = [ + 'train', + 'valid', + 'test', + 'infer', +] + + +def parse_config(cfg_file): + """Load a config file into AttrDict""" + import yaml + with open(cfg_file, 'r') as fopen: + yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader)) + create_attr_dict(yaml_config) + return yaml_config + + +def create_attr_dict(yaml_config): + from ast import literal_eval + for key, value in yaml_config.items(): + if type(value) is dict: + yaml_config[key] = value = AttrDict(value) + if isinstance(value, str): + try: + value = literal_eval(value) + except BaseException: + pass + if isinstance(value, AttrDict): + create_attr_dict(yaml_config[key]) + else: + yaml_config[key] = value + return + + +def merge_configs(cfg, sec, args_dict): + assert sec in CONFIG_SECS, "invalid config section {}".format(sec) + sec_dict = getattr(cfg, sec.upper()) + for k, v in args_dict.items(): + if v is None: + continue + try: + if hasattr(sec_dict, k): + setattr(sec_dict, k, v) + except: + pass + return cfg + + +def print_configs(cfg, mode): + logger.info("---------------- {:>5} Arguments ----------------".format( + mode)) + for sec, sec_items in cfg.items(): + logger.info("{}:".format(sec)) + for k, v in sec_items.items(): + logger.info(" {}:{}".format(k, v)) + logger.info("-------------------------------------------------") diff --git a/applications/EDVR/utils/utility.py b/applications/EDVR/utils/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..ced1e7d757ff5697c0fe61f130849524491da3b0 --- /dev/null +++ b/applications/EDVR/utils/utility.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import sys +import signal +import logging +import paddle +import paddle.fluid as fluid + +__all__ = ['AttrDict'] + +logger = logging.getLogger(__name__) + + +def _term(sig_num, addition): + print('current pid is %s, group id is %s' % (os.getpid(), os.getpgrp())) + os.killpg(os.getpgid(os.getpid()), signal.SIGKILL) + + +signal.signal(signal.SIGTERM, _term) +signal.signal(signal.SIGINT, _term) + + +class AttrDict(dict): + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + if key in self.__dict__: + self.__dict__[key] = value + else: + self[key] = value + +def check_cuda(use_cuda, err = \ + "\nYou can not set use_gpu = True in the model because you are using paddlepaddle-cpu.\n \ + Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_gpu = False to run models on CPU.\n" + ): + try: + if use_cuda == True and fluid.is_compiled_with_cuda() == False: + print(err) + sys.exit(1) + except Exception as e: + pass + + +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1)