未验证 提交 d5f74052 编写于 作者: S SunGaofeng 提交者: GitHub

fix test and infer for ctcn (#2562)

* fix test and infer for ctcn

* add instruction to dataset/ctcn/readme
上级 356bb7e2
## 简介
本教程期望给开发者提供基于PaddlePaddle的便捷、高效的使用深度学习算法解决视频理解、视频编辑、视频生成等一系列模型。目前包含视频分类模型,后续会不断的扩展到其他更多场景。
本教程期望给开发者提供基于PaddlePaddle的便捷、高效的使用深度学习算法解决视频理解、视频编辑、视频生成等一系列模型。目前包含视频分类和动作定位模型,后续会不断的扩展到其他更多场景。
目前视频分类模型包括:
目前视频分类和动作定位模型包括:
| 模型 | 类别 | 描述 |
| :--------------- | :--------: | :------------: |
......@@ -13,16 +13,17 @@
| [TSM](./models/tsm/README.md) | 视频分类| 基于时序移位的简单高效视频时空建模方法 |
| [TSN](./models/tsn/README.md) | 视频分类| ECCV'16提出的基于2D-CNN经典解决方案 |
| [Non-local](./models/nonlocal_model/README.md) | 视频分类| 视频非局部关联建模模型 |
| [C-TCN](./models/ctcn/README.md) | 视频动作定位| 2018年ActivityNet夺冠方案 |
### 主要特点
- 包含视频分类方向的多个主流领先模型,其中Attention LSTM,Attention Cluster和NeXtVLAD是比较流行的特征序列模型,Non-local, TSN, TSM和StNet是End-to-End的视频分类模型。Attention LSTM模型速度快精度高,NeXtVLAD是2nd-Youtube-8M比赛中最好的单模型, TSN是基于2D-CNN的经典解决方案,TSM是基于时序移位的简单高效视频时空建模方法,Non-local模型提出了视频非局部关联建模方法。Attention Cluster和StNet是百度自研模型,分别发表于CVPR2018和AAAI2019,是Kinetics600比赛第一名中使用到的模型
- 包含视频分类和动作定位方向的多个主流领先模型,其中Attention LSTM,Attention Cluster和NeXtVLAD是比较流行的特征序列模型,Non-local, TSN, TSM和StNet是End-to-End的视频分类模型。Attention LSTM模型速度快精度高,NeXtVLAD是2nd-Youtube-8M比赛中最好的单模型, TSN是基于2D-CNN的经典解决方案,TSM是基于时序移位的简单高效视频时空建模方法,Non-local模型提出了视频非局部关联建模方法。Attention Cluster和StNet是百度自研模型,分别发表于CVPR2018和AAAI2019,是Kinetics600比赛第一名中使用到的模型。C-TCN也是百度自研模型,2018年ActivityNet比赛的夺冠方案
- 提供了适合视频分类任务的通用骨架代码,用户可一键式高效配置模型完成训练和评测。
- 提供了适合视频分类和动作定位任务的通用骨架代码,用户可一键式高效配置模型完成训练和评测。
## 安装
在当前模型库运行样例代码需要PadddlePaddle Fluid v.1.4.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.4/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
在当前模型库运行样例代码需要PadddlePaddle Fluid v.1.5.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据[安装文档](http://www.paddlepaddle.org/documentation/docs/zh/1.5/beginners_guide/install/index_cn.html)中的说明来更新PaddlePaddle。
## 数据准备
......@@ -123,6 +124,14 @@ infer.py
| TSN | 256 | 8卡P40 | 7.1 | 0.67 | [model](https://paddlemodels.bj.bcebos.com/video_classification/tsn_kinetics.tar.gz) |
| TSM | 128 | 8卡P40 | 7.1 | 0.70 | [model](https://paddlemodels.bj.bcebos.com/video_classification/tsm_kinetics.tar.gz) |
| Non-local | 64 | 8卡P40 | 7.1 | 0.74 | [model](https://paddlemodels.bj.bcebos.com/video_classification/nonlocal_kinetics.tar.gz) |
- 基于ActivityNet的动作定位模型:
| 模型 | Batch Size | 环境配置 | cuDNN版本 | MAP | 下载链接 |
| :-------: | :---: | :---------: | :----: | :----: | :----------: |
| C-TCN | 16 | 8卡P40 | 7.1 | 0.31| [model](https://paddlemodels.bj.bcebos.com/video_detection/ctcn.tar.gz) |
## 参考文献
- [Attention Clusters: Purely Attention Based Local Feature Integration for Video Classification](https://arxiv.org/abs/1711.09550), Xiang Long, Chuang Gan, Gerard de Melo, Jiajun Wu, Xiao Liu, Shilei Wen
......@@ -132,7 +141,9 @@ infer.py
- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool
- [Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1), Ji Lin, Chuang Gan, Song Han
- [Non-local Neural Networks](https://arxiv.org/abs/1711.07971v1), Xiaolong Wang, Ross Girshick, Abhinav Gupta, Kaiming He
## 版本更新
- 3/2019: 新增模型库,发布Attention Cluster,Attention LSTM,NeXtVLAD,StNet,TSN五个视频分类模型。
- 4/2019: 发布Non-local, TSM两个视频分类模型。
- 6/2019: 发布C-TCN视频动作定位模型;Non-local模型增加C2D ResNet101和I3D ResNet50骨干网络;NeXtVLAD、TSM模型速度和显存优化。
......@@ -36,17 +36,17 @@ num_gpus = 8
filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat'
rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val'
flow = 'senet152-201cls-flow-60.9-5seg-331data_val'
class_label_file = 'dataset/ctcn/test_val_label.list'
class_label_file = 'dataset/ctcn/labels.txt'
video_duration_file = 'dataset/ctcn/val_duration_frame.list'
batch_size = 1
num_threads = 1
score_thresh = 0.001
nms_thresh = 0.08
sigma_thresh = 0.006
soft_thresh = 0.006
nms_thresh = 0.8
sigma_thresh = 0.9
soft_thresh = 0.004
[INFER]
filelist = 'dataset/ctcn/Activity1.3_val_rgb.listformat'
filelist = 'dataset/ctcn/infer.list'
rgb = 'senet152-201cls-rgb-70.3-5seg-331data_331img_val'
flow = 'senet152-201cls-flow-60.9-5seg-331data_val'
batch_size = 1
......
......@@ -196,13 +196,74 @@ class CTCNReader(DataReader):
boxes, labels = Coder.encode(boxes, labels)
return img, boxes, labels
def load_file(self, fname):
if python_ver < (3, 0):
rgb_pkl = pickle.load(
open(os.path.join(self.root, self.rgb, fname + '.pkl')))
flow_pkl = pickle.load(
open(os.path.join(self.root, self.flow, fname + '.pkl')))
else:
rgb_pkl = pickle.load(
open(os.path.join(self.root, self.rgb, fname + '.pkl')),
encoding='bytes')
flow_pkl = pickle.load(
open(os.path.join(self.root, self.flow, fname + '.pkl')),
encoding='bytes')
data_flow = np.array(flow_pkl['scores'])
data_rgb = np.array(rgb_pkl['scores'])
if data_flow.shape[0] < data_rgb.shape[0]:
data_rgb = data_rgb[0:data_flow.shape[0], :]
elif data_flow.shape[0] > data_rgb.shape[0]:
data_flow = data_flow[0:data_rgb.shape[0], :]
feats = np.concatenate((data_rgb, data_flow), axis=1)
if feats.shape[0] == 0 or feats.shape[1] == 0:
feats = np.zeros((512, 1024), np.float32)
logger.info('### file loading len = 0 {} ###'.format(fname))
return feats
def create_reader(self):
"""reader creator for ctcn model"""
if self.mode == 'infer':
return self.make_infer_reader()
if self.num_threads == 1:
return self.make_reader()
else:
return self.make_multiprocess_reader()
def make_infer_reader(self):
"""reader for inference"""
def reader():
with open(self.filelist) as f:
reader_list = f.readlines()
batch_out = []
for line in reader_list:
fname = line.strip().split()[0]
rgb_exist = os.path.exists(
os.path.join(self.root, self.rgb, fname + '.pkl'))
flow_exist = os.path.exists(
os.path.join(self.root, self.flow, fname + '.pkl'))
if not (rgb_exist and flow_exist):
logger.info('file not exist', fname)
continue
try:
feats = self.load_file(fname)
feats, boxes = self.resize(
feats, boxes=None, size=self.img_size)
h, w = feats.shape[:2]
feats = feats.reshape(1, h, w)
except:
logger.info('Error when loading {}'.format(fname))
continue
batch_out.append((feats, fname))
if len(batch_out) == self.batch_size:
yield batch_out
batch_out = []
return reader
def make_reader(self):
"""single process reader"""
......@@ -223,17 +284,17 @@ class CTCNReader(DataReader):
flow_exist = os.path.exists(
os.path.join(self.root, self.flow, splited[0] + '.pkl'))
if not (rgb_exist and flow_exist):
print('file not exist', splited[0])
# logger.info('file not exist', splited[0])
continue
fnames.append(splited[0])
frames_num = int(splited[1]) // self.snippet_length
num_boxes = int(splited[2])
box = []
label = []
for i in range(num_boxes):
c = splited[3 + 3 * i]
xmin = splited[4 + 3 * i]
xmax = splited[5 + 3 * i]
for ii in range(num_boxes):
c = splited[3 + 3 * ii]
xmin = splited[4 + 3 * ii]
xmax = splited[5 + 3 * ii]
box.append([
float(xmin) / self.snippet_length,
float(xmax) / self.snippet_length
......@@ -247,44 +308,9 @@ class CTCNReader(DataReader):
for idx in range(num_videos):
fname = fnames[idx]
try:
if python_ver < (3, 0):
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')))
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')))
else:
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')),
encoding='bytes')
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')),
encoding='bytes')
data_flow = np.array(flow_pkl['scores'])
data_rgb = np.array(rgb_pkl['scores'])
if data_flow.shape[0] < data_rgb.shape[0]:
data_rgb = data_rgb[0:data_flow.shape[0], :]
elif data_flow.shape[0] > data_rgb.shape[0]:
data_flow = data_flow[0:data_rgb.shape[0], :]
feats = np.concatenate((data_rgb, data_flow), axis=1)
if feats.shape[0] == 0 or feats.shape[1] == 0:
feats = np.zeros((512, 1024), np.float32)
logger.info('### file loading len = 0 {} ###'.format(
fname))
feats = self.load_file(fname)
boxes = copy.deepcopy(total_boxes[idx])
labels = copy.deepcopy(total_labels[idx])
feats, boxes, labels = self.transform(feats, boxes, labels,
self.mode)
labels = labels.astype('int64')
......@@ -328,17 +354,17 @@ class CTCNReader(DataReader):
flow_exist = os.path.exists(
os.path.join(self.root, self.flow, splited[0] + '.pkl'))
if not (rgb_exist and flow_exist):
logger.info('file not exist {}'.format(splited[0]))
# logger.info('file not exist {}'.format(splited[0]))
continue
fnames.append(splited[0])
frames_num = int(splited[1]) // self.snippet_length
num_boxes = int(splited[2])
box = []
label = []
for i in range(num_boxes):
c = splited[3 + 3 * i]
xmin = splited[4 + 3 * i]
xmax = splited[5 + 3 * i]
for ii in range(num_boxes):
c = splited[3 + 3 * ii]
xmin = splited[4 + 3 * ii]
xmax = splited[5 + 3 * ii]
box.append([
float(xmin) / self.snippet_length,
float(xmax) / self.snippet_length
......@@ -352,41 +378,7 @@ class CTCNReader(DataReader):
for idx in range(num_videos):
fname = fnames[idx]
try:
if python_ver < (3, 0):
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')))
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')))
else:
rgb_pkl = pickle.load(
open(
os.path.join(self.root, self.rgb, fname +
'.pkl')),
encoding='bytes')
flow_pkl = pickle.load(
open(
os.path.join(self.root, self.flow, fname +
'.pkl')),
encoding='bytes')
data_flow = np.array(flow_pkl['scores'])
data_rgb = np.array(rgb_pkl['scores'])
if data_flow.shape[0] < data_rgb.shape[0]:
data_rgb = data_rgb[0:data_flow.shape[0], :]
elif data_flow.shape[0] > data_rgb.shape[0]:
data_flow = data_flow[0:data_rgb.shape[0], :]
feats = np.concatenate((data_rgb, data_flow), axis=1)
if feats.shape[0] == 0 or feats.shape[1] == 0:
feats = np.zeros((512, 1024), np.float32)
logger.info('### file loading len = 0 {} ###'.format(
fname))
feats = self.load_file(fname)
boxes = copy.deepcopy(total_boxes[idx])
labels = copy.deepcopy(total_labels[idx])
......
......@@ -3,6 +3,7 @@
- [Youtube-8M](#Youtube-8M数据集)
- [Kinetics](#Kinetics数据集)
- [Non-local](#Non-local)
- [C-TCN](#C-TCN)
## Youtube-8M数据集
这里用到的是YouTube-8M 2018年更新之后的数据集。使用官方数据集,并将TFRecord文件转化为pickle文件以便PaddlePaddle使用。Youtube-8M数据集官方提供了frame-level和video-level的特征,这里只需使用到frame-level的特征。
......@@ -121,3 +122,7 @@ ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo ]
## Non-local
Non-local模型也使用kinetics数据集,不过其数据处理方式和其他模型不一样,详细内容见[Non-local数据说明](./nonlocal/README.md)
## C-TCN
C-TCN模型使用ActivityNet 1.3数据集,具体使用方法见[C-TCN数据说明](./ctcn/README.md)
# C-TCN模型数据使用说明
C-TCN模型使用ActivityNet 1.3数据集,具体下载方法请参考官方[下载说明](http://activity-net.org/index.html)。在训练此模型时,需要先使用训练好的TSN模型对mp4源文件进行特征提取,这里对RGB和Optical Flow分别提取特征,并存储为pickle文件格式。我们将会提供转化后的数据下载链接。转化后的数据文件目录结构为:
```
data
|
|----senet152-201cls-flow-60.9-5seg-331data\_train
|----senet152-201cls-rgb-70.3-5seg-331data\_331img\_train
|----senet152-201cls-flow-60.9-5seg-331data\_val
|----senet152-201cls-rgb-70.3-5seg-331data\_331img\_val
```
同时需要下载如下几个数据文件Activity1.3\_train\_rgb.listformat, Activity1.3\_val\_rgb.listformat, labels.txt, test\_val\_label.list, val\_duration\_frame.list,并放到dataset/ctcn目录下。
......@@ -119,28 +119,38 @@ def infer(args):
video_id = [items[-1] for items in data]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in))
predictions = np.array(infer_outs[0])
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append((video_id[i], preds.tolist(), topk_inds.tolist()))
prev_time = cur_time
cur_time = time.time()
period = cur_time - prev_time
periods.append(period)
if args.model_name in ['CTCN']:
# For detection model
loc_predictions = np.array(infer_outs[0])
cls_predictions = np.array(infer_outs[1])
for i in range(len(video_id)):
results.append((video_id[i], loc_predictions[i].tolist(),
cls_predictions[i].tolist()))
else:
# For classification model
predictions = np.array(infer_outs[0])
for i in range(len(predictions)):
topk_inds = predictions[i].argsort()[0 - args.infer_topk:]
topk_inds = topk_inds[::-1]
preds = predictions[i][topk_inds]
results.append(
(video_id[i], preds.tolist(), topk_inds.tolist()))
if args.log_interval > 0 and infer_iter % args.log_interval == 0:
logger.info('Processed {} samples'.format((infer_iter) * len(
predictions)))
logger.info('Processed {} samples'.format((infer_iter + 1) * len(
video_id)))
logger.info('[INFER] infer finished. average time: {}'.format(
np.mean(periods)))
if not os.path.isdir(args.save_dir):
os.mkdir(args.save_dir)
result_file_name = os.path.join(args.save_dir,
"{}_infer_result".format(args.model_name))
pickle.dump(results, open(result_file_name, 'wb'))
result_file_name = os.path.join(
args.save_dir, "{}_infer_result.pkl".format(args.model_name))
pickle.dump(results, open(result_file_name, 'wb'), protocol=0)
if __name__ == "__main__":
......
......@@ -88,10 +88,10 @@ class MetricsCalculator():
}
self.out_file = 'res_decode_' + str(self.score_thresh) + '_' + \
str(self.nms_thresh) + '_' + str(self.sigma_thresh) + \
'_' + str(self.soft_thresh)
'_' + str(self.soft_thresh) + '.json'
def accumulate(self, loss, pred, label):
cur_batch_size = loss[0].shape[0]
cur_batch_size = 1 # iteration counter
self.aggr_loss += np.mean(np.array(loss[0]))
self.aggr_loc_loss += np.mean(np.array(loss[1]))
self.aggr_cls_loss += np.mean(np.array(loss[2]))
......@@ -99,13 +99,13 @@ class MetricsCalculator():
if self.mode == 'test':
box_preds, label_preds, score_preds = self.box_coder.decode(
pred[0].squeeze(), pred[1].squeeze(), **self.box_decode_params)
fid = label[-1]
fid = label.squeeze()
fname = self.gt_labels[fid]
logger.info("file {}, num of box preds {}:".format(fname,
len(box_preds)))
logger.info("id {}, file {}, num of box preds {}:".format(
fid, fname, len(box_preds)))
self.results_detect[fname] = []
for j in range(len(label_preds)):
results_detect[fname[0]].append({
self.results_detect[fname].append({
"score": score_preds[j],
"label": self.class_label[label_preds[j]].strip(),
"segment": [
......@@ -123,7 +123,9 @@ class MetricsCalculator():
if self.mode == 'test':
self.res_detect['results'] = self.results_detect
with open(self.out_file, 'w') as f:
json.dump(res_detect, f)
json.dump(self.res_detect, f)
logger.info('results has been saved into file: {}'.format(
self.out_file))
def get_computed_metrics(self):
json_stats = {}
......
......@@ -86,8 +86,8 @@ class MetricsCalculator():
os.makedirs(self.checkpoint_dir)
pkl_path = os.path.join(self.checkpoint_dir, "results_probs.pkl")
with open(pkl_path, 'w') as f:
pickle.dump(self.results, f)
with open(pkl_path, 'wb') as f:
pickle.dump(self.results, f, protocol=0)
logger.info('Temporary file saved to: {}'.format(pkl_path))
......
# C-TCN 视频动作定位模型
---
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)
- [模型推断](#模型推断)
- [参考论文](#参考论文)
## 模型简介
C-TCN动作定位模型是百度自研,2018年ActivityNet夺冠方案,在Paddle上首次开源,为开发者提供了处理视频动作定位问题的解决方案。此模型引入了concept-wise时间卷积网络,对每个concept先用卷积神经网络分别提取时间维度的信息,然后再将每个concept的信息进行组合。主体结构是残差网络+FPN,采用类似SSD的单阶段目标检测算法对时间维度的anchor box进行预测和分类。
## 数据准备
C-TCN的训练数据采用ActivityNet1.3提供的数据集,数据下载及准备请参考[数据说明](../../dataset/ctcn/README.md)
## 模型训练
数据准备完毕后,可以通过如下两种方式启动训练:
python train.py --model_name=CTCN
--config=./configs/ctcn.txt
--save_dir=checkpoints
--log_interval=10
--valid_interval=1
--pretrain=${path_to_pretrain_model}
bash scripts/train/train_ctcn.sh
- 从头开始训练,使用上述启动脚本程序即可启动训练,不需要用到预训练模型
- 可下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/ctcn.tar.gz)通过`--resume`指定权重存放路径进行finetune等开发
**训练策略:**
* 采用Momentum优化算法训练,momentum=0.9
* 权重衰减系数为1e-4
* 学习率在迭代次数达到9000的时候做一次衰减
## 模型评估
可通过如下两种方式进行模型评估:
python test.py --model_name=CTCN
--config=configs/ctcn.txt
--log_interval=1
--weights=$PATH_TO_WEIGHTS
bash scripts/test/test_ctcn.sh
- 使用`scripts/test/test_ctcn.sh`进行评估时,需要修改脚本中的`--weights`参数指定需要评估的权重。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/ctcn.tar.gz)进行评估
当取如下参数时,在ActivityNet1.3数据集下评估精度如下:
| score\_thresh | nms\_thresh | soft\_sigma | soft\_thresh | Top-1 |
| :-----------: | :---------: | :---------: | :----------: | :----: |
| 0.001 | 0.8 | 0.9 | 0.004 | 31% |
## 模型推断
可通过如下命令进行模型推断:
python infer.py --model_name=CTCN
--config=configs/ctcn.txt
--log_interval=1
--weights=$PATH_TO_WEIGHTS
--filelist=$FILELIST
- 模型推断结果存储于`CTCN_infer_result.pkl`中,通过`pickle`格式存储。
- 若未指定`--weights`参数,脚本会下载已发布模型[model](https://paddlemodels.bj.bcebos.com/video_detection/ctcn.tar.gz)进行推断
## 参考论文
- 待发表
......@@ -106,8 +106,8 @@ class CTCN(ModelBase):
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
elif self.mode == 'infer':
fileid = fluid.layers.data(
name='fileid', shape=fileid_shape, dtype='int64')
# only image feature input when inference
pass
else:
raise NotImplementedError('mode {} not implemented'.format(
self.mode))
......@@ -168,7 +168,7 @@ class CTCN(ModelBase):
self.loc_targets, self.cls_targets, self.fileid
]
elif self.mode == 'infer':
return self.feature_input + [self.fileid]
return self.feature_input
else:
raise NotImplemented
......
......@@ -136,10 +136,7 @@ class ModelBase(object):
fluid.io.load_params(exe, pretrain, main_program=prog)
def load_test_weights(self, exe, weights, prog, place):
def if_exist(var):
return os.path.exists(os.path.join(weights, var.name))
fluid.io.load_vars(exe, weights, predicate=if_exist)
fluid.io.load_params(exe, weights, main_program=prog)
def get_config_from_sec(self, sec, item, default=None):
if sec.upper() not in self.cfg:
......
python infer.py --model_name="CTCN" --config=./configs/ctcn.txt --filelist=./dataset/ctcn/infer.list \
--log_interval=1 --weights=./checkpoints/CTCN_epoch0 --save_dir=./save
export CUDA_VISIBLE_DEVICES=0
python test.py --model_name="CTCN" --config=./configs/ctcn.txt \
--log_interval=10 --weights=./checkpoints/CTCN_epoch0
--log_interval=1 --weights=./checkpoints/CTCN_epoch0
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
#export CUDA_VISIBLE_DEVICES=0
export FLAGS_fast_eager_deletion_mode=1
export FLAGS_eager_delete_tensor_gb=0.0
......
......@@ -94,11 +94,16 @@ def test(args):
test_metrics = get_metrics(args.model_name.upper(), 'test', test_config)
test_feeder = fluid.DataFeeder(place=place, feed_list=test_feeds)
if test_loss is None:
fetch_list = [x.name for x in test_outputs] + [test_feeds[-1].name]
if args.model_name.upper() in ['CTCN']:
fetch_list = [x.name for x in test_loss] + \
[x.name for x in test_outputs] + \
[test_feeds[-1].name]
else:
fetch_list = [test_loss.name] + [x.name for x in test_outputs
] + [test_feeds[-1].name]
if test_loss is None:
fetch_list = [x.name for x in test_outputs] + [test_feeds[-1].name]
else:
fetch_list = [test_loss.name] + [x.name for x in test_outputs
] + [test_feeds[-1].name]
epoch_period = []
for test_iter, data in enumerate(test_reader()):
......@@ -106,14 +111,25 @@ def test(args):
test_outs = exe.run(fetch_list=fetch_list, feed=test_feeder.feed(data))
period = time.time() - cur_time
epoch_period.append(period)
if test_loss is None:
loss = np.zeros(1, ).astype('float32')
pred = np.array(test_outs[0])
label = np.array(test_outs[-1])
if args.model_name.upper() in ['CTCN']:
total_loss = test_outs[0]
loc_loss = test_outs[1]
cls_loss = test_outs[2]
loc_preds = test_outs[3]
cls_preds = test_outs[4]
fid = test_outs[-1]
loss = [total_loss, loc_loss, cls_loss]
pred = [loc_preds, cls_preds]
label = fid
else:
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
if test_loss is None:
loss = np.zeros(1, ).astype('float32')
pred = np.array(test_outs[0])
label = np.array(test_outs[-1])
else:
loss = np.array(test_outs[0])
pred = np.array(test_outs[1])
label = np.array(test_outs[-1])
test_metrics.accumulate(loss, pred, label)
# metric here
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册