未验证 提交 838197ea 编写于 作者: D dyning 提交者: GitHub

Merge pull request #92 from WuHaobo/polish_googlenet

polish program for googlenet
...@@ -157,7 +157,11 @@ def create_loss(out, ...@@ -157,7 +157,11 @@ def create_loss(out,
return loss(out, target) return loss(out, target)
def create_metric(out, feeds, topk=5, classes_num=1000, def create_metric(out,
feeds,
architecture,
topk=5,
classes_num=1000,
use_distillation=False): use_distillation=False):
""" """
Create measures of model accuracy, such as top1 and top5 Create measures of model accuracy, such as top1 and top5
...@@ -171,16 +175,22 @@ def create_metric(out, feeds, topk=5, classes_num=1000, ...@@ -171,16 +175,22 @@ def create_metric(out, feeds, topk=5, classes_num=1000,
Returns: Returns:
fetchs(dict): dict of measures fetchs(dict): dict of measures
""" """
# just need student label to get metrics if architecture["name"] == "GoogLeNet":
if use_distillation: assert len(out) == 3, "GoogLeNet should have 3 outputs"
out = out[1] softmax_out = out[0]
else:
# just need student label to get metrics
if use_distillation:
out = out[1]
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
fetchs = OrderedDict() fetchs = OrderedDict()
label = feeds['label'] # set top1 to fetchs
softmax_out = fluid.layers.softmax(out, use_cudnn=False) top1 = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=1)
top1 = fluid.layers.accuracy(softmax_out, label=label, k=1)
fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True)) fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
# set topk to fetchs
k = min(topk, classes_num) k = min(topk, classes_num)
topk = fluid.layers.accuracy(softmax_out, label=label, k=k) topk = fluid.layers.accuracy(softmax_out, label=feeds['label'], k=k)
topk_name = 'top{}'.format(k) topk_name = 'top{}'.format(k)
fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True)) fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
...@@ -201,7 +211,8 @@ def create_fetchs(out, ...@@ -201,7 +211,8 @@ def create_fetchs(out,
Args: Args:
out(variable): model output variable out(variable): model output variable
feeds(dict): dict of model input variables(included label) feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information, architecture(dict): architecture information,
name(such as ResNet50) is needed name(such as ResNet50) is needed
topk(int): usually top5 topk(int): usually top5
...@@ -217,7 +228,8 @@ def create_fetchs(out, ...@@ -217,7 +228,8 @@ def create_fetchs(out,
use_distillation) use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True)) fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
if not use_mix: if not use_mix:
metric = create_metric(out, feeds, topk, classes_num, use_distillation) metric = create_metric(out, feeds, architecture, topk, classes_num,
use_distillation)
fetchs.update(metric) fetchs.update(metric)
return fetchs return fetchs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册