From e8bed7f3a08871cc1cc91d10459946b560e46e46 Mon Sep 17 00:00:00 2001 From: liuyuhui Date: Sat, 8 May 2021 07:15:38 +0000 Subject: [PATCH] fix xpu reader --- ppcls/data/reader.py | 46 ++++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 90bff358..cda7077a 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -197,7 +197,7 @@ class CommonDataset(Dataset): def __len__(self): return self.num_samples - + class MultiLabelDataset(Dataset): """ @@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset): labels = label_str.split(',') labels = [int(i) for i in labels] - return (transform(img, self.ops), np.array(labels).astype("float32")) + return (transform(img, self.ops), + np.array(labels).astype("float32")) except Exception as e: - logger.error("data read failed: {}, exception info: {}".format(line, e)) + logger.error("data read failed: {}, exception info: {}".format( + line, e)) return self.__getitem__(random.randint(0, len(self))) def __len__(self): @@ -263,6 +265,7 @@ class Reader: self.collate_fn = self.mix_collate_fn self.places = places + self.use_xpu = config.get("use_xpu", False) self.multilabel = config.get("multilabel", False) def mix_collate_fn(self, batch): @@ -285,20 +288,29 @@ class Reader: dataset = MultiLabelDataset(self.params) else: dataset = CommonDataset(self.params) - - is_train = self.params['mode'] == "train" - batch_sampler = DistributedBatchSampler( - dataset, - batch_size=batch_size, - shuffle=self.shuffle and is_train, - drop_last=is_train) - loader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=self.collate_fn if is_train else None, - places=self.places, - return_list=True, - num_workers=self.params["num_workers"]) + if (self.params['mode'] != "train") and self.use_xpu: + loader = DataLoader( + dataset, + places=self.places, + batch_size=batch_size, + drop_last=False, + return_list=True, + shuffle=False, + num_workers=self.params["num_workers"]) + else: + is_train = self.params['mode'] == "train" + batch_sampler = DistributedBatchSampler( + dataset, + batch_size=batch_size, + shuffle=self.shuffle and is_train, + drop_last=is_train) + loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn if is_train else None, + places=self.places, + return_list=True, + num_workers=self.params["num_workers"]) return loader -- GitLab