diff --git a/examples/bmn/bmn_metric.py b/examples/bmn/bmn_metric.py index f9bf101f825913572803fbb1168260f83a0d96ac..d8e0d3e3ae307c9fa61000e336b4ef6864f956f2 100644 --- a/examples/bmn/bmn_metric.py +++ b/examples/bmn/bmn_metric.py @@ -39,11 +39,6 @@ class BmnMetric(Metric): elif self.mode == 'infer': self.get_infer_dataset_dict() - def add_metric_op(self, preds, label): - pred_bm, pred_start, pred_en = preds - video_index = label[-1] - return [pred_bm, pred_start, pred_en, video_index] #return list - def update(self, pred_bm, pred_start, pred_end, fid): # generate proposals pred_start = pred_start[0] diff --git a/hapi/metrics.py b/hapi/metrics.py index 3350853677b62275bb0107addff3f3b3780ea81c..1d24c4ada2e77bba0df59cad75dd0fac3842f80c 100644 --- a/hapi/metrics.py +++ b/hapi/metrics.py @@ -48,9 +48,16 @@ class Metric(object): format(self.__class__.__name__)) @abc.abstractmethod - def update(self, *args, **kwargs): + def update(self, *args): """ Update states for metric + + Inputs of :code:`update` is the outputs of :code:`Metric.add_metric_op`, + if :code:`add_metric_op` is not defined, the inputs of :code:`update` + will be flatten arguments of **output** of mode and **label** from data: + :code:`update(output1, output2, ..., label1, label2,...)` + + see :code:`Metric.add_metric_op` """ raise NotImplementedError("function 'update' not implemented in {}.". format(self.__class__.__name__)) @@ -72,11 +79,26 @@ class Metric(object): raise NotImplementedError("function 'name' not implemented in {}.". format(self.__class__.__name__)) - def add_metric_op(self, pred, label): + def add_metric_op(self, *args): """ - Add process op for metric in program + This API is advanced usage to accelerate metric calculating, calulations + from outputs of model to the states which should be updated by Metric can + be defined here, where Paddle OPs is also supported. Outputs of this API + will be the inputs of "Metric.update". + + If :code:`add_metric_op` is defined, it will be called with **outputs** + of model and **labels** from data as arguments, all outputs and labels + will be concatenated and flatten and each filed as a separate argument + as follows: + :code:`add_metric_op(output1, output2, ..., label1, label2,...)` + + If :code:`add_metric_op` is not defined, default behaviour is to pass + input to output, so output format will be: + :code:`return output1, output2, ..., label1, label2,...` + + see :code:`Metric.update` """ - return pred, label + return args class Accuracy(Metric): @@ -91,12 +113,12 @@ class Accuracy(Metric): self._init_name(name) self.reset() - def add_metric_op(self, pred, label, *args, **kwargs): - pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk] - correct = pred == label[0] + def add_metric_op(self, pred, label, *args): + pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk] + correct = pred == label return correct - def update(self, correct, *args, **kwargs): + def update(self, correct, *args): accs = [] for i, k in enumerate(self.topk): num_corrects = correct[:, :k].sum() diff --git a/hapi/model.py b/hapi/model.py index e2fa2d6e3ba900e783421d0e8b29fa9a3aad5813..f4e6744df5107d345c873f6fa45269f704615708 100644 --- a/hapi/model.py +++ b/hapi/model.py @@ -67,7 +67,7 @@ def to_list(value): if value is None: return value if isinstance(value, (list, tuple)): - return value + return list(value) return [value] @@ -473,7 +473,7 @@ class StaticGraphAdapter(object): if mode != 'test': for metric in self.model._metrics: metrics.append( - to_list(metric.add_metric_op(outputs, labels))) + to_list(metric.add_metric_op(*(outputs + labels)))) if mode == 'train' and self.model._optimizer: self._loss_endpoint = fluid.layers.sum(losses) @@ -593,7 +593,7 @@ class DynamicGraphAdapter(object): metrics = [] for metric in self.model._metrics: metric_outs = metric.add_metric_op( - to_list(outputs), to_list(labels)) + *(to_list(outputs) + to_list(labels))) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m) @@ -632,7 +632,8 @@ class DynamicGraphAdapter(object): self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_batch'] = samples - metric_outs = metric.add_metric_op(to_list(outputs), labels) + metric_outs = metric.add_metric_op( + *(to_list(outputs) + to_list(labels))) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m)