提交 e8bed7f3 编写于 作者: L liuyuhui

fix xpu reader

上级 ca51b6f7
......@@ -197,7 +197,7 @@ class CommonDataset(Dataset):
def __len__(self):
return self.num_samples
class MultiLabelDataset(Dataset):
"""
......@@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
labels = label_str.split(',')
labels = [int(i) for i in labels]
return (transform(img, self.ops), np.array(labels).astype("float32"))
return (transform(img, self.ops),
np.array(labels).astype("float32"))
except Exception as e:
logger.error("data read failed: {}, exception info: {}".format(line, e))
logger.error("data read failed: {}, exception info: {}".format(
line, e))
return self.__getitem__(random.randint(0, len(self)))
def __len__(self):
......@@ -263,6 +265,7 @@ class Reader:
self.collate_fn = self.mix_collate_fn
self.places = places
self.use_xpu = config.get("use_xpu", False)
self.multilabel = config.get("multilabel", False)
def mix_collate_fn(self, batch):
......@@ -285,20 +288,29 @@ class Reader:
dataset = MultiLabelDataset(self.params)
else:
dataset = CommonDataset(self.params)
is_train = self.params['mode'] == "train"
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=self.shuffle and is_train,
drop_last=is_train)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn if is_train else None,
places=self.places,
return_list=True,
num_workers=self.params["num_workers"])
if (self.params['mode'] != "train") and self.use_xpu:
loader = DataLoader(
dataset,
places=self.places,
batch_size=batch_size,
drop_last=False,
return_list=True,
shuffle=False,
num_workers=self.params["num_workers"])
else:
is_train = self.params['mode'] == "train"
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=self.shuffle and is_train,
drop_last=is_train)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn if is_train else None,
places=self.places,
return_list=True,
num_workers=self.params["num_workers"])
return loader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册