未验证 提交 2bd52bf5 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #719 from vslyu/2.1/fix_xpu_eval

[Kunlun]2.1cherry-pick: xpu use one cards for evaluation in multi cards training
...@@ -197,7 +197,7 @@ class CommonDataset(Dataset): ...@@ -197,7 +197,7 @@ class CommonDataset(Dataset):
def __len__(self): def __len__(self):
return self.num_samples return self.num_samples
class MultiLabelDataset(Dataset): class MultiLabelDataset(Dataset):
""" """
...@@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset): ...@@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
labels = label_str.split(',') labels = label_str.split(',')
labels = [int(i) for i in labels] labels = [int(i) for i in labels]
return (transform(img, self.ops), np.array(labels).astype("float32")) return (transform(img, self.ops),
np.array(labels).astype("float32"))
except Exception as e: except Exception as e:
logger.error("data read failed: {}, exception info: {}".format(line, e)) logger.error("data read failed: {}, exception info: {}".format(
line, e))
return self.__getitem__(random.randint(0, len(self))) return self.__getitem__(random.randint(0, len(self)))
def __len__(self): def __len__(self):
...@@ -263,6 +265,7 @@ class Reader: ...@@ -263,6 +265,7 @@ class Reader:
self.collate_fn = self.mix_collate_fn self.collate_fn = self.mix_collate_fn
self.places = places self.places = places
self.use_xpu = config.get("use_xpu", False)
self.multilabel = config.get("multilabel", False) self.multilabel = config.get("multilabel", False)
def mix_collate_fn(self, batch): def mix_collate_fn(self, batch):
...@@ -285,20 +288,29 @@ class Reader: ...@@ -285,20 +288,29 @@ class Reader:
dataset = MultiLabelDataset(self.params) dataset = MultiLabelDataset(self.params)
else: else:
dataset = CommonDataset(self.params) dataset = CommonDataset(self.params)
if (self.params['mode'] != "train") and self.use_xpu:
is_train = self.params['mode'] == "train" loader = DataLoader(
batch_sampler = DistributedBatchSampler( dataset,
dataset, places=self.places,
batch_size=batch_size, batch_size=batch_size,
shuffle=self.shuffle and is_train, drop_last=False,
drop_last=is_train) return_list=True,
loader = DataLoader( shuffle=False,
dataset, num_workers=self.params["num_workers"])
batch_sampler=batch_sampler, else:
collate_fn=self.collate_fn if is_train else None, is_train = self.params['mode'] == "train"
places=self.places, batch_sampler = DistributedBatchSampler(
return_list=True, dataset,
num_workers=self.params["num_workers"]) batch_size=batch_size,
shuffle=self.shuffle and is_train,
drop_last=is_train)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn if is_train else None,
places=self.places,
return_list=True,
num_workers=self.params["num_workers"])
return loader return loader
......
...@@ -119,7 +119,8 @@ def create_metric(out, ...@@ -119,7 +119,8 @@ def create_metric(out,
classes_num=1000, classes_num=1000,
use_distillation=False, use_distillation=False,
multilabel=False, multilabel=False,
mode="train"): mode="train",
use_xpu=False):
""" """
Create measures of model accuracy, such as top1 and top5 Create measures of model accuracy, such as top1 and top5
...@@ -175,11 +176,12 @@ def create_metric(out, ...@@ -175,11 +176,12 @@ def create_metric(out,
fetch_list.append(ham_dist) fetch_list.append(ham_dist)
# multi cards' eval # multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1: if not use_xpu:
for idx, fetch in enumerate(fetch_list): if mode != "train" and paddle.distributed.get_world_size() > 1:
fetch_list[idx] = paddle.distributed.all_reduce( for idx, fetch in enumerate(fetch_list):
fetch, op=paddle.distributed.ReduceOp. fetch_list[idx] = paddle.distributed.all_reduce(
SUM) / paddle.distributed.get_world_size() fetch, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs = OrderedDict() fetchs = OrderedDict()
for idx, name in enumerate(metric_names): for idx, name in enumerate(metric_names):
...@@ -213,6 +215,7 @@ def create_fetchs(feeds, net, config, mode="train"): ...@@ -213,6 +215,7 @@ def create_fetchs(feeds, net, config, mode="train"):
use_mix = config.get('use_mix') and mode == 'train' use_mix = config.get('use_mix') and mode == 'train'
use_distillation = config.get('use_distillation') use_distillation = config.get('use_distillation')
multilabel = config.get('multilabel', False) multilabel = config.get('multilabel', False)
use_xpu = config.get("use_xpu", False)
out = net(feeds["image"]) out = net(feeds["image"])
...@@ -229,7 +232,8 @@ def create_fetchs(feeds, net, config, mode="train"): ...@@ -229,7 +232,8 @@ def create_fetchs(feeds, net, config, mode="train"):
classes_num, classes_num,
use_distillation, use_distillation,
multilabel=multilabel, multilabel=multilabel,
mode=mode) mode=mode,
use_xpu=use_xpu)
fetchs.update(metric) fetchs.update(metric)
return fetchs return fetchs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册