未验证 提交 838197ea 编写于 作者: D dyning 提交者: GitHub

Merge pull request #92 from WuHaobo/polish_googlenet

polish program for googlenet
......@@ -157,7 +157,11 @@ def create_loss(out,
return loss(out, target)
def create_metric(out, feeds, topk=5, classes_num=1000,
def create_metric(out,
feeds,
architecture,
topk=5,
classes_num=1000,
use_distillation=False):
"""
Create measures of model accuracy, such as top1 and top5
......@@ -171,16 +175,22 @@ def create_metric(out, feeds, topk=5, classes_num=1000,
Returns:
fetchs(dict): dict of measures
"""
# just need student label to get metrics
if use_distillation:
out = out[1]
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
softmax_out = out[0]
else:
# just need student label to get metrics
if use_distillation:
out = out[1]
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
fetchs = OrderedDict()
label = feeds['label']
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
top1 = fluid.layers.accuracy(softmax_out, label=label, k=1)
# set top1 to fetchs
top1 = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=1)
fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
# set topk to fetchs
k = min(topk, classes_num)
topk = fluid.layers.accuracy(softmax_out, label=label, k=k)
topk = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=k)
topk_name = 'top{}'.format(k)
fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
......@@ -201,7 +211,8 @@ def create_fetchs(out,
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information,
name(such as ResNet50) is needed
topk(int): usually top5
......@@ -217,7 +228,8 @@ def create_fetchs(out,
use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
if not use_mix:
metric = create_metric(out, feeds, topk, classes_num, use_distillation)
metric = create_metric(out, feeds, architecture, topk, classes_num,
use_distillation)
fetchs.update(metric)
return fetchs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册