未验证 提交 15b18973 编写于 作者: L littletomatodonkey 提交者: GitHub

fix eval script (#464)

* fix eval script

* fix dali shell
上级 e3801c55
...@@ -179,7 +179,8 @@ class Branches(nn.Layer): ...@@ -179,7 +179,8 @@ class Branches(nn.Layer):
outs = [] outs = []
for idx, input in enumerate(inputs): for idx, input in enumerate(inputs):
conv = input conv = input
for basic_block_func in self.basic_block_list[idx]: basic_block_list = self.basic_block_list[idx]
for basic_block_func in basic_block_list:
conv = basic_block_func(conv) conv = basic_block_func(conv)
outs.append(conv) outs.append(conv)
return outs return outs
......
...@@ -165,6 +165,7 @@ class SplatConv(nn.Layer): ...@@ -165,6 +165,7 @@ class SplatConv(nn.Layer):
atten = self.conv3(gap) atten = self.conv3(gap)
atten = self.rsoftmax(atten) atten = self.rsoftmax(atten)
atten = paddle.reshape(x=atten, shape=[-1, atten.shape[1], 1, 1])
if self.radix > 1: if self.radix > 1:
attens = paddle.split(atten, num_or_sections=self.radix, axis=1) attens = paddle.split(atten, num_or_sections=self.radix, axis=1)
......
...@@ -48,12 +48,12 @@ class AverageMeter(object): ...@@ -48,12 +48,12 @@ class AverageMeter(object):
@property @property
def total_minute(self): def total_minute(self):
return '{self.name}_sum: {s:{self.fmt}}{self.postfix} min'.format( return '{self.name} {s:{self.fmt}}{self.postfix} min'.format(
s=self.sum / 60, self=self) s=self.sum / 60, self=self)
@property @property
def mean(self): def mean(self):
return '{self.name}_avg: {self.avg:{self.fmt}}{self.postfix}'.format( return '{self.name}: {self.avg:{self.fmt}}{self.postfix}'.format(
self=self) if self.need_avg else '' self=self) if self.need_avg else ''
@property @property
......
...@@ -67,7 +67,7 @@ def main(args, return_dict={}): ...@@ -67,7 +67,7 @@ def main(args, return_dict={}):
init_model(config, net, optimizer=None) init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)() valid_dataloader = Reader(config, 'valid', places=place)()
net.eval() net.eval()
with paddle.no_grad():
top1_acc = program.run(valid_dataloader, config, net, None, None, 0, top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
'valid') 'valid')
return_dict["top1_acc"] = top1_acc return_dict["top1_acc"] = top1_acc
......
python -m paddle.distributed.launch \ python3.7 -m paddle.distributed.launch \
--selected_gpus="0" \ --gpus="0,1,2,3" \
tools/eval.py \ tools/eval.py \
-c ./configs/eval.yaml \ -c ./configs/ResNet/ResNet50.yaml \
-o load_static_weights=True \ -o pretrained_model="./ResNet50_pretrained" \
-o use_gpu=False -o use_gpu=True
...@@ -298,6 +298,11 @@ def run(dataloader, ...@@ -298,6 +298,11 @@ def run(dataloader,
tic = time.time() tic = time.time()
for idx, batch in enumerate(dataloader()): for idx, batch in enumerate(dataloader()):
# avoid statistics from warmup time
if idx == 10:
metric_list["batch_time"].reset()
metric_list["reader_time"].reset()
metric_list['reader_time'].update(time.time() - tic) metric_list['reader_time'].update(time.time() - tic)
batch_size = len(batch[0]) batch_size = len(batch[0])
feeds = create_feeds(batch, use_mix) feeds = create_feeds(batch, use_mix)
...@@ -327,11 +332,15 @@ def run(dataloader, ...@@ -327,11 +332,15 @@ def run(dataloader,
metric_list["batch_time"].update(time.time() - tic) metric_list["batch_time"].update(time.time() - tic)
tic = time.time() tic = time.time()
fetchs_str = ' '.join([str(m.value) for m in metric_list.values()]) fetchs_str = ' '.join([
str(metric_list[key].mean)
if "time" in key else str(metric_list[key].value)
for key in metric_list
])
if idx % print_interval == 0: if idx % print_interval == 0:
ips_info = "ips: {:.5f} images/sec.".format( ips_info = "ips: {:.5f} images/sec.".format(
batch_size / metric_list["batch_time"].val) batch_size / metric_list["batch_time"].avg)
if mode == 'eval': if mode == 'eval':
logger.info("{:s} step:{:<4d}, {:s} {:s}".format( logger.info("{:s} step:{:<4d}, {:s} {:s}".format(
mode, idx, fetchs_str, ips_info)) mode, idx, fetchs_str, ips_info))
......
#!/usr/bin/env bash #!/usr/bin/env bash
python -m paddle.distributed.launch \ python3.7 -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \ --gpus="0,1,2,3" \
tools/train.py \ tools/train.py \
-c ./configs/ResNet/ResNet50.yaml \ -c ./configs/ResNet/ResNet50.yaml \
-o print_interval=10 -o print_interval=10
...@@ -4,7 +4,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" ...@@ -4,7 +4,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
export FLAGS_fraction_of_gpu_memory_to_use=0.80 export FLAGS_fraction_of_gpu_memory_to_use=0.80
python3.7 -m paddle.distributed.launch \ python3.7 -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \ --gpus="0,1,2,3" \
tools/static/train.py \ tools/static/train.py \
-c ./configs/ResNet/ResNet50.yaml \ -c ./configs/ResNet/ResNet50.yaml \
-o print_interval=10 \ -o print_interval=10 \
......
...@@ -92,8 +92,9 @@ def main(args): ...@@ -92,8 +92,9 @@ def main(args):
# 2. validate with validate dataset # 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
net.eval() net.eval()
top1_acc = program.run(valid_dataloader, config, net, None, None, with paddle.no_grad():
epoch_id, 'valid') top1_acc = program.run(valid_dataloader, config, net, None,
None, epoch_id, 'valid')
if top1_acc > best_top1_acc: if top1_acc > best_top1_acc:
best_top1_acc = top1_acc best_top1_acc = top1_acc
best_top1_epoch = epoch_id best_top1_epoch = epoch_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册