diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 90bff3589d88d036da717946de3fe6d6821edb37..cda7077af68f26cb470581bb3e51caf61c4213c8 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -197,7 +197,7 @@ class CommonDataset(Dataset): def __len__(self): return self.num_samples - + class MultiLabelDataset(Dataset): """ @@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset): labels = label_str.split(',') labels = [int(i) for i in labels] - return (transform(img, self.ops), np.array(labels).astype("float32")) + return (transform(img, self.ops), + np.array(labels).astype("float32")) except Exception as e: - logger.error("data read failed: {}, exception info: {}".format(line, e)) + logger.error("data read failed: {}, exception info: {}".format( + line, e)) return self.__getitem__(random.randint(0, len(self))) def __len__(self): @@ -263,6 +265,7 @@ class Reader: self.collate_fn = self.mix_collate_fn self.places = places + self.use_xpu = config.get("use_xpu", False) self.multilabel = config.get("multilabel", False) def mix_collate_fn(self, batch): @@ -285,20 +288,29 @@ class Reader: dataset = MultiLabelDataset(self.params) else: dataset = CommonDataset(self.params) - - is_train = self.params['mode'] == "train" - batch_sampler = DistributedBatchSampler( - dataset, - batch_size=batch_size, - shuffle=self.shuffle and is_train, - drop_last=is_train) - loader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=self.collate_fn if is_train else None, - places=self.places, - return_list=True, - num_workers=self.params["num_workers"]) + if (self.params['mode'] != "train") and self.use_xpu: + loader = DataLoader( + dataset, + places=self.places, + batch_size=batch_size, + drop_last=False, + return_list=True, + shuffle=False, + num_workers=self.params["num_workers"]) + else: + is_train = self.params['mode'] == "train" + batch_sampler = DistributedBatchSampler( + dataset, + batch_size=batch_size, + shuffle=self.shuffle and is_train, + drop_last=is_train) + loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn if is_train else None, + places=self.places, + return_list=True, + num_workers=self.params["num_workers"]) return loader