提交 ed2d8b38 编写于 作者: L lijianshe02

add edvr inference code

上级 38701e31
MODEL:
name: "EDVR"
format: "png"
num_frames: 5
center: 2
num_filters: 64
deform_conv_groups: 8
front_RBs: 5
back_RBs: 10
predeblur: False
HR_in: False
w_TSA: False
TRAIN:
epoch: 45
use_gpu: True
num_gpus: 8
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 32
#file_root: "/mnt/sungaofeng/edvr/paddle/video/data/dataset/edvr/REDS/train_sharp_bicubic/X4"
file_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS/train_sharp_bicubic/X4"
#gt_root: "/mnt/sungaofeng/edvr/paddle/video/data/dataset/edvr/REDS/train_sharp"
gt_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS/train_sharp"
use_flip: True
use_rot: True
base_learning_rate: 0.0004
l2_weight_decay: 0.0
T_periods: [150000, 150000, 150000, 150000] # for cosine annealing restart
restarts: [150000, 300000, 450000] # for cosine annealing restart
weights: [1, 1, 1] # for cosine annealing restart
eta_min: 1e-7 # for cosine annealing restart
num_reader_threads: 8
buf_size: 1024
fix_random_seed: False
VALID:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 32 #256
#file_root: "/mnt/sungaofeng/edvr/paddle/video/data/dataset/edvr/REDS4/GT"
file_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/sharp_bicubic"
#gt_root: "/mnt/sungaofeng/edvr/paddle/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/GT"
use_flip: False
use_rot: False
TEST:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
#file_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/sharp_bicubic"
#gt_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/GT"
file_root: "/mnt/sungaofeng/edvr/paddle/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root: "/mnt/sungaofeng/edvr/paddle/video/data/dataset/edvr/REDS4/GT"
use_flip: False
use_rot: False
INFER:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
file_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/sharp_bicubic"
gt_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/GT"
use_flip: False
use_rot: False
MODEL:
name: "EDVR"
format: "png"
num_frames: 5
center: 2
num_filters: 128 #64
deform_conv_groups: 8
front_RBs: 5
back_RBs: 40 #10
predeblur: False
HR_in: False
w_TSA: True #False
TRAIN:
epoch: 45
use_gpu: True
num_gpus: 4 #8
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 32
file_root: "/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp_bicubic/X4"
#file_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS/train_sharp_bicubic/X4"
gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS/train_sharp"
#gt_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS/train_sharp"
use_flip: True
use_rot: True
base_learning_rate: 0.0004
l2_weight_decay: 0.0
TSA_only: False
T_periods: [50000, 100000, 150000, 150000, 150000] # for cosine annealing restart
restarts: [50000, 150000, 300000, 450000] # for cosine annealing restart
weights: [1, 1, 1, 1] # for cosine annealing restart
eta_min: 1e-7 # for cosine annealing restart
num_reader_threads: 8
buf_size: 1024
fix_random_seed: False
VALID:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 32 #256
file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root: "/workspace/video_test/video/data/dataset/edvr/edvr/REDS4/GT"
use_flip: False
use_rot: False
TEST:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
#file_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/sharp_bicubic"
#gt_root: "/ssd1/vis/sungaofeng/edvr/mmsr/datasets/REDS4/GT"
file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
use_flip: False
use_rot: False
INFER:
scale: 4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 5
batch_size: 1
#file_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/sharp_bicubic"
#gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
file_root: "/workspace/color/input_frames"
gt_root: "/workspace/video_test/video/data/dataset/edvr/REDS4/GT"
use_flip: False
use_rot: False
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import time
import logging
import argparse
import ast
import numpy as np
import paddle.fluid as fluid
from utils.config_utils import *
import models
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
from utils.utility import check_version
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name',
type=str,
default='AttentionCluster',
help='name of model to train.')
parser.add_argument(
'--config',
type=str,
default='configs/attention_cluster.txt',
help='path to config file of model')
parser.add_argument(
'--batch_size',
type=int,
default=None,
help='test batch size. None to use config file setting.')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default=None,
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--save_dir',
type=str,
default=os.path.join('data', 'evaluate_results'),
help='output dir path, default to use ./data/evaluate_results')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
args = parser.parse_args()
return args
def test(args):
# parse config
config = parse_config(args.config)
test_config = merge_configs(config, 'test', vars(args))
print_configs(test_config, "Test")
# build model
test_model = models.get_model(args.model_name, test_config, mode='test')
test_model.build_input(use_dataloader=False)
test_model.build_model()
test_feeds = test_model.feeds()
test_fetch_list = test_model.fetches()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.weights:
assert os.path.exists(
args.weights), "Given weight dir {} not exist.".format(args.weights)
weights = args.weights or test_model.get_weights()
logger.info('load test weights from {}'.format(weights))
test_model.load_test_weights(exe, weights,
fluid.default_main_program(), place)
# get reader and metrics
test_reader = get_reader(args.model_name.upper(), 'test', test_config)
test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
epoch_period = []
for test_iter, data in enumerate(test_reader()):
cur_time = time.time()
if args.model_name == 'ETS':
feat_data = [items[:3] for items in data]
vinfo = [items[3:] for items in data]
test_outs = exe.run(fetch_list=test_fetch_list,
feed=test_feeder.feed(feat_data),
return_numpy=False)
test_outs += [vinfo]
elif args.model_name == 'TALL':
feat_data = [items[:2] for items in data]
vinfo = [items[2:] for items in data]
test_outs = exe.run(fetch_list=test_fetch_list,
feed=test_feeder.feed(feat_data),
return_numpy=True)
test_outs += [vinfo]
elif args.model_name == 'EDVR':
#img_data = [item[0] for item in data]
#gt_data = [item[1] for item in data]
#gt_data = gt_data[0]
#gt_data = np.transpose(gt_data, (1,2,0))
#gt_data = gt_data[:, :, ::-1]
#print('input', img_data)
#print('gt', gt_data)
feat_data = [items[:2] for items in data]
print("feat_data[0] shape: ", feat_data[0][0].shape)
exit()
vinfo = [items[2:] for items in data]
test_outs = exe.run(fetch_list=test_fetch_list,
feed=test_feeder.feed(feat_data),
return_numpy=True)
#output = test_outs[1]
#print('output', output)
test_outs += [vinfo]
else:
test_outs = exe.run(fetch_list=test_fetch_list,
feed=test_feeder.feed(data))
period = time.time() - cur_time
epoch_period.append(period)
test_metrics.accumulate(test_outs)
# metric here
if args.log_interval > 0 and test_iter % args.log_interval == 0:
info_str = '[EVAL] Batch {}'.format(test_iter)
test_metrics.calculate_and_log_out(test_outs, info_str)
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
test_metrics.finalize_and_log_out("[EVAL] eval finished. ", args.save_dir)
if __name__ == "__main__":
args = parse_args()
# check whether the installed paddle is compiled with GPU
check_cuda(args.use_gpu)
check_version()
logger.info(args)
test(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import time
import logging
import argparse
import ast
import numpy as np
try:
import cPickle as pickle
except:
import pickle
import paddle.fluid as fluid
from utils.config_utils import *
import models
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name',
type=str,
default='AttentionCluster',
help='name of model to train.')
parser.add_argument(
'--config',
type=str,
default='configs/attention_cluster.txt',
help='path to config file of model')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default=None,
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument(
'--save_dir',
type=str,
default='./',
help='directory to store model and params file')
args = parser.parse_args()
return args
def save_inference_model(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
print_configs(infer_config, "Infer")
infer_model = models.get_model(args.model_name, infer_config, mode='infer')
infer_model.build_input(use_dataloader=False)
infer_model.build_model()
infer_feeds = infer_model.feeds()
infer_outputs = infer_model.outputs()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.weights:
assert os.path.exists(
args.weights), "Given weight dir {} not exist.".format(args.weights)
# if no weight files specified, download weights from paddle
weights = args.weights or infer_model.get_weights()
infer_model.load_test_weights(exe, weights,
fluid.default_main_program(), place)
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
# saving inference model
fluid.io.save_inference_model(
args.save_dir,
feeded_var_names=[item.name for item in infer_feeds],
target_vars=infer_outputs,
executor=exe,
main_program=fluid.default_main_program(),
model_filename=args.model_name + "_model.pdmodel",
params_filename=args.model_name + "_params.pdparams")
print("save inference model at %s" % (args.save_dir))
if __name__ == "__main__":
args = parse_args()
# check whether the installed paddle is compiled with GPU
check_cuda(args.use_gpu)
logger.info(args)
save_inference_model(args)
from .metrics_util import get_metrics
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
import numpy as np
import datetime
import logging
import json
import os
import cv2
import math
logger = logging.getLogger(__name__)
class MetricsCalculator():
def __init__(
self,
name='EDVR',
mode='train'):
self.name = name
self.mode = mode # 'train', 'valid', 'test', 'infer'
self.reset()
self.total_frames = 9002 #100
self.bolder_frames = 2
def reset(self):
logger.info('Resetting {} metrics...'.format(self.mode))
if (self.mode == 'train') or (self.mode == 'valid'):
self.aggr_loss = 0.0
elif (self.mode == 'test') or (self.mode == 'infer'):
self.result_dict = dict()
def calculate_and_logout(self, fetch_list, info):
pass
def accumulate(self, fetch_list):
loss = fetch_list[0]
pred = fetch_list[1]
gt = fetch_list[2]
videoinfo = fetch_list[-1]
print('videoinfo: ', videoinfo)
videonames = [item[0] for item in videoinfo]
framenames = [item[1] for item in videoinfo]
for i in range(len(pred)):
pred_i = pred[i]
gt_i = gt[i]
videoname_i = videonames[i]
framename_i = framenames[i]
if videoname_i not in self.result_dict.keys():
self.result_dict[videoname_i] = {}
if framename_i in self.result_dict[videoname_i].keys():
logger.info("frame {} already processed in video {}, please check it".format(framename_i, videoname_i))
raise
is_bolder = (int(framename_i) > (self.total_frames - self.bolder_frames - 1)
or int(framename_i) < self.bolder_frames)
psnr_i = get_psnr(pred_i, gt_i)
img_i = get_img(pred_i)
self.result_dict[videoname_i][framename_i] = [is_bolder, psnr_i]
is_save = True
if is_save and (i == len(pred) - 1):
save_img(img_i, framename_i)
logger.info("video {}, frame {}, bolder {}, psnr = {}".format(videoname_i, framename_i, is_bolder, psnr_i))
def finalize_metrics(self, savedir):
avg_psnr = 0.
avg_psnr_center = 0.
avg_psnr_bolder = 0.
center_num = 0.
bolder_num = 0.
for videoname in self.result_dict.keys():
videoresult = self.result_dict[videoname]
framelist = list(videoresult.keys())
video_psnr_center = 0.
video_psnr_bolder = 0.
video_center_num = 0.
video_bolder_num = 0.
for frame in framelist:
frameresult = videoresult[frame]
is_bolder = frameresult[0]
psnr = frameresult[1]
if is_bolder:
video_bolder_num += 1
video_psnr_bolder += psnr
else:
video_center_num += 1
video_psnr_center += psnr
video_num = video_bolder_num + video_center_num
video_psnr = video_psnr_center + video_psnr_bolder
avg_psnr_bolder += video_psnr_bolder
avg_psnr_center += video_psnr_center
bolder_num += video_bolder_num
center_num += video_center_num
logger.info("video {}, total frame num/psnr {}/{}, center num/psnr {}/{}, bolder num/psnr {}/{}".format(
videoname, video_num, video_psnr/video_num,
video_center_num, video_psnr_center/video_center_num,
video_bolder_num, video_psnr_bolder/video_bolder_num))
avg_psnr = avg_psnr_bolder + avg_psnr_center
total_num = bolder_num + center_num
avg_psnr = avg_psnr / total_num
avg_psnr_center = avg_psnr_center / center_num
avg_psnr_bolder = avg_psnr_bolder / bolder_num
logger.info("Average psnr {}, center {}, bolder {}".format(avg_psnr, avg_psnr_center, avg_psnr_bolder))
def get_psnr(pred, gt):
# pred and gt have range [0, 1]
pred = pred.squeeze().astype(np.float64)
pred = pred * 255.
pred = pred.round()
gt = gt.squeeze().astype(np.float64)
gt = gt * 255.
gt = gt.round()
mse = np.mean((pred - gt)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def get_img(pred):
print('pred shape', pred.shape)
pred = pred.squeeze()
pred = np.clip(pred, a_min=0., a_max=1.0)
pred = pred * 255
pred = pred.round()
pred = pred.astype('uint8')
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc
pred = pred[:, :, ::-1] # rgb -> bgr
return pred
def save_img(img, framename):
dirname = './demo/resultpng'
filename = os.path.join(dirname, framename+'.png')
cv2.imwrite(filename, img)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import logging
import numpy as np
import json
from metrics.edvr_metrics import edvr_metrics as edvr_metrics
logger = logging.getLogger(__name__)
class Metrics(object):
def __init__(self, name, mode, metrics_args):
"""Not implemented"""
pass
def calculate_and_log_out(self, fetch_list, info=''):
"""Not implemented"""
pass
def accumulate(self, fetch_list, info=''):
"""Not implemented"""
pass
def finalize_and_log_out(self, info='', savedir='./'):
"""Not implemented"""
pass
def reset(self):
"""Not implemented"""
pass
class EDVRMetrics(Metrics):
def __init__(self, name, mode, cfg):
self.name = name
self.mode = mode
args = {}
args['mode'] = mode
args['name'] = name
self.calculator = edvr_metrics.MetricsCalculator(**args)
def calculate_and_log_out(self, fetch_list, info=''):
if (self.mode == 'train') or (self.mode == 'valid'):
loss = np.array(fetch_list[0])
logger.info(info + '\tLoss = {}'.format('%.04f' % np.mean(loss)))
elif self.mode == 'test':
pass
def accumulate(self, fetch_list):
self.calculator.accumulate(fetch_list)
def finalize_and_log_out(self, info='', savedir='./'):
self.calculator.finalize_metrics(savedir)
def reset(self):
self.calculator.reset()
class MetricsZoo(object):
def __init__(self):
self.metrics_zoo = {}
def regist(self, name, metrics):
assert metrics.__base__ == Metrics, "Unknow model type {}".format(
type(metrics))
self.metrics_zoo[name] = metrics
def get(self, name, mode, cfg):
for k, v in self.metrics_zoo.items():
if k == name:
return v(name, mode, cfg)
raise MetricsNotFoundError(name, self.metrics_zoo.keys())
# singleton metrics_zoo
metrics_zoo = MetricsZoo()
def regist_metrics(name, metrics):
metrics_zoo.regist(name, metrics)
def get_metrics(name, mode, cfg):
return metrics_zoo.get(name, mode, cfg)
# sort by alphabet
regist_metrics("EDVR", EDVRMetrics)
from .model import regist_model, get_model
#from .attention_cluster import AttentionCluster
#from .attention_lstm import AttentionLSTM
#from .nextvlad import NEXTVLAD
#from .nonlocal_model import NonLocal
#from .tsm import TSM
#from .tsn import TSN
#from .stnet import STNET
#from .ctcn import CTCN
#from .bmn import BMN
#from .bsn import BsnTem
#from .bsn import BsnPem
#from .ets import ETS
#from .tall import TALL
from .edvr import EDVR
# regist models, sort by alphabet
#regist_model("AttentionCluster", AttentionCluster)
#regist_model("AttentionLSTM", AttentionLSTM)
#regist_model("NEXTVLAD", NEXTVLAD)
#regist_model('NONLOCAL', NonLocal)
#regist_model("TSM", TSM)
#regist_model("TSN", TSN)
#regist_model("STNET", STNET)
#regist_model("CTCN", CTCN)
#regist_model("BMN", BMN)
#regist_model("BsnTem", BsnTem)
#regist_model("BsnPem", BsnPem)
#regist_model("ETS", ETS)
#regist_model("TALL", TALL)
regist_model("EDVR", EDVR)
# TSN 视频分类模型
---
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。
详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859)
## 数据准备
TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](../../data/dataset/README.md)
## 模型训练
数据准备完毕后,可以通过如下两种方式启动训练:
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
python train.py --model_name=TSN \
--config=./configs/tsn.yaml \
--log_interval=10 \
--valid_interval=1 \
--use_gpu=True \
--save_dir=./data/checkpoints \
--fix_random_seed=False \
--pretrain=$PATH_TO_PRETRAIN_MODEL
bash run.sh train TSN ./configs/tsn.yaml
- 从头开始训练,需要加载在ImageNet上训练的ResNet50权重作为初始化参数,请下载此[模型参数](https://paddlemodels.bj.bcebos.com/video_classification/ResNet50_pretrained.tar.gz)并解压,将上面启动命令行或者run.sh脚本中的`pretrain`参数设置为解压之后的模型参数
存放路径。如果没有手动下载并设置`pretrain`参数,则程序会自动下载并将参数保存在~/.paddle/weights/ResNet50\_pretrained目录下面
- 可下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)通过`--resume`指定权重存
放路径进行finetune等开发
**数据读取器说明:** 模型读取Kinetics-400数据集中的`mp4`数据,每条数据抽取`seg_num`段,每段抽取1帧图像,对每帧图像做随机增强后,缩放至`target_size`
**训练策略:**
* 采用Momentum优化算法训练,momentum=0.9
* 权重衰减系数为1e-4
* 学习率在训练的总epoch数的1/3和2/3时分别做0.1的衰减
## 模型评估
可通过如下两种方式进行模型评估:
python eval.py --model_name=TSN \
--config=./configs/tsn.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--use_gpu=True
bash run.sh eval TSN ./configs/tsn.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)进行评估
- 评估结果以log的形式直接打印输出TOP1\_ACC、TOP5\_ACC等精度指标
- 使用CPU进行评估时,请将上面的命令行或者run.sh脚本中的`use_gpu`设置为False
当取如下参数时,在Kinetics400的validation数据集下评估精度如下:
| seg\_num | target\_size | Top-1 |
| :------: | :----------: | :----: |
| 3 | 224 | 0.66 |
| 7 | 224 | 0.67 |
## 模型推断
可通过如下两种方式启动模型推断:
python predict.py --model_name=TSN \
--config=./configs/tsn.yaml \
--log_interval=1 \
--weights=$PATH_TO_WEIGHTS \
--filelist=$FILELIST \
--use_gpu=True \
--video_path=$VIDEO_PATH
bash run.sh predict TSN ./configs/tsn.yaml
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要用到的权重。
- 如果video\_path为'', 则忽略掉此参数。如果video\_path != '',则程序会对video\_path指定的视频文件进行预测,而忽略掉filelist的值,预测结果为此视频的分类概率。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_classification/TSN.pdparams)进行推断
- 模型推断结果以log的形式直接打印输出,可以看到测试样本的分类预测概率。
- 使用CPU进行推断时,请将命令行或者run.sh脚本中的`use_gpu`设置为False
## 参考论文
- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from ..model import ModelBase
from .edvr_model import EDVRModel
import logging
logger = logging.getLogger(__name__)
__all__ = ["EDVR"]
class EDVR(ModelBase):
def __init__(self, name, cfg, mode='train'):
super(EDVR, self).__init__(name, cfg, mode=mode)
self.get_config()
def get_config(self):
self.num_filters = self.get_config_from_sec('model', 'num_filters')
self.num_frames = self.get_config_from_sec('model', 'num_frames')
self.dcn_groups = self.get_config_from_sec('model', 'deform_conv_groups')
self.front_RBs = self.get_config_from_sec('model', 'front_RBs')
self.back_RBs = self.get_config_from_sec('model', 'back_RBs')
self.center = self.get_config_from_sec('model', 'center', 2)
self.predeblur = self.get_config_from_sec('model', 'predeblur', False)
self.HR_in = self.get_config_from_sec('model', 'HR_in', False)
self.w_TSA = self.get_config_from_sec('model', 'w_TSA', True)
self.crop_size = self.get_config_from_sec(self.mode, 'crop_size')
self.scale = self.get_config_from_sec(self.mode, 'scale', 1)
self.num_gpus = self.get_config_from_sec(self.mode, 'num_gpus', 8)
self.batch_size = self.get_config_from_sec(self.mode, 'batch_size', 256)
# get optimizer related parameters
self.base_learning_rate = self.get_config_from_sec('train', 'base_learning_rate')
self.l2_weight_decay = self.get_config_from_sec('train', 'l2_weight_decay')
self.T_periods = self.get_config_from_sec('train', 'T_periods')
self.restarts = self.get_config_from_sec('train', 'restarts')
self.weights = self.get_config_from_sec('train', 'weights')
self.eta_min = self.get_config_from_sec('train', 'eta_min')
self.TSA_only = self.get_config_from_sec('train', 'TSA_only', False)
def build_input(self, use_dataloader=True):
if self.mode != 'test':
gt_shape = [None, 3, self.crop_size, self.crop_size]
else:
gt_shape = [None, 3, 720, 1280]
if self.HR_in:
img_shape = [-1, self.num_frames, 3, self.crop_size, self.crop_size]
else:
if (self.mode != 'test') and (self.mode != 'infer') :
img_shape = [None, self.num_frames, 3, \
int(self.crop_size / self.scale), int(self.crop_size / self.scale)]
else:
img_shape = [None, self.num_frames, 3, 360, 472] #180, 320]
self.use_dataloader = use_dataloader
image = fluid.data(name='LQ_IMGs', shape=img_shape, dtype='float32')
if self.mode != 'infer':
label = fluid.data(name='GT_IMG', shape=gt_shape, dtype='float32')
else:
label = None
if use_dataloader:
assert self.mode != 'infer', \
'dataloader is not recommendated when infer, please set use_dataloader to be false.'
self.dataloader = fluid.io.DataLoader.from_generator(
feed_list=[image, label], capacity=4, iterable=True)
self.feature_input = [image]
self.label_input = label
def create_model_args(self):
cfg = {}
cfg['nf'] = self.num_filters
cfg['nframes'] = self.num_frames
cfg['groups'] = self.dcn_groups
cfg['front_RBs'] = self.front_RBs
cfg['back_RBs'] = self.back_RBs
cfg['center'] = self.center
cfg['predeblur'] = self.predeblur
cfg['HR_in'] = self.HR_in
cfg['w_TSA'] = self.w_TSA
cfg['mode'] = self.mode
cfg['TSA_only'] = self.TSA_only
return cfg
def build_model(self):
cfg = self.create_model_args()
videomodel = EDVRModel(**cfg)
out = videomodel.net(self.feature_input[0])
self.network_outputs = [out]
def optimizer(self):
assert self.mode == 'train', "optimizer only can be get in train mode"
learning_rate = get_lr(base_lr = self.base_learning_rate,
T_periods=self.T_periods,
restarts=self.restarts,
weights=self.weights,
eta_min=self.eta_min)
l2_weight_decay = self.l2_weight_decay
optimizer = fluid.optimizer.Adam(
learning_rate = learning_rate,
beta1 = 0.9,
beta2 = 0.99,
regularization=fluid.regularizer.L2Decay(l2_weight_decay))
return optimizer
def loss(self):
assert self.mode != 'infer', "invalid loss calculationg in infer mode"
pred = self.network_outputs[0]
label = self.label_input
epsilon = 1e-6
diff = pred - label
diff = diff * diff + epsilon
diff = fluid.layers.sqrt(diff)
self.loss_ = fluid.layers.reduce_sum(diff)
return self.loss_
def outputs(self):
return self.network_outputs
def feeds(self):
return self.feature_input if self.mode == 'infer' else self.feature_input + [
self.label_input
]
def fetches(self):
if self.mode == 'train' or self.mode == 'valid':
losses = self.loss()
fetch_list = [losses, self.network_outputs[0], self.label_input]
elif self.mode == 'test':
losses = self.loss()
fetch_list = [losses, self.network_outputs[0], self.label_input]
elif self.mode == 'infer':
fetch_list = self.network_outputs
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
return fetch_list
def pretrain_info(self):
return (
None,
None
)
def weights_info(self):
return (
None,
None
)
def load_pretrain_params0(self, exe, pretrain, prog, place):
"""load pretrain form .npz which is created by torch"""
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter)
params_list = list(filter(is_parameter, prog.list_vars()))
import numpy as np
state_dict = np.load(pretrain)
for p in params_list:
if p.name in state_dict.keys():
print('########### load param {} from file'.format(p.name))
else:
print('----------- param {} not in file'.format(p.name))
fluid.set_program_state(prog, state_dict)
print('load pretrain from ', pretrain)
def load_test_weights(self, exe, weights, prog, place):
"""load weights from .npz which is created by torch"""
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter)
params_list = list(filter(is_parameter, prog.list_vars()))
import numpy as np
state_dict = np.load(weights)
for p in params_list:
if p.name in state_dict.keys():
print('########### load param {} from file'.format(p.name))
else:
print('----------- param {} not in file'.format(p.name))
fluid.set_program_state(prog, state_dict)
print('load weights from ', weights)
# This is for learning rate cosine annealing restart
Dtype='float32'
def decay_step_counter(begin=0):
# the first global step is zero in learning rate decay
global_step = fluid.layers.autoincreased_step_counter(
counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
return global_step
def get_lr(base_lr = 0.001,
T_periods = [250000, 250000, 250000, 250000],
restarts = [250000, 500000, 750000],
weights=[1, 1, 1],
eta_min=0):
with fluid.default_main_program()._lr_schedule_guard():
global_step = decay_step_counter()
lr = fluid.layers.create_global_var(shape=[1], value=base_lr, dtype=Dtype, persistable=True, name="learning_rate")
num_segs = len(restarts)
restart_point = 0
with fluid.layers.Switch() as switch:
with switch.case(global_step == 0):
pass
for i in range(num_segs):
T_max = T_periods[i]
weight = weights[i]
with switch.case(global_step < restarts[i]):
with fluid.layers.Switch() as switch_second:
value_2Tmax = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2*T_max)
step_checker = global_step-restart_point-1-T_max
with switch_second.case(fluid.layers.elementwise_mod(step_checker, value_2Tmax)==0):
var_value = lr + (base_lr - eta_min) * (1 - math.cos(math.pi / float(T_max))) / 2
fluid.layers.assign(var_value, lr)
with switch_second.default():
double_step = fluid.layers.cast(global_step, dtype='float64') - float(restart_point)
double_scale = (1 + fluid.layers.cos(math.pi * double_step / float(T_max))) / \
(1 + fluid.layers.cos(math.pi * (double_step - 1) / float(T_max)))
float_scale = fluid.layers.cast(double_scale, dtype=Dtype)
var_value = float_scale * (lr - eta_min) + eta_min
fluid.layers.assign(var_value, lr)
with switch.case(global_step == restarts[i]):
var_value = fluid.layers.fill_constant(
shape=[1], dtype=Dtype, value=float(base_lr*weight))
fluid.layers.assign(var_value, lr)
restart_point = restarts[i]
T_max = T_periods[num_segs]
with switch.default():
with fluid.layers.Switch() as switch_second:
value_2Tmax = fluid.layers.fill_constant(shape=[1], dtype='int64', value=2*T_max)
step_checker = global_step-restart_point-1-T_max
with switch_second.case(fluid.layers.elementwise_mod(step_checker, value_2Tmax)==0):
var_value = lr + (base_lr - eta_min) * (1 - math.cos(math.pi / float(T_max))) / 2
fluid.layers.assign(var_value, lr)
with switch_second.default():
double_step = fluid.layers.cast(global_step, dtype='float64') - float(restart_point)
double_scale = (1 + fluid.layers.cos(math.pi * double_step / float(T_max))) / \
(1 + fluid.layers.cos(math.pi * (double_step - 1) / float(T_max)))
float_scale = fluid.layers.cast(double_scale, dtype=Dtype)
var_value = float_scale * (lr - eta_min) + eta_min
fluid.layers.assign(var_value, lr)
return lr
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import wget
import logging
try:
from configparser import ConfigParser
except:
from ConfigParser import ConfigParser
import paddle.fluid as fluid
from .utils import download, AttrDict
WEIGHT_DIR = os.path.join(os.path.expanduser('~'), '.paddle', 'weights')
logger = logging.getLogger(__name__)
def is_parameter(var):
return isinstance(var, fluid.framework.Parameter)
class NotImplementError(Exception):
"Error: model function not implement"
def __init__(self, model, function):
super(NotImplementError, self).__init__()
self.model = model.__class__.__name__
self.function = function.__name__
def __str__(self):
return "Function {}() is not implemented in model {}".format(
self.function, self.model)
class ModelNotFoundError(Exception):
"Error: model not found"
def __init__(self, model_name, avail_models):
super(ModelNotFoundError, self).__init__()
self.model_name = model_name
self.avail_models = avail_models
def __str__(self):
msg = "Model {} Not Found.\nAvailiable models:\n".format(
self.model_name)
for model in self.avail_models:
msg += " {}\n".format(model)
return msg
class ModelBase(object):
def __init__(self, name, cfg, mode='train'):
assert mode in ['train', 'valid', 'test', 'infer'], \
"Unknown mode type {}".format(mode)
self.name = name
self.is_training = (mode == 'train')
self.mode = mode
self.cfg = cfg
self.dataloader = None
def build_model(self):
"build model struct"
raise NotImplementError(self, self.build_model)
def build_input(self, use_dataloader):
"build input Variable"
raise NotImplementError(self, self.build_input)
def optimizer(self):
"get model optimizer"
raise NotImplementError(self, self.optimizer)
def outputs():
"get output variable"
raise notimplementerror(self, self.outputs)
def loss(self):
"get loss variable"
raise notimplementerror(self, self.loss)
def feeds(self):
"get feed inputs list"
raise NotImplementError(self, self.feeds)
def fetches(self):
"get fetch list of model"
raise NotImplementError(self, self.fetches)
def weights_info(self):
"get model weight default path and download url"
raise NotImplementError(self, self.weights_info)
def get_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.weights_info()
path = os.path.join(WEIGHT_DIR, path)
if not os.path.isdir(WEIGHT_DIR):
logger.info('{} not exists, will be created automatically.'.format(
WEIGHT_DIR))
os.makedirs(WEIGHT_DIR)
if os.path.exists(path):
return path
logger.info("Download weights of {} from {}".format(self.name, url))
wget.download(url, path)
return path
def dataloader(self):
return self.dataloader
def epoch_num(self):
"get train epoch num"
return self.cfg.TRAIN.epoch
def pretrain_info(self):
"get pretrain base model directory"
return (None, None)
def get_pretrain_weights(self):
"get model weight file path, download weight from Paddle if not exist"
path, url = self.pretrain_info()
if not path:
return None
path = os.path.join(WEIGHT_DIR, path)
if not os.path.isdir(WEIGHT_DIR):
logger.info('{} not exists, will be created automatically.'.format(
WEIGHT_DIR))
os.makedirs(WEIGHT_DIR)
if os.path.exists(path):
return path
logger.info("Download pretrain weights of {} from {}".format(self.name,
url))
download(url, path)
return path
def load_pretrain_params(self, exe, pretrain, prog, place):
logger.info("Load pretrain weights from {}".format(pretrain))
state_dict = fluid.load_program_state(pretrain)
fluid.set_program_state(prog, state_dict)
def load_test_weights(self, exe, weights, prog, place):
params_list = list(filter(is_parameter, prog.list_vars()))
fluid.load(prog, weights, executor=exe, var_list=params_list)
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ModelZoo(object):
def __init__(self):
self.model_zoo = {}
def regist(self, name, model):
assert model.__base__ == ModelBase, "Unknow model type {}".format(
type(model))
self.model_zoo[name] = model
def get(self, name, cfg, mode='train'):
for k, v in self.model_zoo.items():
if k.upper() == name.upper():
return v(name, cfg, mode)
raise ModelNotFoundError(name, self.model_zoo.keys())
# singleton model_zoo
model_zoo = ModelZoo()
def regist_model(name, model):
model_zoo.regist(name, model)
def get_model(name, cfg, mode='train'):
return model_zoo.get(name, cfg, mode)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import wget
import tarfile
__all__ = ['decompress', 'download', 'AttrDict']
def decompress(path):
t = tarfile.open(path)
t.extractall(path=os.path.split(path)[0])
t.close()
os.remove(path)
def download(url, path):
weight_dir = os.path.split(path)[0]
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
path = path + ".tar.gz"
wget.download(url, path)
decompress(path)
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import time
import logging
import argparse
import ast
import numpy as np
try:
import cPickle as pickle
except:
import pickle
import paddle.fluid as fluid
import cv2
from utils.config_utils import *
import models
from reader import get_reader
from metrics import get_metrics
from utils.utility import check_cuda
from utils.utility import check_version
logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.DEBUG, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_name',
type=str,
default='AttentionCluster',
help='name of model to train.')
parser.add_argument(
'--config',
type=str,
default='configs/attention_cluster.txt',
help='path to config file of model')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='default use gpu.')
parser.add_argument(
'--weights',
type=str,
default=None,
help='weight path, None to automatically download weights provided by Paddle.'
)
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='sample number in a batch for inference.')
parser.add_argument(
'--filelist',
type=str,
default=None,
help='path to inferenece data file lists file.')
parser.add_argument(
'--log_interval',
type=int,
default=1,
help='mini-batch interval to log.')
parser.add_argument(
'--infer_topk',
type=int,
default=20,
help='topk predictions to restore.')
parser.add_argument(
'--save_dir',
type=str,
default=os.path.join('data', 'predict_results'),
help='directory to store results')
parser.add_argument(
'--video_path',
type=str,
default=None,
help='directory to store results')
args = parser.parse_args()
return args
def get_img(pred):
print('pred shape', pred.shape)
pred = pred.squeeze()
pred = np.clip(pred, a_min=0., a_max=1.0)
pred = pred * 255
pred = pred.round()
pred = pred.astype('uint8')
pred = np.transpose(pred, (1, 2, 0)) # chw -> hwc
pred = pred[:, :, ::-1] # rgb -> bgr
return pred
def save_img(img, framename):
dirname = './demo/resultpng'
filename = os.path.join(dirname, framename+'.png')
cv2.imwrite(filename, img)
def infer(args):
# parse config
config = parse_config(args.config)
infer_config = merge_configs(config, 'infer', vars(args))
print_configs(infer_config, "Infer")
#infer_model = models.get_model(args.model_name, infer_config, mode='infer')
#infer_model.build_input(use_dataloader=False)
#infer_model.build_model()
#infer_feeds = infer_model.feeds()
#infer_outputs = infer_model.outputs()
model_path = '/workspace/video_test/video/for_eval/data/inference_model'
model_filename = 'EDVR_model.pdmodel'
params_filename = 'EDVR_params.pdparams'
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_list, fetch_list] = fluid.io.load_inference_model(dirname=model_path, model_filename=model_filename, params_filename=params_filename, executor=exe)
#filelist = args.filelist or infer_config.INFER.filelist
#filepath = args.video_path or infer_config.INFER.get('filepath', '')
#if filepath != '':
# assert os.path.exists(filepath), "{} not exist.".format(filepath)
#else:
# assert os.path.exists(filelist), "{} not exist.".format(filelist)
# get infer reader
infer_feeds = [inference_program.global_block().var(var_name) for var_name in feed_list]
infer_reader = get_reader(args.model_name.upper(), 'infer', infer_config)
#print(inference_program)
#if args.weights:
# assert os.path.exists(
# args.weights), "Given weight dir {} not exist.".format(args.weights)
# if no weight files specified, download weights from paddle
#weights = args.weights or infer_model.get_weights()
#infer_model.load_test_weights(exe, weights,
# fluid.default_main_program(), place)
infer_feeder = fluid.DataFeeder(place=place, feed_list=infer_feeds)
#fetch_list = infer_model.fetches()
#infer_metrics = get_metrics(args.model_name.upper(), 'infer', infer_config)
#infer_metrics.reset()
periods = []
cur_time = time.time()
for infer_iter, data in enumerate(infer_reader()):
if args.model_name == 'EDVR':
data_feed_in = [items[0] for items in data]
video_info = [items[1:] for items in data]
infer_outs = exe.run(inference_program,
fetch_list=fetch_list,
feed={feed_list[0]:np.array(data_feed_in)})
infer_result_list = [item for item in infer_outs]
videonames = [item[0] for item in video_info]
framenames = [item[1] for item in video_info]
for i in range(len(infer_result_list)):
img_i = get_img(infer_result_list[i])
save_img(img_i, 'img' + videonames[i] + framenames[i])
prev_time = cur_time
cur_time = time.time()
period = cur_time - prev_time
periods.append(period)
#infer_metrics.accumulate(infer_result_list)
if args.log_interval > 0 and infer_iter % args.log_interval == 0:
logger.info('Processed {} samples'.format(infer_iter + 1))
logger.info('[INFER] infer finished. average time: {}'.format(np.mean(periods)))
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
#infer_metrics.finalize_and_log_out(savedir=args.save_dir)
if __name__ == "__main__":
args = parse_args()
# check whether the installed paddle is compiled with GPU
check_cuda(args.use_gpu)
check_version()
logger.info(args)
infer(args)
from .reader_utils import regist_reader, get_reader
#from .feature_reader import FeatureReader
#from .kinetics_reader import KineticsReader
#from .nonlocal_reader import NonlocalReader
#from .ctcn_reader import CTCNReader
#from .bmn_reader import BMNReader
#from .bsn_reader import BSNVideoReader
#from .bsn_reader import BSNProposalReader
#from .ets_reader import ETSReader
#from .tall_reader import TALLReader
from .edvr_reader import EDVRReader
# regist reader, sort by alphabet
#regist_reader("ATTENTIONCLUSTER", FeatureReader)
#regist_reader("ATTENTIONLSTM", FeatureReader)
#regist_reader("NEXTVLAD", FeatureReader)
#regist_reader("NONLOCAL", NonlocalReader)
#regist_reader("TSM", KineticsReader)
#regist_reader("TSN", KineticsReader)
#regist_reader("STNET", KineticsReader)
#regist_reader("CTCN", CTCNReader)
#regist_reader("BMN", BMNReader)
#regist_reader("BSNTEM", BSNVideoReader)
#regist_reader("BSNPEM", BSNProposalReader)
#regist_reader("ETS", ETSReader)
#regist_reader("TALL", TALLReader)
regist_reader("EDVR", EDVRReader)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.mode == 'test' or self.mode == 'infer':
self.filelist.sort()
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
if self.mode != 'infer':
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
fileroot = self.fileroot,
crop_size = self.crop_size,
use_flip = self.use_flip,
use_rot = self.use_rot,
gtroot = self.gtroot,
LR_input = self.LR_input,
scale = self.scale,
mode = self.mode)
else:
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
fileroot = self.fileroot,
crop_size = self.crop_size,
use_flip = self.use_flip,
use_rot = self.use_rot,
gtroot = '',
LR_input = True,
scale = 4,
mode = self.mode)
def get_sample_data(item, number_frames, interval_list, random_reverse, fileroot,
crop_size, use_flip, use_rot, gtroot, LR_input, scale, mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
elif (mode == 'test') or (mode == 'infer'):
ngb_frames, name_b = get_test_neighbor_frames(int(frame_name), number_frames)
else:
raise NotImplementedError('mode {} not implemented'.format(mode))
frame_name = name_b
print('key2', ngb_frames, name_b)
if mode != 'infer':
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'), is_gt=True)
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%04d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
if mode != 'infer':
frame_list.append(img_GT)
if (mode == 'train') or (mode == 'valid'):
rlt = img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
if mode != 'infer':
frame_list = rlt[0:-1]
img_GT = rlt[-1]
else:
frame_list = rlt
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
if mode != 'infer':
img_GT = img_GT[:, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
return img_LQs, img_GT
else:
return img_LQs
def get_test_neighbor_frames(crt_i, N, max_n=100, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
name_b = '{:08d}'.format(crt_i)
return return_l, name_b
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, max_frame=99, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > max_frame:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
max_frame) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, max_frame)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None, is_gt=False):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
#if not is_gt:
# #print(path)
# img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
fileroot,
crop_size,
use_flip,
use_rot,
gtroot,
LR_input,
scale,
mode='train'):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
if mode != 'infer':
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, fileroot,
crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode)
else:
img_LQs = get_sample_data(item,
number_frames, interval_list, random_reverse, fileroot,
crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
elif mode == 'infer':
batch_out.append((img_LQs, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
fileroot,
crop_size,
use_flip,
use_rot,
gtroot,
LR_input,
scale,
mode='train'):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
if mode != 'infer':
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, fileroot,
crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode)
else:
img_LQs = get_sample_data(item,
number_frames, interval_list, random_reverse, fileroot,
crop_size,use_flip, use_rot, gtroot, LR_input, scale, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
elif mode == 'infer':
batch_out.append((img_LQs, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.mode == 'test':
self.filelist.sort()
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
gtroot = self.gtroot,
fileroot = self.fileroot,
LR_input = self.LR_input,
crop_size = self.crop_size,
scale = self.scale,
use_flip = self.use_flip,
use_rot = self.use_rot,
mode = self.mode)
def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode='train'):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
elif mode == 'test':
ngb_frames, name_b = get_test_neighbor_frames(int(frame_name), number_frames)
else:
raise NotImplementedError('mode {} not implemented'.format(mode))
frame_name = name_b
print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'), is_gt=True)
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if (mode == 'train') or (mode == 'valid'):
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
if (mode == 'train') or (mode == 'valid'):
rlt = img_augment(frame_list, use_flip, use_rot)
else:
rlt = frame_list
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_test_neighbor_frames(crt_i, N, max_n=100, padding='new_info'):
"""Generate an index list for reading N frames from a sequence of images
Args:
crt_i (int): current center index
max_n (int): max number of the sequence of images (calculated from 1)
N (int): reading N frames
padding (str): padding mode, one of replicate | reflection | new_info | circle
Example: crt_i = 0, N = 5
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
new_info: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
return_l (list [int]): a list of indexes
"""
max_n = max_n - 1
n_pad = N // 2
return_l = []
for i in range(crt_i - n_pad, crt_i + n_pad + 1):
if i < 0:
if padding == 'replicate':
add_idx = 0
elif padding == 'reflection':
add_idx = -i
elif padding == 'new_info':
add_idx = (crt_i + n_pad) + (-i)
elif padding == 'circle':
add_idx = N + i
else:
raise ValueError('Wrong padding mode')
elif i > max_n:
if padding == 'replicate':
add_idx = max_n
elif padding == 'reflection':
add_idx = max_n * 2 - i
elif padding == 'new_info':
add_idx = (crt_i - n_pad) - (i - max_n)
elif padding == 'circle':
add_idx = i - N
else:
raise ValueError('Wrong padding mode')
else:
add_idx = i
return_l.append(add_idx)
name_b = '{:08d}'.format(crt_i)
return return_l, name_b
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, max_frame=99, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > max_frame:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
max_frame) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, max_frame)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None, is_gt=False):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
#if not is_gt:
# #print(path)
# img = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot,
mode='train'):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot, mode)
videoname = item.split('_')[0]
framename = item.split('_')[1]
if (mode == 'train') or (mode == 'valid'):
batch_out.append((img_LQs, img_GT))
elif mode == 'test':
batch_out.append((img_LQs, img_GT, videoname, framename))
else:
raise NotImplementedError("mode {} not implemented".format(mode))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import cv2
import math
import random
import multiprocessing
import functools
import numpy as np
import paddle
import cv2
import logging
from .reader_utils import DataReader
logger = logging.getLogger(__name__)
python_ver = sys.version_info
random.seed(0)
np.random.seed(0)
class EDVRReader(DataReader):
"""
Data reader for video super resolution task fit for EDVR model.
This is specified for REDS dataset.
"""
def __init__(self, name, mode, cfg):
super(EDVRReader, self).__init__(name, mode, cfg)
self.format = cfg.MODEL.format
self.crop_size = self.get_config_from_sec(mode, 'crop_size')
self.interval_list = self.get_config_from_sec(mode, 'interval_list')
self.random_reverse = self.get_config_from_sec(mode, 'random_reverse')
self.number_frames = self.get_config_from_sec(mode, 'number_frames')
# set batch size and file list
self.batch_size = cfg[mode.upper()]['batch_size']
self.fileroot = cfg[mode.upper()]['file_root']
self.use_flip = self.get_config_from_sec(mode, 'use_flip', False)
self.use_rot = self.get_config_from_sec(mode, 'use_rot', False)
self.num_reader_threads = self.get_config_from_sec(mode, 'num_reader_threads', 1)
self.buf_size = self.get_config_from_sec(mode, 'buf_size', 1024)
self.fix_random_seed = self.get_config_from_sec(mode, 'fix_random_seed', False)
if self.mode != 'infer':
self.gtroot = self.get_config_from_sec(mode, 'gt_root')
self.scale = self.get_config_from_sec(mode, 'scale', 1)
self.LR_input = (self.scale > 1)
if self.fix_random_seed:
random.seed(0)
np.random.seed(0)
self.num_reader_threads = 1
"""
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
#self.filelist.sort()
def reader_():
### not implemented border mode, maybe add later ############
if self.mode == 'train':
random.shuffle(self.filelist)
for item in self.filelist:
#print(item)
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = self.number_frames, \
interval_list = self.interval_list, \
random_reverse = self.random_reverse)
frame_name = name_b
#print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(self.gtroot, video_name, frame_name + '.png'))
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(self.fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(self.fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if self.LR_input:
LQ_size = self.crop_size // self.scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * self.scale), int(rnd_w * self.scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + self.crop_size, rnd_w_HR:rnd_w_HR + self.crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - self.crop_size))
rnd_w = random.randint(0, max(0, W - self.crop_size))
frame_list = [v[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + self.crop_size, rnd_w:rnd_w + self.crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
rlt = img_augment(frame_list, self.use_flip, self.use_rot)
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
yield img_LQs, img_GT
def _batch_reader():
batch_out = []
for img_LQs, img_GT in reader_():
#print('lq', img_LQs.shape)
#print('gt', img_GT.shape)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return _batch_reader
"""
def create_reader(self):
logger.info('initialize reader ... ')
self.filelist = []
for video_name in os.listdir(self.fileroot):
if (self.mode == 'train') and (video_name in ['000', '011', '015', '020']):
continue
for frame_name in os.listdir(os.path.join(self.fileroot, video_name)):
frame_idx = frame_name.split('.')[0]
video_frame_idx = video_name + '_' + frame_idx
# for each item in self.filelist is like '010_00000015', '260_00000090'
self.filelist.append(video_frame_idx)
if self.num_reader_threads == 1:
reader_func = make_reader
else:
reader_func = make_multi_reader
return reader_func(filelist = self.filelist,
num_threads = self.num_reader_threads,
batch_size = self.batch_size,
is_training = (self.mode == 'train'),
number_frames = self.number_frames,
interval_list = self.interval_list,
random_reverse = self.random_reverse,
gtroot = self.gtroot,
fileroot = self.fileroot,
LR_input = self.LR_input,
crop_size = self.crop_size,
scale = self.scale,
use_flip = self.use_flip,
use_rot = self.use_rot)
def get_sample_data(item, number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot):
video_name = item.split('_')[0]
frame_name = item.split('_')[1]
ngb_frames, name_b = get_neighbor_frames(frame_name, \
number_frames = number_frames, \
interval_list = interval_list, \
random_reverse = random_reverse)
frame_name = name_b
#print('key2', ngb_frames, name_b)
img_GT = read_img(os.path.join(gtroot, video_name, frame_name + '.png'))
#print('gt_mean', np.mean(img_GT))
frame_list = []
for ngb_frm in ngb_frames:
ngb_name = "%08d"%ngb_frm
#img = read_img(os.path.join(fileroot, video_name, frame_name + '.png'))
img = read_img(os.path.join(fileroot, video_name, ngb_name + '.png'))
frame_list.append(img)
#print('img_mean', np.mean(img))
H, W, C = frame_list[0].shape
# add random crop
if LR_input:
LQ_size = crop_size // scale
rnd_h = random.randint(0, max(0, H - LQ_size))
rnd_w = random.randint(0, max(0, W - LQ_size))
#print('rnd_h {}, rnd_w {}', rnd_h, rnd_w)
frame_list = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in frame_list]
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
img_GT = img_GT[rnd_h_HR:rnd_h_HR + crop_size, rnd_w_HR:rnd_w_HR + crop_size, :]
else:
rnd_h = random.randint(0, max(0, H - crop_size))
rnd_w = random.randint(0, max(0, W - crop_size))
frame_list = [v[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :] for v in frame_list]
img_GT = img_GT[rnd_h:rnd_h + crop_size, rnd_w:rnd_w + crop_size, :]
# add random flip and rotation
frame_list.append(img_GT)
rlt = img_augment(frame_list, use_flip, use_rot)
frame_list = rlt[0:-1]
img_GT = rlt[-1]
# stack LQ images to NHWC, N is the frame number
img_LQs = np.stack(frame_list, axis=0)
# BGR to RGB, HWC to CHW, numpy to tensor
img_GT = img_GT[:, :, [2, 1, 0]]
img_LQs = img_LQs[:, :, :, [2, 1, 0]]
img_GT = np.transpose(img_GT, (2, 0, 1)).astype('float32')
img_LQs = np.transpose(img_LQs, (0, 3, 1, 2)).astype('float32')
return img_LQs, img_GT
def get_neighbor_frames(frame_name, number_frames, interval_list, random_reverse, bordermode=False):
center_frame_idx = int(frame_name)
half_N_frames = number_frames // 2
#### determine the neighbor frames
interval = random.choice(interval_list)
if bordermode:
direction = 1 # 1: forward; 0: backward
if random_reverse and random.random() < 0.5:
direction = random.choice([0, 1])
if center_frame_idx + interval * (number_frames - 1) > 99:
direction = 0
elif center_frame_idx - interval * (number_frames - 1) < 0:
direction = 1
# get the neighbor list
if direction == 1:
neighbor_list = list(
range(center_frame_idx, center_frame_idx + interval * number_frames, interval))
else:
neighbor_list = list(
range(center_frame_idx, center_frame_idx - interval * number_frames, -interval))
name_b = '{:08d}'.format(neighbor_list[0])
else:
# ensure not exceeding the borders
while (center_frame_idx + half_N_frames * interval >
99) or (center_frame_idx - half_N_frames * interval < 0):
center_frame_idx = random.randint(0, 99)
# get the neighbor list
neighbor_list = list(
range(center_frame_idx - half_N_frames * interval,
center_frame_idx + half_N_frames * interval + 1, interval))
if random_reverse and random.random() < 0.5:
neighbor_list.reverse()
name_b = '{:08d}'.format(neighbor_list[half_N_frames])
assert len(neighbor_list) == number_frames, \
"frames slected have length({}), but it should be ({})".format(len(neighbor_list), number_frames)
return neighbor_list, name_b
def read_img(path, size=None):
"""read image by cv2
return: Numpy float32, HWC, BGR, [0,1]"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# some images have 4 channels
if img.shape[2] > 3:
img = img[:, :, :3]
return img
def img_augment(img_list, hflip=True, rot=True):
"""horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip:
img = img[:, ::-1, :]
if vflip:
img = img[::-1, :, :]
if rot90:
img = img.transpose(1, 0, 2)
return img
return [_augment(img) for img in img_list]
def make_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot):
fl = filelist
def reader_():
if is_training:
random.shuffle(fl)
batch_out = []
for item in fl:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
return reader_
def make_multi_reader(filelist,
num_threads,
batch_size,
is_training,
number_frames,
interval_list,
random_reverse,
gtroot,
fileroot,
LR_input,
crop_size,
scale,
use_flip,
use_rot):
def read_into_queue(flq, queue):
batch_out = []
for item in flq:
img_LQs, img_GT = get_sample_data(item,
number_frames, interval_list, random_reverse, gtroot, fileroot,
LR_input, crop_size, scale, use_flip, use_rot)
batch_out.append((img_LQs, img_GT))
if len(batch_out) == batch_size:
queue.put(batch_out)
batch_out = []
queue.put(None)
def queue_reader():
fl = filelist
if is_training:
random.shuffle(fl)
n = num_threads
queue_size = 20
reader_lists = [None] * n
file_num = int(len(fl) // n)
for i in range(n):
if i < len(reader_lists) - 1:
tmp_list = fl[i * file_num:(i + 1) * file_num]
else:
tmp_list = fl[i * file_num:]
reader_lists[i] = tmp_list
queue = multiprocessing.Queue(queue_size)
p_list = [None] * len(reader_lists)
# for reader_list in reader_lists:
for i in range(len(reader_lists)):
reader_list = reader_lists[i]
p_list[i] = multiprocessing.Process(
target=read_into_queue, args=(reader_list, queue))
p_list[i].start()
reader_num = len(reader_lists)
finish_num = 0
while finish_num < reader_num:
sample = queue.get()
if sample is None:
finish_num += 1
else:
yield sample
for i in range(len(p_list)):
if p_list[i].is_alive():
p_list[i].join()
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import pickle
import cv2
import numpy as np
import random
class ReaderNotFoundError(Exception):
"Error: reader not found"
def __init__(self, reader_name, avail_readers):
super(ReaderNotFoundError, self).__init__()
self.reader_name = reader_name
self.avail_readers = avail_readers
def __str__(self):
msg = "Reader {} Not Found.\nAvailiable readers:\n".format(
self.reader_name)
for reader in self.avail_readers:
msg += " {}\n".format(reader)
return msg
class DataReader(object):
"""data reader for video input"""
def __init__(self, model_name, mode, cfg):
self.name = model_name
self.mode = mode
self.cfg = cfg
def create_reader(self):
"""Not implemented"""
pass
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
return default
return self.cfg[sec.upper()].get(item, default)
class ReaderZoo(object):
def __init__(self):
self.reader_zoo = {}
def regist(self, name, reader):
assert reader.__base__ == DataReader, "Unknow model type {}".format(
type(reader))
self.reader_zoo[name] = reader
def get(self, name, mode, cfg):
for k, v in self.reader_zoo.items():
if k == name:
return v(name, mode, cfg)
raise ReaderNotFoundError(name, self.reader_zoo.keys())
# singleton reader_zoo
reader_zoo = ReaderZoo()
def regist_reader(name, reader):
reader_zoo.regist(name, reader)
def get_reader(name, mode, cfg):
reader_model = reader_zoo.get(name, mode, cfg)
return reader_model.create_reader()
# examples of running programs:
# bash ./run.sh train CTCN ./configs/ctcn.yaml
# bash ./run.sh eval NEXTVLAD ./configs/nextvlad.yaml
# bash ./run.sh predict NONLOCAL ./cofings/nonlocal.yaml
# mode should be one of [train, eval, predict, inference]
# name should be one of [AttentionCluster, AttentionLSTM, NEXTVLAD, NONLOCAL, TSN, TSM, STNET, CTCN]
# configs should be ./configs/xxx.yaml
mode=$1
name=$2
configs=$3
pretrain="./tmp/name_map/paddle_state_dict.npz" # set pretrain model path if needed
resume="" # set pretrain model path if needed
save_dir="./data/checkpoints"
save_inference_dir="./data/inference_model"
use_gpu=True
fix_random_seed=False
log_interval=1
valid_interval=1
#weights="./data/checkpoints/EDVR_epoch721.pdparams" #set the path of weights to enable eval and predicut, just ignore this when training
#weights="./data/checkpoints_with_tsa/EDVR_epoch821.pdparams"
weights="./weights/paddle_state_dict_L.npz"
#weights="./weights/"
#export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=4,5,6,7 #0,1,5,6 fast, 2,3,4,7 slow
#export CUDA_VISIBLE_DEVICES=7
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.98
if [ "$mode" == "train" ]; then
echo $mode $name $configs $resume $pretrain
if [ "$resume"x != ""x ]; then
python train.py --model_name=$name \
--config=$configs \
--resume=$resume \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
elif [ "$pretrain"x != ""x ]; then
python train.py --model_name=$name \
--config=$configs \
--pretrain=$pretrain \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
else
python train.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--valid_interval=$valid_interval \
--use_gpu=$use_gpu \
--save_dir=$save_dir \
--fix_random_seed=$fix_random_seed
fi
elif [ "$mode"x == "eval"x ]; then
echo $mode $name $configs $weights
if [ "$weights"x != ""x ]; then
python eval.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--weights=$weights \
--use_gpu=$use_gpu
else
python eval.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--use_gpu=$use_gpu
fi
elif [ "$mode"x == "predict"x ]; then
echo $mode $name $configs $weights
if [ "$weights"x != ""x ]; then
python predict.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--weights=$weights \
--video_path='' \
--use_gpu=$use_gpu
else
python predict.py --model_name=$name \
--config=$configs \
--log_interval=$log_interval \
--use_gpu=$use_gpu \
--video_path=''
fi
elif [ "$mode"x == "inference"x ]; then
echo $mode $name $configs $weights
if [ "$weights"x != ""x ]; then
python inference_model.py --model_name=$name \
--config=$configs \
--weights=$weights \
--use_gpu=$use_gpu \
--save_dir=$save_inference_dir
else
python inference_model.py --model_name=$name \
--config=$configs \
--use_gpu=$use_gpu \
--save_dir=$save_inference_dir
fi
else
echo "Not implemented mode " $mode
fi
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import yaml
from .utility import AttrDict
import logging
logger = logging.getLogger(__name__)
CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
import yaml
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return
def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg
def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import signal
import logging
import paddle
import paddle.fluid as fluid
__all__ = ['AttrDict']
logger = logging.getLogger(__name__)
def _term(sig_num, addition):
print('current pid is %s, group id is %s' % (os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGTERM, _term)
signal.signal(signal.SIGINT, _term)
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def check_cuda(use_cuda, err = \
"\nYou can not set use_gpu = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_gpu = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册