未验证 提交 463980f3 编写于 作者: L littletomatodonkey 提交者: GitHub

add support for multi cards eval (#413)

* add support for multi cards eva
上级 6be63583
......@@ -242,29 +242,19 @@ class Reader:
dataset = CommonDataset(self.params)
if self.params['mode'] == "train":
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=self.shuffle,
drop_last=True)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
places=self.places,
return_list=True,
num_workers=self.params["num_workers"])
else:
loader = DataLoader(
dataset,
places=self.places,
batch_size=batch_size,
drop_last=False,
return_list=True,
shuffle=False,
num_workers=self.params["num_workers"])
is_train = self.params['mode'] == "train"
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=self.shuffle and is_train,
drop_last=is_train)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn if is_train else None,
places=self.places,
return_list=True,
num_workers=self.params["num_workers"])
return loader
......
......@@ -253,7 +253,7 @@ class ResNet_vd(nn.Layer):
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
......
......@@ -143,6 +143,8 @@ def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
"""
save model to the target path
"""
if paddle.distributed.get_rank() != 0:
return
model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
......
......@@ -108,7 +108,8 @@ def create_metric(out,
architecture,
topk=5,
classes_num=1000,
use_distillation=False):
use_distillation=False,
mode="train"):
"""
Create measures of model accuracy, such as top1 and top5
......@@ -117,6 +118,8 @@ def create_metric(out,
feeds(dict): dict of model input variables(included label)
topk(int): usually top5
classes_num(int): num of classes
use_distillation(bool): whether to use distillation training
mode(str): mode, train/valid
Returns:
fetchs(dict): dict of measures
......@@ -133,10 +136,20 @@ def create_metric(out,
fetchs = OrderedDict()
# set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
fetchs['top1'] = top1
# set topk to fetchs
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
topk = paddle.distributed.all_reduce(
topk, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs['top1'] = top1
topk_name = 'top{}'.format(k)
fetchs[topk_name] = topk
......@@ -175,8 +188,14 @@ def create_fetchs(feeds, net, config, mode="train"):
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
epsilon, use_mix, use_distillation)
if not use_mix:
metric = create_metric(out, feeds["label"], architecture, topk,
classes_num, use_distillation)
metric = create_metric(
out,
feeds["label"],
architecture,
topk,
classes_num,
use_distillation,
mode=mode)
fetchs.update(metric)
return fetchs
......
......@@ -77,7 +77,7 @@ def main(args):
train_dataloader = Reader(config, 'train', places=place)()
if config.validate and paddle.distributed.get_rank() == 0:
if config.validate:
valid_dataloader = Reader(config, 'valid', places=place)()
last_epoch_id = config.get("last_epoch", -1)
......@@ -89,28 +89,27 @@ def main(args):
program.run(train_dataloader, config, net, optimizer, lr_scheduler,
epoch_id, 'train')
if paddle.distributed.get_rank() == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
top1_acc = program.run(valid_dataloader, config, net, None,
None, epoch_id, 'valid')
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info("{:s}".format(logger.coloring(message, "RED")))
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, epoch_id)
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
net.eval()
top1_acc = program.run(valid_dataloader, config, net, None, None,
epoch_id, 'valid')
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
best_top1_epoch = epoch_id
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info("{:s}".format(logger.coloring(message, "RED")))
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, epoch_id)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册