From 632170a3510c61c738cb56e82470ce31327490e2 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Thu, 24 Oct 2019 14:06:19 +0800 Subject: [PATCH] add new grounding model TALL for PaddleVideo (#3690) * add new grounding model TALL for PaddleVideo * modify configs and add infer data * delete inference data of tall * delete some redundant notes after gf's review * refine code of metrics for prediction and evaluation * delete notes of multi-process reader --- PaddleCV/PaddleVideo/configs/tall.yaml | 67 ++++ PaddleCV/PaddleVideo/eval.py | 7 + PaddleCV/PaddleVideo/metrics/metrics_util.py | 55 ++- .../metrics/tall_metrics/__init__.py | 0 .../metrics/tall_metrics/tall_metrics.py | 297 ++++++++++++++++ PaddleCV/PaddleVideo/models/__init__.py | 2 + PaddleCV/PaddleVideo/models/tall/__init__.py | 1 + PaddleCV/PaddleVideo/models/tall/tall.py | 165 +++++++++ PaddleCV/PaddleVideo/models/tall/tall_net.py | 151 ++++++++ PaddleCV/PaddleVideo/models/tsn/README.md | 2 +- PaddleCV/PaddleVideo/predict.py | 11 +- PaddleCV/PaddleVideo/reader/__init__.py | 2 + PaddleCV/PaddleVideo/reader/tall_reader.py | 323 ++++++++++++++++++ PaddleCV/PaddleVideo/utils/utility.py | 21 +- 14 files changed, 1086 insertions(+), 18 deletions(-) create mode 100644 PaddleCV/PaddleVideo/configs/tall.yaml create mode 100644 PaddleCV/PaddleVideo/metrics/tall_metrics/__init__.py create mode 100644 PaddleCV/PaddleVideo/metrics/tall_metrics/tall_metrics.py create mode 100644 PaddleCV/PaddleVideo/models/tall/__init__.py create mode 100644 PaddleCV/PaddleVideo/models/tall/tall.py create mode 100644 PaddleCV/PaddleVideo/models/tall/tall_net.py create mode 100644 PaddleCV/PaddleVideo/reader/tall_reader.py diff --git a/PaddleCV/PaddleVideo/configs/tall.yaml b/PaddleCV/PaddleVideo/configs/tall.yaml new file mode 100644 index 00000000..46cb181e --- /dev/null +++ b/PaddleCV/PaddleVideo/configs/tall.yaml @@ -0,0 +1,67 @@ +MODEL: + name: "TALL" + visual_feature_dim : 12288 + sentence_embedding_size : 4800 + semantic_size : 1024 + hidden_size : 1000 + output_size : 3 + +TRAIN: + epoch : 25 + use_gpu : True + num_gpus : 1 + batch_size : 56 + + off_size: 2 + clip_norm: 5.0 + learning_rate: 1e-3 + + semantic_size : 1024 + feats_dimen : 4096 + context_num : 1 + context_size : 128 + sent_vec_dim : 4800 + sliding_clip_path : "./data/dataset/tall/Interval64_128_256_512_overlap0.8_c3d_fc6/" + clip_sentvec : "./data/dataset/tall/train_clip-sentvec.pkl" + movie_length_info : "./data/dataset/tall/video_allframes_info.pkl" + +VALID: + use_gpu : True + num_gpus : 1 + batch_size : 56 + + off_size: 2 + clip_norm: 5.0 + learning_rate: 1e-3 + + semantic_size : 1024 + feats_dimen : 4096 + context_num : 1 + context_size : 128 + sent_vec_dim : 4800 + sliding_clip_path : "./data/dataset/tall/Interval64_128_256_512_overlap0.8_c3d_fc6/" + clip_sentvec : "./data/dataset/tall/train_clip-sentvec.pkl" + movie_length_info : "./data/dataset/tall/video_allframes_info.pkl" + +TEST: + batch_size : 1 + + feats_dimen : 4096 + context_num : 1 + context_size : 128 + sent_vec_dim : 4800 + semantic_size : 4800 + sliding_clip_path : "./data/dataset/tall/Interval128_256_overlap0.8_c3d_fc6/" + clip_sentvec : "./data/dataset/tall/test_clip-sentvec.pkl" + +INFER: + batch_size: 1 + feats_dimen: 4096 + context_num: 1 + context_size: 128 + sent_vec_dim: 4800 + semantic_size: 4800 + + filelist: "./data/dataset/tall/infer" + sliding_clip_path : "./data/dataset/tall/infer/infer_feat" + clip_sentvec : "./data/dataset/tall/infer/infer_clip-sen.pkl" diff --git a/PaddleCV/PaddleVideo/eval.py b/PaddleCV/PaddleVideo/eval.py index 1264f4a4..a800f9c7 100644 --- a/PaddleCV/PaddleVideo/eval.py +++ b/PaddleCV/PaddleVideo/eval.py @@ -120,6 +120,13 @@ def test(args): feed=test_feeder.feed(feat_data), return_numpy=False) test_outs += [vinfo] + elif args.model_name == 'TALL': + feat_data = [items[:2] for items in data] + vinfo = [items[2:] for items in data] + test_outs = exe.run(fetch_list=test_fetch_list, + feed=test_feeder.feed(feat_data), + return_numpy=True) + test_outs += [vinfo] else: test_outs = exe.run(fetch_list=test_fetch_list, feed=test_feeder.feed(data)) diff --git a/PaddleCV/PaddleVideo/metrics/metrics_util.py b/PaddleCV/PaddleVideo/metrics/metrics_util.py index 9ec25e2f..f84f4b95 100644 --- a/PaddleCV/PaddleVideo/metrics/metrics_util.py +++ b/PaddleCV/PaddleVideo/metrics/metrics_util.py @@ -29,6 +29,7 @@ from metrics.bmn_metrics import bmn_proposal_metrics as bmn_proposal_metrics from metrics.bsn_metrics import bsn_tem_metrics as bsn_tem_metrics from metrics.bsn_metrics import bsn_pem_metrics as bsn_pem_metrics from metrics.ets_metrics import ets_metrics as ets_metrics +from metrics.tall_metrics import tall_metrics as tall_metrics logger = logging.getLogger(__name__) @@ -432,18 +433,17 @@ class ETSMetrics(Metrics): args['name'] = name self.calculator = ets_metrics.MetricsCalculator(**args) - def calculate_and_log_out(self, fetch_list, info=''): if (self.mode == 'train') or (self.mode == 'valid'): loss = np.array(fetch_list[0]) - logger.info( - info + '\tLoss = {}'.format('%.08f' % np.mean(loss))) + logger.info(info + '\tLoss = {}'.format('%.08f' % np.mean(loss))) elif self.mode == "test": translation_ids = np.array(fetch_list[0]) translation_scores = np.array(fetch_list[1]) logger.info( - info + '\ttranslation_ids = {}, \ttranslation_scores = {}'.format( - '%.01f' % np.mean(translation_ids), '%.04f' % np.mean(translation_scores))) + info + '\ttranslation_ids = {}, \ttranslation_scores = {}'. + format('%.01f' % np.mean(translation_ids), '%.04f' % np.mean( + translation_scores))) def accumulate(self, fetch_list): self.calculator.accumulate(fetch_list) @@ -454,7 +454,49 @@ class ETSMetrics(Metrics): else: #test or infer self.calculator.finalize_metrics(savedir) if self.mode == 'test': - logger.info(info + 'please refer to metrics/ets_metrics/README.md to get accuracy') + logger.info( + info + + 'please refer to metrics/ets_metrics/README.md to get accuracy' + ) + + def reset(self): + self.calculator.reset() + + +class TALLMetrics(Metrics): + def __init__(self, name, mode, cfg): + self.name = name + self.mode = mode + args = {} + args['mode'] = mode + args['name'] = name + self.calculator = tall_metrics.MetricsCalculator(**args) + + def calculate_and_log_out(self, fetch_list, info=''): + if (self.mode == 'train') or (self.mode == 'valid'): + loss = np.array(fetch_list[0]) + logger.info(info + '\tLoss = {}'.format('%.04f' % np.mean(loss))) + elif self.mode == "test": + sim_score_mat = np.array(fetch_list[0]) + logger.info(info + '\tsim_score_mat = {}'.format('%.01f' % np.mean( + sim_score_mat))) + + def accumulate(self, fetch_list): + self.calculator.accumulate(fetch_list) + + def finalize_and_log_out(self, info='', savedir='./'): + if self.mode == 'valid': + logger.info(info) + elif self.mode == 'infer': + self.calculator.finalize_infer_metrics(savedir) + else: + self.calculator.finalize_metrics(savedir) + metrics_dict = self.calculator.get_computed_metrics() + R1_IOU5 = metrics_dict['best_R1_IOU5'] + R5_IOU5 = metrics_dict['best_R5_IOU5'] + + logger.info("best_R1_IOU5: {}\n".format(" %0.3f" % R1_IOU5)) + logger.info("best_R5_IOU5: {}\n".format(" %0.3f" % R5_IOU5)) def reset(self): self.calculator.reset() @@ -501,3 +543,4 @@ regist_metrics("BMN", BmnMetrics) regist_metrics("BSNTEM", BsnTemMetrics) regist_metrics("BSNPEM", BsnPemMetrics) regist_metrics("ETS", ETSMetrics) +regist_metrics("TALL", TALLMetrics) diff --git a/PaddleCV/PaddleVideo/metrics/tall_metrics/__init__.py b/PaddleCV/PaddleVideo/metrics/tall_metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/PaddleCV/PaddleVideo/metrics/tall_metrics/tall_metrics.py b/PaddleCV/PaddleVideo/metrics/tall_metrics/tall_metrics.py new file mode 100644 index 00000000..0ed0f39d --- /dev/null +++ b/PaddleCV/PaddleVideo/metrics/tall_metrics/tall_metrics.py @@ -0,0 +1,297 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and + +import numpy as np +import datetime +import logging +import json +import os +import operator + +logger = logging.getLogger(__name__) + + +class MetricsCalculator(): + def __init__( + self, + name='TALL', + mode='train', ): + self.name = name + self.mode = mode # 'train', 'valid', 'test', 'infer' + self.reset() + + def reset(self): + logger.info('Resetting {} metrics...'.format(self.mode)) + if (self.mode == 'train') or (self.mode == 'valid'): + self.aggr_loss = 0.0 + elif (self.mode == 'test') or (self.mode == 'infer'): + self.result_dict = dict() + self.save_res = dict() + self.out_file = self.name + '_' + self.mode + '_res_' + '.json' + + def nms_temporal(self, x1, x2, sim, overlap): + pick = [] + assert len(x1) == len(sim) + assert len(x2) == len(sim) + if len(x1) == 0: + return pick + + union = map(operator.sub, x2, x1) # union = x2-x1 + + I = [i[0] for i in sorted( + enumerate(sim), key=lambda x: x[1])] # sort and get index + + while len(I) > 0: + i = I[-1] + pick.append(i) + + xx1 = [max(x1[i], x1[j]) for j in I[:-1]] + xx2 = [min(x2[i], x2[j]) for j in I[:-1]] + inter = [max(0.0, k2 - k1) for k1, k2 in zip(xx1, xx2)] + o = [ + inter[u] / (union[i] + union[I[u]] - inter[u]) + for u in range(len(I) - 1) + ] + I_new = [] + for j in range(len(o)): + if o[j] <= overlap: + I_new.append(I[j]) + I = I_new + return pick + + def calculate_IoU(self, i0, i1): + # calculate temporal intersection over union + union = (min(i0[0], i1[0]), max(i0[1], i1[1])) + inter = (max(i0[0], i1[0]), min(i0[1], i1[1])) + iou = 1.0 * (inter[1] - inter[0]) / (union[1] - union[0]) + return iou + + def compute_IoU_recall_top_n_forreg(self, top_n, iou_thresh, + sentence_image_mat, + sentence_image_reg_mat, sclips): + correct_num = 0.0 + for k in range(sentence_image_mat.shape[0]): + gt = sclips[k] + gt_start = float(gt.split("_")[1]) + gt_end = float(gt.split("_")[2]) + sim_v = [v for v in sentence_image_mat[k]] + starts = [s for s in sentence_image_reg_mat[k, :, 0]] + ends = [e for e in sentence_image_reg_mat[k, :, 1]] + picks = self.nms_temporal(starts, ends, sim_v, iou_thresh - 0.05) + if top_n < len(picks): + picks = picks[0:top_n] + for index in picks: + pred_start = sentence_image_reg_mat[k, index, 0] + pred_end = sentence_image_reg_mat[k, index, 1] + iou = self.calculate_IoU((gt_start, gt_end), + (pred_start, pred_end)) + if iou >= iou_thresh: + correct_num += 1 + break + return correct_num + + def accumulate(self, fetch_list): + if self.mode == 'valid': + loss = fetch_list[0] + self.aggr_loss += np.mean(np.array(loss)) + elif (self.mode == 'test') or (self.mode == 'infer'): + outputs = fetch_list[0] + b_start = [item[0] for item in fetch_list[1]] + b_end = [item[1] for item in fetch_list[1]] + b_k = [item[2] for item in fetch_list[1]] + b_t = [item[3] for item in fetch_list[1]] + b_movie_clip_sentences = [item[4] for item in fetch_list[1]] + b_movie_clip_featmaps = [item[5] for item in fetch_list[1]] + b_movie_name = [item[6] for item in fetch_list[1]] + + batch_size = len(b_start) + for i in range(batch_size): + start = b_start[i] + end = b_end[i] + k = b_k[i] + t = b_t[i] + movie_clip_sentences = b_movie_clip_sentences[i] + movie_clip_featmaps = b_movie_clip_featmaps[i] + movie_name = b_movie_name[i] + + item_res = [outputs, start, end, k, t] + + if movie_name not in self.result_dict.keys(): + self.result_dict[movie_name] = [] + self.result_dict[movie_name].append(movie_clip_sentences) + self.result_dict[movie_name].append(movie_clip_featmaps) + + self.result_dict[movie_name].append(item_res) + + def accumulate_infer_results(self, fetch_list): + # the same as test + pass + + def finalize_metrics(self, savedir): + # init + IoU_thresh = [0.1, 0.3, 0.5, 0.7] + all_correct_num_10 = [0.0] * 5 + all_correct_num_5 = [0.0] * 5 + all_correct_num_1 = [0.0] * 5 + all_retrievd = 0.0 + + idx = 0 + all_number = len(self.result_dict) + for movie_name in self.result_dict.keys(): + idx += 1 + logger.info('{} / {}'.format('%d' % idx, '%d' % all_number)) + + movie_clip_sentences = self.result_dict[movie_name][0] + movie_clip_featmaps = self.result_dict[movie_name][1] + + ls = len(movie_clip_sentences) + lf = len(movie_clip_featmaps) + sentence_image_mat = np.zeros([ls, lf]) + sentence_image_reg_mat = np.zeros([ls, lf, 2]) + + movie_res = self.result_dict[movie_name][2:] + for item_res in movie_res: + outputs, start, end, k, t = item_res + + outputs = np.squeeze(outputs) + sentence_image_mat[k, t] = outputs[0] + reg_end = end + outputs[2] + reg_start = start + outputs[1] + + sentence_image_reg_mat[k, t, 0] = reg_start + sentence_image_reg_mat[k, t, 1] = reg_end + + sclips = [b[0] for b in movie_clip_sentences] + + for i in range(len(IoU_thresh)): + IoU = IoU_thresh[i] + correct_num_10 = self.compute_IoU_recall_top_n_forreg( + 10, IoU, sentence_image_mat, sentence_image_reg_mat, sclips) + correct_num_5 = self.compute_IoU_recall_top_n_forreg( + 5, IoU, sentence_image_mat, sentence_image_reg_mat, sclips) + correct_num_1 = self.compute_IoU_recall_top_n_forreg( + 1, IoU, sentence_image_mat, sentence_image_reg_mat, sclips) + + logger.info( + movie_name + + " IoU= {}, R@10: {}; IoU= {}, R@5: {}; IoU= {}, R@1: {}". + format('%s' % str(IoU), '%s' % str(correct_num_10 / len( + sclips)), '%s' % str(IoU), '%s' % str( + correct_num_5 / len(sclips)), '%s' % str(IoU), '%s' + % str(correct_num_1 / len(sclips)))) + + all_correct_num_10[i] += correct_num_10 + all_correct_num_5[i] += correct_num_5 + all_correct_num_1[i] += correct_num_1 + + all_retrievd += len(sclips) + + for j in range(len(IoU_thresh)): + logger.info( + " IoU= {}, R@10: {}; IoU= {}, R@5: {}; IoU= {}, R@1: {}".format( + '%s' % str(IoU_thresh[j]), '%s' % str(all_correct_num_10[ + j] / all_retrievd), '%s' % str(IoU_thresh[j]), '%s' % + str(all_correct_num_5[j] / all_retrievd), '%s' % str( + IoU_thresh[j]), '%s' % str(all_correct_num_1[j] / + all_retrievd))) + + self.R1_IOU5 = all_correct_num_1[2] / all_retrievd + self.R5_IOU5 = all_correct_num_5[2] / all_retrievd + + self.save_res["best_R1_IOU5"] = self.R1_IOU5 + self.save_res["best_R5_IOU5"] = self.R5_IOU5 + + self.filepath = os.path.join(savedir, self.out_file) + with open(self.filepath, 'w') as f: + f.write( + json.dumps( + { + 'version': 'VERSION 1.0', + 'results': self.save_res, + 'external_data': {} + }, + indent=2)) + logger.info('results has been saved into file: {}'.format( + self.filepath)) + + def finalize_infer_metrics(self, savedir): + idx = 0 + all_number = len(self.result_dict) + res = dict() + for movie_name in self.result_dict.keys(): + res[movie_name] = [] + idx += 1 + logger.info('{} / {}'.format('%d' % idx, '%d' % all_number)) + + movie_clip_sentences = self.result_dict[movie_name][0] + movie_clip_featmaps = self.result_dict[movie_name][1] + + ls = len(movie_clip_sentences) + lf = len(movie_clip_featmaps) + sentence_image_mat = np.zeros([ls, lf]) + sentence_image_reg_mat = np.zeros([ls, lf, 2]) + + movie_res = self.result_dict[movie_name][2:] + for item_res in movie_res: + outputs, start, end, k, t = item_res + + outputs = np.squeeze(outputs) + sentence_image_mat[k, t] = outputs[0] + reg_end = end + outputs[2] + reg_start = start + outputs[1] + + sentence_image_reg_mat[k, t, 0] = reg_start + sentence_image_reg_mat[k, t, 1] = reg_end + + sclips = [b[0] for b in movie_clip_sentences] + IoU = 0.5 #pre-define + for k in range(sentence_image_mat.shape[0]): + #ground_truth for compare + gt = sclips[k] + gt_start = float(gt.split("_")[1]) + gt_end = float(gt.split("_")[2]) + + sim_v = [v for v in sentence_image_mat[k]] + starts = [s for s in sentence_image_reg_mat[k, :, 0]] + ends = [e for e in sentence_image_reg_mat[k, :, 1]] + picks = self.nms_temporal(starts, ends, sim_v, IoU - 0.05) + + if 1 < len(picks): #top1 + picks = picks[0:1] + + for index in picks: + pred_start = sentence_image_reg_mat[k, index, 0] + pred_end = sentence_image_reg_mat[k, index, 1] + res[movie_name].append((k, pred_start, pred_end)) + + logger.info( + 'movie_name: {}, sentence_id: {}, pred_start_time: {}, pred_end_time: {}, gt_start_time: {}, gt_end_time: {}'. + format('%s' % movie_name, '%s' % str(k), '%s' % str( + pred_start), '%s' % str(pred_end), '%s' % str(gt_start), + '%s' % str(gt_end))) + + self.filepath = os.path.join(savedir, self.out_file) + with open(self.filepath, 'w') as f: + f.write( + json.dumps( + { + 'version': 'VERSION 1.0', + 'results': res, + 'external_data': {} + }, + indent=2)) + logger.info('results has been saved into file: {}'.format( + self.filepath)) + + def get_computed_metrics(self): + return self.save_res diff --git a/PaddleCV/PaddleVideo/models/__init__.py b/PaddleCV/PaddleVideo/models/__init__.py index a5e01826..d4550e2f 100644 --- a/PaddleCV/PaddleVideo/models/__init__.py +++ b/PaddleCV/PaddleVideo/models/__init__.py @@ -11,6 +11,7 @@ from .bmn import BMN from .bsn import BsnTem from .bsn import BsnPem from .ets import ETS +from .tall import TALL # regist models, sort by alphabet regist_model("AttentionCluster", AttentionCluster) @@ -25,3 +26,4 @@ regist_model("BMN", BMN) regist_model("BsnTem", BsnTem) regist_model("BsnPem", BsnPem) regist_model("ETS", ETS) +regist_model("TALL", TALL) diff --git a/PaddleCV/PaddleVideo/models/tall/__init__.py b/PaddleCV/PaddleVideo/models/tall/__init__.py new file mode 100644 index 00000000..c0ddead7 --- /dev/null +++ b/PaddleCV/PaddleVideo/models/tall/__init__.py @@ -0,0 +1 @@ +from .tall import * diff --git a/PaddleCV/PaddleVideo/models/tall/tall.py b/PaddleCV/PaddleVideo/models/tall/tall.py new file mode 100644 index 00000000..c09eea21 --- /dev/null +++ b/PaddleCV/PaddleVideo/models/tall/tall.py @@ -0,0 +1,165 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +import numpy as np + +from ..model import ModelBase +from . import tall_net + +import logging +logger = logging.getLogger(__name__) + +__all__ = ["TALL"] + + +class TALL(ModelBase): + """TALL model""" + + def __init__(self, name, cfg, mode='train'): + super(TALL, self).__init__(name, cfg, mode=mode) + self.get_config() + + def get_config(self): + self.visual_feature_dim = self.get_config_from_sec('MODEL', + 'visual_feature_dim') + self.sentence_embedding_size = self.get_config_from_sec( + 'MODEL', 'sentence_embedding_size') + self.semantic_size = self.get_config_from_sec('MODEL', 'semantic_size') + self.hidden_size = self.get_config_from_sec('MODEL', 'hidden_size') + self.output_size = self.get_config_from_sec('MODEL', 'output_size') + self.batch_size = self.get_config_from_sec(self.mode, 'batch_size') + + self.off_size = self.get_config_from_sec('train', + 'off_size') # in train of yaml + self.clip_norm = self.get_config_from_sec('train', 'clip_norm') + self.learning_rate = self.get_config_from_sec('train', 'learning_rate') + + def build_input(self, use_dataloader=True): + visual_shape = self.visual_feature_dim + sentence_shape = self.sentence_embedding_size + offset_shape = self.off_size + + # set init data to None + images = None + sentences = None + offsets = None + + self.use_dataloader = use_dataloader + + images = fluid.data( + name='train_visual', shape=[None, visual_shape], dtype='float32') + + sentences = fluid.data( + name='train_sentences', + shape=[None, sentence_shape], + dtype='float32') + + feed_list = [] + feed_list.append(images) + feed_list.append(sentences) + if (self.mode == 'train') or (self.mode == 'valid'): + offsets = fluid.data( + name='train_offsets', + shape=[None, offset_shape], + dtype='float32') + + feed_list.append(offsets) + elif (self.mode == 'test') or (self.mode == 'infer'): + # input images and sentences + pass + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + + if use_dataloader: + assert self.mode != 'infer', \ + 'dataloader is not recommendated when infer, please set use_dataloader to be false.' + self.dataloader = fluid.io.DataLoader.from_generator( + feed_list=feed_list, capacity=16, iterable=True) + + self.images = [images] + self.sentences = sentences + self.offsets = offsets + + def create_model_args(self): + cfg = {} + + cfg['semantic_size'] = self.semantic_size + cfg['sentence_embedding_size'] = self.sentence_embedding_size + cfg['hidden_size'] = self.hidden_size + cfg['output_size'] = self.output_size + cfg['batch_size'] = self.batch_size + return cfg + + def build_model(self): + cfg = self.create_model_args() + self.videomodel = tall_net.TALLNET( + semantic_size=cfg['semantic_size'], + sentence_embedding_size=cfg['sentence_embedding_size'], + hidden_size=cfg['hidden_size'], + output_size=cfg['output_size'], + batch_size=cfg['batch_size'], + mode=self.mode) + outs = self.videomodel.net(images=self.images[0], + sentences=self.sentences) + self.network_outputs = [outs] + + def optimizer(self): + clip_norm = self.clip_norm + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm)) + + optimizer = fluid.optimizer.Adam(learning_rate=self.learning_rate) + + return optimizer + + def loss(self): + assert self.mode != 'infer', "invalid loss calculationg in infer mode" + self.loss_ = self.videomodel.loss(self.network_outputs[0], self.offsets) + return self.loss_ + + def outputs(self): + preds = self.network_outputs[0] + return [preds] + + def feeds(self): + if (self.mode == 'train') or (self.mode == 'valid'): + return self.images + [self.sentences, self.offsets] + elif self.mode == 'test' or (self.mode == 'infer'): + return self.images + [self.sentences] + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + + def fetches(self): + if (self.mode == 'train') or (self.mode == 'valid'): + losses = self.loss() + fetch_list = [item for item in losses] + elif (self.mode == 'test') or (self.mode == 'infer'): + preds = self.outputs() + fetch_list = [item for item in preds] + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + return fetch_list + + def pretrain_info(self): + return (None, None) + + def weights_info(self): + pass diff --git a/PaddleCV/PaddleVideo/models/tall/tall_net.py b/PaddleCV/PaddleVideo/models/tall/tall_net.py new file mode 100644 index 00000000..6ba4a59e --- /dev/null +++ b/PaddleCV/PaddleVideo/models/tall/tall_net.py @@ -0,0 +1,151 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +import numpy as np + + +class TALLNET(object): + def __init__(self, + semantic_size, + sentence_embedding_size, + hidden_size, + output_size, + batch_size, + mode='train'): + + self.semantic_size = semantic_size + self.sentence_embedding_size = sentence_embedding_size + self.hidden_size = hidden_size + self.output_size = output_size + self.batch_size = batch_size #divide train and test + + self.mode = mode + + def cross_modal_comb(self, visual_feat, sentence_embed): + visual_feat = fluid.layers.reshape(visual_feat, + [1, -1, self.semantic_size]) + vv_feature = fluid.layers.expand(visual_feat, [self.batch_size, 1, 1]) + sentence_embed = fluid.layers.reshape(sentence_embed, + [-1, 1, self.semantic_size]) + ss_feature = fluid.layers.expand(sentence_embed, + [1, self.batch_size, 1]) + + concat_feature = fluid.layers.concat( + [vv_feature, ss_feature], axis=2) #B,B,2048 + + mul_feature = vv_feature * ss_feature # B,B,1024 + add_feature = vv_feature + ss_feature # B,B,1024 + + comb_feature = fluid.layers.concat( + [mul_feature, add_feature, concat_feature], axis=2) + return comb_feature + + def net(self, images, sentences): + # visual2semantic + transformed_clip = fluid.layers.fc( + input=images, + size=self.semantic_size, + act=None, + name='v2s_lt', + param_attr=fluid.ParamAttr( + name='v2s_lt_weights', + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=1.0, seed=0)), + bias_attr=False) + + #l2_normalize + transformed_clip = fluid.layers.l2_normalize(x=transformed_clip, axis=1) + + # sentenct2semantic + transformed_sentence = fluid.layers.fc( + input=sentences, + size=self.semantic_size, + act=None, + name='s2s_lt', + param_attr=fluid.ParamAttr( + name='s2s_lt_weights', + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=1.0, seed=0)), + bias_attr=False) + + #l2_normalize + transformed_sentence = fluid.layers.l2_normalize( + x=transformed_sentence, axis=1) + + cross_modal_vec = self.cross_modal_comb(transformed_clip, + transformed_sentence) + cross_modal_vec = fluid.layers.unsqueeze( + input=cross_modal_vec, axes=[0]) + cross_modal_vec = fluid.layers.transpose( + cross_modal_vec, perm=[0, 3, 1, 2]) + + mid_output = fluid.layers.conv2d( + input=cross_modal_vec, + num_filters=self.hidden_size, + filter_size=1, + stride=1, + act="relu", + param_attr=fluid.param_attr.ParamAttr(name="mid_out_weights"), + bias_attr=False) + + sim_score_mat = fluid.layers.conv2d( + input=mid_output, + num_filters=self.output_size, + filter_size=1, + stride=1, + act=None, + param_attr=fluid.param_attr.ParamAttr(name="sim_mat_weights"), + bias_attr=False) + sim_score_mat = fluid.layers.squeeze(input=sim_score_mat, axes=[0]) + return sim_score_mat + + def loss(self, outs, offs): + sim_score_mat = outs[0] + p_reg_mat = outs[1] + l_reg_mat = outs[2] + # loss cls, not considering iou + input_size = outs.shape[1] + I = fluid.layers.diag(np.array([1] * input_size).astype('float32')) + I_2 = -2 * I + all1 = fluid.layers.ones( + shape=[input_size, input_size], dtype="float32") + + mask_mat = I_2 + all1 + + alpha = 1.0 / input_size + lambda_regression = 0.01 + batch_para_mat = alpha * all1 + para_mat = I + batch_para_mat + + sim_mask_mat = fluid.layers.exp(mask_mat * sim_score_mat) + loss_mat = fluid.layers.log(all1 + sim_mask_mat) + loss_mat = loss_mat * para_mat + loss_align = fluid.layers.mean(loss_mat) + + # regression loss + reg_ones = fluid.layers.ones(shape=[input_size, 1], dtype="float32") + l_reg_diag = fluid.layers.matmul( + l_reg_mat * I, reg_ones, transpose_x=True, transpose_y=False) + p_reg_diag = fluid.layers.matmul( + p_reg_mat * I, reg_ones, transpose_x=True, transpose_y=False) + offset_pred = fluid.layers.concat( + input=[p_reg_diag, l_reg_diag], axis=1) + loss_reg = fluid.layers.mean( + fluid.layers.abs(offset_pred - offs)) # L1 loss + loss = lambda_regression * loss_reg + loss_align + avg_loss = fluid.layers.mean(loss) + + return [avg_loss] diff --git a/PaddleCV/PaddleVideo/models/tsn/README.md b/PaddleCV/PaddleVideo/models/tsn/README.md index 80ca3268..b1333450 100644 --- a/PaddleCV/PaddleVideo/models/tsn/README.md +++ b/PaddleCV/PaddleVideo/models/tsn/README.md @@ -15,7 +15,7 @@ Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。 -详细内容请参考ECCV 2016年论文[StNet:Local and Global Spatial-Temporal Modeling for Human Action Recognition](https://arxiv.org/abs/1608.00859) +详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859) ## 数据准备 diff --git a/PaddleCV/PaddleVideo/predict.py b/PaddleCV/PaddleVideo/predict.py index 939a6ad8..d21deafd 100644 --- a/PaddleCV/PaddleVideo/predict.py +++ b/PaddleCV/PaddleVideo/predict.py @@ -142,11 +142,19 @@ def infer(args): if args.model_name == 'ETS': data_feed_in = [items[:3] for items in data] vinfo = [items[3:] for items in data] - video_id = [items[0] for items in vinfo] + video_id = [items[6] for items in vinfo] infer_outs = exe.run(fetch_list=fetch_list, feed=infer_feeder.feed(data_feed_in), return_numpy=False) infer_result_list = infer_outs + [vinfo] + elif args.model_name == 'TALL': + data_feed_in = [items[:2] for items in data] + vinfo = [items[2:] for items in data] + video_id = [items[0] for items in vinfo] + infer_outs = exe.run(fetch_list=fetch_list, + feed=infer_feeder.feed(data_feed_in), + return_numpy=True) + infer_result_list = infer_outs + [vinfo] elif args.model_name == 'BsnPem': data_feed_in = [items[:1] for items in data] vinfo = [items[1:] for items in data] @@ -155,7 +163,6 @@ def infer(args): feed=infer_feeder.feed(data_feed_in), return_numpy=False) infer_result_list = infer_outs + [vinfo] - else: data_feed_in = [items[:-1] for items in data] video_id = [items[-1] for items in data] diff --git a/PaddleCV/PaddleVideo/reader/__init__.py b/PaddleCV/PaddleVideo/reader/__init__.py index a697e4bd..d394b9b8 100644 --- a/PaddleCV/PaddleVideo/reader/__init__.py +++ b/PaddleCV/PaddleVideo/reader/__init__.py @@ -7,6 +7,7 @@ from .bmn_reader import BMNReader from .bsn_reader import BSNVideoReader from .bsn_reader import BSNProposalReader from .ets_reader import ETSReader +from .tall_reader import TALLReader # regist reader, sort by alphabet regist_reader("ATTENTIONCLUSTER", FeatureReader) @@ -21,3 +22,4 @@ regist_reader("BMN", BMNReader) regist_reader("BSNTEM", BSNVideoReader) regist_reader("BSNPEM", BSNProposalReader) regist_reader("ETS", ETSReader) +regist_reader("TALL", TALLReader) diff --git a/PaddleCV/PaddleVideo/reader/tall_reader.py b/PaddleCV/PaddleVideo/reader/tall_reader.py new file mode 100644 index 00000000..b59b7718 --- /dev/null +++ b/PaddleCV/PaddleVideo/reader/tall_reader.py @@ -0,0 +1,323 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import random +import sys +import numpy as np +import h5py +import multiprocessing +import functools +import paddle + +random.seed(0) + +import logging +logger = logging.getLogger(__name__) + +try: + import cPickle as pickle +except: + import pickle + +from .reader_utils import DataReader + +python_ver = sys.version_info + + +class TALLReader(DataReader): + """ + Data reader for TALL model, which was stored as features extracted by prior networks + """ + + def __init__(self, name, mode, cfg): + self.name = name + self.mode = mode + + self.visual_feature_dim = cfg.MODEL.visual_feature_dim + self.movie_length_info = cfg.TRAIN.movie_length_info + + self.feats_dimen = cfg[mode.upper()]['feats_dimen'] + self.context_num = cfg[mode.upper()]['context_num'] + self.context_size = cfg[mode.upper()]['context_size'] + self.sent_vec_dim = cfg[mode.upper()]['sent_vec_dim'] + self.sliding_clip_path = cfg[mode.upper()]['sliding_clip_path'] + self.clip_sentvec = cfg[mode.upper()]['clip_sentvec'] + self.semantic_size = cfg[mode.upper()]['semantic_size'] + + self.batch_size = cfg[mode.upper()]['batch_size'] + self.init_data() + + def get_context_window(self, clip_name): + # compute left (pre) and right (post) context features based on read_unit_level_feats(). + movie_name = clip_name.split("_")[0] + start = int(clip_name.split("_")[1]) + end = int(clip_name.split("_")[2].split(".")[0]) + clip_length = self.context_size + left_context_feats = np.zeros( + [self.context_num, self.feats_dimen], dtype=np.float32) + right_context_feats = np.zeros( + [self.context_num, self.feats_dimen], dtype=np.float32) + last_left_feat = np.load( + os.path.join(self.sliding_clip_path, clip_name)) + last_right_feat = np.load( + os.path.join(self.sliding_clip_path, clip_name)) + for k in range(self.context_num): + left_context_start = start - clip_length * (k + 1) + left_context_end = start - clip_length * k + right_context_start = end + clip_length * k + right_context_end = end + clip_length * (k + 1) + left_context_name = movie_name + "_" + str( + left_context_start) + "_" + str(left_context_end) + ".npy" + right_context_name = movie_name + "_" + str( + right_context_start) + "_" + str(right_context_end) + ".npy" + if os.path.exists( + os.path.join(self.sliding_clip_path, left_context_name)): + left_context_feat = np.load( + os.path.join(self.sliding_clip_path, left_context_name)) + last_left_feat = left_context_feat + else: + left_context_feat = last_left_feat + if os.path.exists( + os.path.join(self.sliding_clip_path, right_context_name)): + right_context_feat = np.load( + os.path.join(self.sliding_clip_path, right_context_name)) + last_right_feat = right_context_feat + else: + right_context_feat = last_right_feat + left_context_feats[k] = left_context_feat + right_context_feats[k] = right_context_feat + return np.mean( + left_context_feats, axis=0), np.mean( + right_context_feats, axis=0) + + def init_data(self): + def calculate_IoU(i0, i1): + # calculate temporal intersection over union + union = (min(i0[0], i1[0]), max(i0[1], i1[1])) + inter = (max(i0[0], i1[0]), min(i0[1], i1[1])) + iou = 1.0 * (inter[1] - inter[0]) / (union[1] - union[0]) + return iou + + def calculate_nIoL(base, sliding_clip): + # calculate the non Intersection part over Length ratia, make sure the input IoU is larger than 0 + inter = (max(base[0], sliding_clip[0]), min(base[1], + sliding_clip[1])) + inter_l = inter[1] - inter[0] + length = sliding_clip[1] - sliding_clip[0] + nIoL = 1.0 * (length - inter_l) / length + return nIoL + + # load file + if (self.mode == 'train') or (self.mode == 'valid'): + if python_ver < (3, 0): + cs = pickle.load(open(self.clip_sentvec, 'rb')) + movie_length_info = pickle.load( + open(self.movie_length_info, 'rb')) + else: + cs = pickle.load( + open(self.clip_sentvec, 'rb'), encoding='bytes') + movie_length_info = pickle.load( + open(self.movie_length_info, 'rb'), encoding='bytes') + elif (self.mode == 'test') or (self.mode == 'infer'): + if python_ver < (3, 0): + cs = pickle.load(open(self.clip_sentvec, 'rb')) + else: + cs = pickle.load( + open(self.clip_sentvec, 'rb'), encoding='bytes') + + self.clip_sentence_pairs = [] + for l in cs: + clip_name = l[0] + sent_vecs = l[1] + for sent_vec in sent_vecs: + self.clip_sentence_pairs.append((clip_name, sent_vec)) #10146 + logger.info(self.mode.upper() + ':' + str( + len(self.clip_sentence_pairs)) + " clip-sentence pairs are readed") + + movie_names_set = set() + movie_clip_names = {} + # read groundtruth sentence-clip pairs + for k in range(len(self.clip_sentence_pairs)): + clip_name = self.clip_sentence_pairs[k][0] + movie_name = clip_name.split("_")[0] + if not movie_name in movie_names_set: + movie_names_set.add(movie_name) + movie_clip_names[movie_name] = [] + movie_clip_names[movie_name].append(k) + self.movie_names = list(movie_names_set) + logger.info(self.mode.upper() + ':' + str(len(self.movie_names)) + + " movies.") + + # read sliding windows, and match them with the groundtruths to make training samples + sliding_clips_tmp = os.listdir(self.sliding_clip_path) #161396 + self.clip_sentence_pairs_iou = [] + if self.mode == 'valid': + # TALL model doesn't take validation during training, it will test after all the training epochs finish. + return + if self.mode == 'train': + for clip_name in sliding_clips_tmp: + if clip_name.split(".")[2] == "npy": + movie_name = clip_name.split("_")[0] + for clip_sentence in self.clip_sentence_pairs: + original_clip_name = clip_sentence[0] + original_movie_name = original_clip_name.split("_")[0] + if original_movie_name == movie_name: + start = int(clip_name.split("_")[1]) + end = int(clip_name.split("_")[2].split(".")[0]) + o_start = int(original_clip_name.split("_")[1]) + o_end = int( + original_clip_name.split("_")[2].split(".")[0]) + iou = calculate_IoU((start, end), (o_start, o_end)) + if iou > 0.5: + nIoL = calculate_nIoL((o_start, o_end), + (start, end)) + if nIoL < 0.15: + movie_length = movie_length_info[ + movie_name.split(".")[0]] + start_offset = o_start - start + end_offset = o_end - end + self.clip_sentence_pairs_iou.append( + (clip_sentence[0], clip_sentence[1], + clip_name, start_offset, end_offset)) + logger.info('TRAIN:' + str(len(self.clip_sentence_pairs_iou)) + + " iou clip-sentence pairs are readed") + + elif (self.mode == 'test') or (self.mode == 'infer'): + for clip_name in sliding_clips_tmp: + if clip_name.split(".")[2] == "npy": + movie_name = clip_name.split("_")[0] + if movie_name in movie_clip_names: + self.clip_sentence_pairs_iou.append( + clip_name.split(".")[0] + "." + clip_name.split(".") + [1]) + + logger.info('TEST:' + str(len(self.clip_sentence_pairs_iou)) + + " iou clip-sentence pairs are readed") + + def load_movie_slidingclip(self, clip_sentence_pairs, + clip_sentence_pairs_iou, movie_name): + # load unit level feats and sentence vector + movie_clip_sentences = [] + movie_clip_featmap = [] + for k in range(len(clip_sentence_pairs)): + if movie_name in clip_sentence_pairs[k][0]: + movie_clip_sentences.append( + (clip_sentence_pairs[k][0], + clip_sentence_pairs[k][1][:self.semantic_size])) + for k in range(len(clip_sentence_pairs_iou)): + if movie_name in clip_sentence_pairs_iou[k]: + visual_feature_path = os.path.join( + self.sliding_clip_path, clip_sentence_pairs_iou[k] + ".npy") + left_context_feat, right_context_feat = self.get_context_window( + clip_sentence_pairs_iou[k] + ".npy") + feature_data = np.load(visual_feature_path) + comb_feat = np.hstack( + (left_context_feat, feature_data, right_context_feat)) + movie_clip_featmap.append( + (clip_sentence_pairs_iou[k], comb_feat)) + return movie_clip_featmap, movie_clip_sentences + + def create_reader(self): + """reader creator for ets model""" + if self.mode == 'infer': + return self.make_infer_reader() + else: + return self.make_reader() + + def make_infer_reader(self): + """reader for inference""" + + def reader(): + batch_out = [] + idx = 0 + for movie_name in self.movie_names: + idx += 1 + movie_clip_featmaps, movie_clip_sentences = self.load_movie_slidingclip( + self.clip_sentence_pairs, self.clip_sentence_pairs_iou, + movie_name) + for k in range(len(movie_clip_sentences)): + sent_vec = movie_clip_sentences[k][1] + sent_vec = np.reshape(sent_vec, [1, sent_vec.shape[0]]) + for t in range(len(movie_clip_featmaps)): + featmap = movie_clip_featmaps[t][1] + visual_clip_name = movie_clip_featmaps[t][0] + start = float(visual_clip_name.split("_")[1]) + end = float( + visual_clip_name.split("_")[2].split("_")[0]) + featmap = np.reshape(featmap, [1, featmap.shape[0]]) + + batch_out.append((featmap, sent_vec, start, end, k, t, + movie_clip_sentences, + movie_clip_featmaps, movie_name)) + if len(batch_out) == self.batch_size: + yield batch_out + batch_out = [] + + return reader + + def make_reader(self): + def reader(): + batch_out = [] + if self.mode == 'valid': + return + elif self.mode == 'train': + random.shuffle(self.clip_sentence_pairs_iou) + for clip_sentence_pair in self.clip_sentence_pairs_iou: + offset = np.zeros(2, dtype=np.float32) + clip_name = clip_sentence_pair[0] + feat_path = os.path.join(self.sliding_clip_path, + clip_sentence_pair[2]) + featmap = np.load(feat_path) + left_context_feat, right_context_feat = self.get_context_window( + clip_sentence_pair[2]) + image = np.hstack( + (left_context_feat, featmap, right_context_feat)) + sentence = clip_sentence_pair[1][:self.sent_vec_dim] + p_offset = clip_sentence_pair[3] + l_offset = clip_sentence_pair[4] + offset[0] = p_offset + offset[1] = l_offset + batch_out.append((image, sentence, offset)) + if len(batch_out) == self.batch_size: + yield batch_out + batch_out = [] + + elif self.mode == 'test': + for movie_name in self.movie_names: + movie_clip_featmaps, movie_clip_sentences = self.load_movie_slidingclip( + self.clip_sentence_pairs, self.clip_sentence_pairs_iou, + movie_name) + for k in range(len(movie_clip_sentences)): + sent_vec = movie_clip_sentences[k][1] + sent_vec = np.reshape(sent_vec, [1, sent_vec.shape[0]]) + for t in range(len(movie_clip_featmaps)): + featmap = movie_clip_featmaps[t][1] + visual_clip_name = movie_clip_featmaps[t][0] + start = float(visual_clip_name.split("_")[1]) + end = float( + visual_clip_name.split("_")[2].split("_")[0]) + featmap = np.reshape(featmap, [1, featmap.shape[0]]) + + batch_out.append((featmap, sent_vec, start, end, k, + t, movie_clip_sentences, + movie_clip_featmaps, movie_name)) + if len(batch_out) == self.batch_size: + yield batch_out + batch_out = [] + else: + raise NotImplementedError('mode {} not implemented'.format( + self.mode)) + + return reader diff --git a/PaddleCV/PaddleVideo/utils/utility.py b/PaddleCV/PaddleVideo/utils/utility.py index ea3f580d..ced1e7d7 100644 --- a/PaddleCV/PaddleVideo/utils/utility.py +++ b/PaddleCV/PaddleVideo/utils/utility.py @@ -13,6 +13,7 @@ #limitations under the License. import os +import sys import signal import logging import paddle @@ -22,6 +23,7 @@ __all__ = ['AttrDict'] logger = logging.getLogger(__name__) + def _term(sig_num, addition): print('current pid is %s, group id is %s' % (os.getpid(), os.getpgrp())) os.killpg(os.getpgid(os.getpid()), signal.SIGKILL) @@ -52,17 +54,18 @@ def check_cuda(use_cuda, err = \ except Exception as e: pass + def check_version(): - """ + """ Log error and exit when the installed version of paddlepaddle is not satisfied. """ - err = "PaddlePaddle version 1.6 or higher is required, " \ - "or a suitable develop version is satisfied as well. \n" \ - "Please make sure the version is good with your code." \ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ - try: - fluid.require_version('1.6.0') - except Exception as e: - logger.error(err) - sys.exit(1) + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) -- GitLab