未验证 提交 463980f3 编写于 作者: L littletomatodonkey 提交者: GitHub

add support for multi cards eval (#413)

* add support for multi cards eva
上级 6be63583
...@@ -242,29 +242,19 @@ class Reader: ...@@ -242,29 +242,19 @@ class Reader:
dataset = CommonDataset(self.params) dataset = CommonDataset(self.params)
if self.params['mode'] == "train": is_train = self.params['mode'] == "train"
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=self.shuffle, shuffle=self.shuffle and is_train,
drop_last=True) drop_last=is_train)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=self.collate_fn, collate_fn=self.collate_fn if is_train else None,
places=self.places, places=self.places,
return_list=True, return_list=True,
num_workers=self.params["num_workers"]) num_workers=self.params["num_workers"])
else:
loader = DataLoader(
dataset,
places=self.places,
batch_size=batch_size,
drop_last=False,
return_list=True,
shuffle=False,
num_workers=self.params["num_workers"])
return loader return loader
......
...@@ -253,7 +253,7 @@ class ResNet_vd(nn.Layer): ...@@ -253,7 +253,7 @@ class ResNet_vd(nn.Layer):
for block in range(len(depth)): for block in range(len(depth)):
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
if layers in [101, 152] and block == 2: if layers in [101, 152, 200] and block == 2:
if i == 0: if i == 0:
conv_name = "res" + str(block + 2) + "a" conv_name = "res" + str(block + 2) + "a"
else: else:
......
...@@ -143,6 +143,8 @@ def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'): ...@@ -143,6 +143,8 @@ def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
""" """
save model to the target path save model to the target path
""" """
if paddle.distributed.get_rank() != 0:
return
model_path = os.path.join(model_path, str(epoch_id)) model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path) _mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
......
...@@ -108,7 +108,8 @@ def create_metric(out, ...@@ -108,7 +108,8 @@ def create_metric(out,
architecture, architecture,
topk=5, topk=5,
classes_num=1000, classes_num=1000,
use_distillation=False): use_distillation=False,
mode="train"):
""" """
Create measures of model accuracy, such as top1 and top5 Create measures of model accuracy, such as top1 and top5
...@@ -117,6 +118,8 @@ def create_metric(out, ...@@ -117,6 +118,8 @@ def create_metric(out,
feeds(dict): dict of model input variables(included label) feeds(dict): dict of model input variables(included label)
topk(int): usually top5 topk(int): usually top5
classes_num(int): num of classes classes_num(int): num of classes
use_distillation(bool): whether to use distillation training
mode(str): mode, train/valid
Returns: Returns:
fetchs(dict): dict of measures fetchs(dict): dict of measures
...@@ -133,10 +136,20 @@ def create_metric(out, ...@@ -133,10 +136,20 @@ def create_metric(out,
fetchs = OrderedDict() fetchs = OrderedDict()
# set top1 to fetchs # set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1) top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
fetchs['top1'] = top1
# set topk to fetchs # set topk to fetchs
k = min(topk, classes_num) k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k) topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
topk = paddle.distributed.all_reduce(
topk, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs['top1'] = top1
topk_name = 'top{}'.format(k) topk_name = 'top{}'.format(k)
fetchs[topk_name] = topk fetchs[topk_name] = topk
...@@ -175,8 +188,14 @@ def create_fetchs(feeds, net, config, mode="train"): ...@@ -175,8 +188,14 @@ def create_fetchs(feeds, net, config, mode="train"):
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num, fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
epsilon, use_mix, use_distillation) epsilon, use_mix, use_distillation)
if not use_mix: if not use_mix:
metric = create_metric(out, feeds["label"], architecture, topk, metric = create_metric(
classes_num, use_distillation) out,
feeds["label"],
architecture,
topk,
classes_num,
use_distillation,
mode=mode)
fetchs.update(metric) fetchs.update(metric)
return fetchs return fetchs
......
...@@ -77,7 +77,7 @@ def main(args): ...@@ -77,7 +77,7 @@ def main(args):
train_dataloader = Reader(config, 'train', places=place)() train_dataloader = Reader(config, 'train', places=place)()
if config.validate and paddle.distributed.get_rank() == 0: if config.validate:
valid_dataloader = Reader(config, 'valid', places=place)() valid_dataloader = Reader(config, 'valid', places=place)()
last_epoch_id = config.get("last_epoch", -1) last_epoch_id = config.get("last_epoch", -1)
...@@ -89,28 +89,27 @@ def main(args): ...@@ -89,28 +89,27 @@ def main(args):
program.run(train_dataloader, config, net, optimizer, lr_scheduler, program.run(train_dataloader, config, net, optimizer, lr_scheduler,
epoch_id, 'train') epoch_id, 'train')
if paddle.distributed.get_rank() == 0: # 2. validate with validate dataset
# 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0:
if config.validate and epoch_id % config.valid_interval == 0: net.eval()
net.eval() top1_acc = program.run(valid_dataloader, config, net, None, None,
top1_acc = program.run(valid_dataloader, config, net, None, epoch_id, 'valid')
None, epoch_id, 'valid') if top1_acc > best_top1_acc:
if top1_acc > best_top1_acc: best_top1_acc = top1_acc
best_top1_acc = top1_acc best_top1_epoch = epoch_id
best_top1_epoch = epoch_id if epoch_id % config.save_interval == 0:
if epoch_id % config.save_interval == 0: model_path = os.path.join(config.model_save_dir,
model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"])
config.ARCHITECTURE["name"]) save_model(net, optimizer, model_path, "best_model")
save_model(net, optimizer, model_path, "best_model") message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
message = "The best top1 acc {:.5f}, in epoch: {:d}".format( best_top1_acc, best_top1_epoch)
best_top1_acc, best_top1_epoch) logger.info("{:s}".format(logger.coloring(message, "RED")))
logger.info("{:s}".format(logger.coloring(message, "RED")))
# 3. save the persistable model
# 3. save the persistable model if epoch_id % config.save_interval == 0:
if epoch_id % config.save_interval == 0: model_path = os.path.join(config.model_save_dir,
model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"])
config.ARCHITECTURE["name"]) save_model(net, optimizer, model_path, epoch_id)
save_model(net, optimizer, model_path, epoch_id)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册