diff --git a/examples/tsm/kinetics_dataset.py b/examples/tsm/kinetics_dataset.py index c8570018cfbcf808917f28806ab841da874782d3..123d89814a8c631569cd0503750cafac631cca22 100644 --- a/examples/tsm/kinetics_dataset.py +++ b/examples/tsm/kinetics_dataset.py @@ -100,19 +100,12 @@ class KineticsDataset(Dataset): def __getitem__(self, idx): pickle_path = os.path.join(self.pickle_dir, self.pickle_paths[idx]) - try: - if six.PY2: - data = pickle.load(open(pickle_path, 'rb')) - else: - data = pickle.load(open(pickle_path, 'rb'), encoding='bytes') - - vid, label, frames = data - if len(frames) < 1: - logger.error("{} contains no frame".format(pickle_path)) - sys.exit(-1) - except Exception as e: - logger.error("Load {} failed: {}".format(pickle_path, e)) - sys.exit(-1) + if six.PY2: + data = pickle.load(open(pickle_path, 'rb')) + else: + data = pickle.load(open(pickle_path, 'rb'), encoding='bytes') + + vid, label, frames = data if self.label_list is not None: label = self.label_list.index(label) diff --git a/hapi/datasets/coco.py b/hapi/datasets/coco.py index f1ab97281a6e0e20834c33f1e6663903f25349a0..50d31cff06692e30fb153983023d4c8ed7476f2c 100644 --- a/hapi/datasets/coco.py +++ b/hapi/datasets/coco.py @@ -18,7 +18,6 @@ from __future__ import print_function import os import cv2 import numpy as np -from pycocotools.coco import COCO from paddle.io import Dataset @@ -91,6 +90,7 @@ class COCODataset(Dataset): self._load_roidb_and_cname2cid() def _load_roidb_and_cname2cid(self): + from pycocotools.coco import COCO assert self._anno_path.endswith('.json'), \ 'invalid coco annotation file: ' + anno_path coco = COCO(self._anno_path) diff --git a/hapi/model.py b/hapi/model.py index f4e6744df5107d345c873f6fa45269f704615708..ed891f58a95f399d02475bfec16c53d9c82e8422 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -798,12 +798,12 @@ class Model(fluid.dygraph.Layer): "{} receives a shape {}, but the expected shape is {}.". format(key, list(state.shape), list(param.shape))) return param, state - - def _strip_postfix(path): - path, ext = os.path.splitext(path) - assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ - "Unknown postfix {} from weights".format(ext) - return path + + def _strip_postfix(path): + path, ext = os.path.splitext(path) + assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ + "Unknown postfix {} from weights".format(ext) + return path path = _strip_postfix(path) param_state = _load_state_from_path(path + ".pdparams")