diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 90bff3589d88d036da717946de3fe6d6821edb37..cda7077af68f26cb470581bb3e51caf61c4213c8 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -197,7 +197,7 @@ class CommonDataset(Dataset): def __len__(self): return self.num_samples - + class MultiLabelDataset(Dataset): """ @@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset): labels = label_str.split(',') labels = [int(i) for i in labels] - return (transform(img, self.ops), np.array(labels).astype("float32")) + return (transform(img, self.ops), + np.array(labels).astype("float32")) except Exception as e: - logger.error("data read failed: {}, exception info: {}".format(line, e)) + logger.error("data read failed: {}, exception info: {}".format( + line, e)) return self.__getitem__(random.randint(0, len(self))) def __len__(self): @@ -263,6 +265,7 @@ class Reader: self.collate_fn = self.mix_collate_fn self.places = places + self.use_xpu = config.get("use_xpu", False) self.multilabel = config.get("multilabel", False) def mix_collate_fn(self, batch): @@ -285,20 +288,29 @@ class Reader: dataset = MultiLabelDataset(self.params) else: dataset = CommonDataset(self.params) - - is_train = self.params['mode'] == "train" - batch_sampler = DistributedBatchSampler( - dataset, - batch_size=batch_size, - shuffle=self.shuffle and is_train, - drop_last=is_train) - loader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=self.collate_fn if is_train else None, - places=self.places, - return_list=True, - num_workers=self.params["num_workers"]) + if (self.params['mode'] != "train") and self.use_xpu: + loader = DataLoader( + dataset, + places=self.places, + batch_size=batch_size, + drop_last=False, + return_list=True, + shuffle=False, + num_workers=self.params["num_workers"]) + else: + is_train = self.params['mode'] == "train" + batch_sampler = DistributedBatchSampler( + dataset, + batch_size=batch_size, + shuffle=self.shuffle and is_train, + drop_last=is_train) + loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn if is_train else None, + places=self.places, + return_list=True, + num_workers=self.params["num_workers"]) return loader diff --git a/tools/program.py b/tools/program.py index fd155ab8f7334906ebb0aa60f6c2b7f350b06c1b..4666d96242adc4769eec28bf22a863f6a1512196 100644 --- a/tools/program.py +++ b/tools/program.py @@ -119,7 +119,8 @@ def create_metric(out, classes_num=1000, use_distillation=False, multilabel=False, - mode="train"): + mode="train", + use_xpu=False): """ Create measures of model accuracy, such as top1 and top5 @@ -175,11 +176,12 @@ def create_metric(out, fetch_list.append(ham_dist) # multi cards' eval - if mode != "train" and paddle.distributed.get_world_size() > 1: - for idx, fetch in enumerate(fetch_list): - fetch_list[idx] = paddle.distributed.all_reduce( - fetch, op=paddle.distributed.ReduceOp. - SUM) / paddle.distributed.get_world_size() + if not use_xpu: + if mode != "train" and paddle.distributed.get_world_size() > 1: + for idx, fetch in enumerate(fetch_list): + fetch_list[idx] = paddle.distributed.all_reduce( + fetch, op=paddle.distributed.ReduceOp. + SUM) / paddle.distributed.get_world_size() fetchs = OrderedDict() for idx, name in enumerate(metric_names): @@ -213,6 +215,7 @@ def create_fetchs(feeds, net, config, mode="train"): use_mix = config.get('use_mix') and mode == 'train' use_distillation = config.get('use_distillation') multilabel = config.get('multilabel', False) + use_xpu = config.get("use_xpu", False) out = net(feeds["image"]) @@ -229,7 +232,8 @@ def create_fetchs(feeds, net, config, mode="train"): classes_num, use_distillation, multilabel=multilabel, - mode=mode) + mode=mode, + use_xpu=use_xpu) fetchs.update(metric) return fetchs