提交 e46c41b4 编写于 作者: H huangjun12 提交者: SunGaofeng

fix bug of tall when using python3 (#3791)

* fix bug of tall when using python3

* add link of weight and specify the dependency in README

* delete preprocess.sh in dataset/ets/README
上级 9d31d6e3
......@@ -37,6 +37,10 @@
- CUDNN >= 7.0
- pandas
- h5py
- 使用Youtube-8M数据集时,需要将tfrecord数据转化成pickle格式,需要用到Tensorflow,详见[数据说明](./data/dataset/README.md)中Youtube-8M部分。与此相关的模型是Attention Cluster, Attention LSTM, NeXtVLAD,使用其他模型请忽略此项。
- 使用Kinetics数据集时,如果需要将mp4文件提前解码并保存成pickle格式,需要用到ffmpeg,详见[数据说明](./data/dataset/README.md)中Kinetics部分。需要说明的是Nonlocal模型虽然也使用Kinetics数据集,但输入数据是视频源文件,不需要提前解码,不涉及此项。与此相关的模型是TSN, TSM, StNet,使用其他模型请忽略此项。
......
......@@ -28,7 +28,6 @@ ets
|----feat_data/
|----train.list
|----val.list
|----preprocess.sh
|----generate_train_pickle.py
|----generate_data.py
|----generate_infer_data.py
......
......@@ -46,7 +46,7 @@ class MetricsCalculator():
if len(x1) == 0:
return pick
union = map(operator.sub, x2, x1) # union = x2-x1
union = list(map(operator.sub, x2, x1)) # union = x2-x1
I = [i[0] for i in sorted(
enumerate(sim), key=lambda x: x[1])] # sort and get index
......
......@@ -189,4 +189,7 @@ class ETS(ModelBase):
return (None, None)
def weights_info(self):
pass
return (
'ETS_final.pdparams',
'https://paddlemodels.bj.bcebos.com/video_caption/ETS_final.pdparams'
)
......@@ -162,4 +162,7 @@ class TALL(ModelBase):
return (None, None)
def weights_info(self):
pass
return (
'TALL_final.pdparams',
'https://paddlemodels.bj.bcebos.com/video_grounding/TALL_final.pdparams'
)
......@@ -142,7 +142,7 @@ def infer(args):
if args.model_name == 'ETS':
data_feed_in = [items[:3] for items in data]
vinfo = [items[3:] for items in data]
video_id = [items[6] for items in vinfo]
video_id = [items[0] for items in vinfo]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in),
return_numpy=False)
......@@ -150,7 +150,7 @@ def infer(args):
elif args.model_name == 'TALL':
data_feed_in = [items[:2] for items in data]
vinfo = [items[2:] for items in data]
video_id = [items[0] for items in vinfo]
video_id = [items[6] for items in vinfo]
infer_outs = exe.run(fetch_list=fetch_list,
feed=infer_feeder.feed(data_feed_in),
return_numpy=True)
......
......@@ -139,8 +139,8 @@ class TALLReader(DataReader):
self.clip_sentence_pairs = []
for l in cs:
clip_name = l[0]
sent_vecs = l[1]
clip_name = l[0].decode('utf-8') #byte object to string
sent_vecs = l[1] #numpy array
for sent_vec in sent_vecs:
self.clip_sentence_pairs.append((clip_name, sent_vec)) #10146
logger.info(self.mode.upper() + ':' + str(
......@@ -185,7 +185,8 @@ class TALLReader(DataReader):
(start, end))
if nIoL < 0.15:
movie_length = movie_length_info[
movie_name.split(".")[0]]
movie_name.split(".")[0].encode(
'utf-8')] #str to byte
start_offset = o_start - start
end_offset = o_end - end
self.clip_sentence_pairs_iou.append(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册