提交 0872cfa2 编写于 作者: D dengkaipeng

fix tsm hang

上级 dc2a5e54
...@@ -100,19 +100,12 @@ class KineticsDataset(Dataset): ...@@ -100,19 +100,12 @@ class KineticsDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx]) pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx])
try: if six.PY2:
if six.PY2: data = pickle.load(open(pickle_path, 'rb'))
data = pickle.load(open(pickle_path, 'rb')) else:
else: data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
data = pickle.load(open(pickle_path, 'rb'), encoding='bytes')
vid, label, frames = data
vid, label, frames = data
if len(frames) < 1:
logger.error("{} contains no frame".format(pickle_path))
sys.exit(-1)
except Exception as e:
logger.error("Load {} failed: {}".format(pickle_path, e))
sys.exit(-1)
if self.label_list is not None: if self.label_list is not None:
label = self.label_list.index(label) label = self.label_list.index(label)
......
...@@ -18,7 +18,6 @@ from __future__ import print_function ...@@ -18,7 +18,6 @@ from __future__ import print_function
import os import os
import cv2 import cv2
import numpy as np import numpy as np
from pycocotools.coco import COCO
from paddle.io import Dataset from paddle.io import Dataset
...@@ -91,6 +90,7 @@ class COCODataset(Dataset): ...@@ -91,6 +90,7 @@ class COCODataset(Dataset):
self._load_roidb_and_cname2cid() self._load_roidb_and_cname2cid()
def _load_roidb_and_cname2cid(self): def _load_roidb_and_cname2cid(self):
from pycocotools.coco import COCO
assert self._anno_path.endswith('.json'), \ assert self._anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path 'invalid coco annotation file: ' + anno_path
coco = COCO(self._anno_path) coco = COCO(self._anno_path)
......
...@@ -798,12 +798,12 @@ class Model(fluid.dygraph.Layer): ...@@ -798,12 +798,12 @@ class Model(fluid.dygraph.Layer):
"{} receives a shape {}, but the expected shape is {}.". "{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape))) format(key, list(state.shape), list(param.shape)))
return param, state return param, state
def _strip_postfix(path): def _strip_postfix(path):
path, ext = os.path.splitext(path) path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
"Unknown postfix {} from weights".format(ext) "Unknown postfix {} from weights".format(ext)
return path return path
path = _strip_postfix(path) path = _strip_postfix(path)
param_state = _load_state_from_path(path + ".pdparams") param_state = _load_state_from_path(path + ".pdparams")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册