提交 233a305b 编写于 作者: D Dong Zhihong

"need to write math functors"

上级 83e65005
from paddle.v2.framework.framework import Program, g_program, g_init_program from paddle.v2.framework.framework import Program, unique_name
from paddle.v2.framework.layer_helper import LayerHelper
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
...@@ -11,24 +12,14 @@ class Evaluator(object): ...@@ -11,24 +12,14 @@ class Evaluator(object):
add increment operator to accumulate the metric states add increment operator to accumulate the metric states
""" """
def __init__(self, input=None, **kwargs): def __init__(self, evaluator_type, **kwargs):
if "program" in kwargs:
self._program = kwargs.get("program")
else:
self._program = input.program
self._states = [] self._states = []
self._helper = LayerHelper(layer_type=evaluator_type, **kwargs)
def _create_tmp_variable(self, name, dtype):
return self.program.current_block().create_var(
name=unique_name(".".join([self.name, 'tmp'])),
dtype=dtype,
persistable=False)
@staticmethod @staticmethod
def clear(self): def clear(self):
""" """
clear metric states at the begin of each pass/user specified batch clear metric states at the begin of each pass/user specified batch
return a clear
""" """
raise NotImplementedError() raise NotImplementedError()
...@@ -41,7 +32,18 @@ class Evaluator(object): ...@@ -41,7 +32,18 @@ class Evaluator(object):
class Accuracy(Evaluator): class Accuracy(Evaluator):
def __init__(self, input, label, k=1, **kwargs): def __init__(self, input, label, k=1, **kwargs):
super(Accuracy, self).__init__(input=input, **kwargs) super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = helper.create_global_variable(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
topk_out = helper.create_tmp_variable(dtype=input.data_type) topk_out = helper.create_tmp_variable(dtype=input.data_type)
topk_indices = helper.create_tmp_variable(dtype="int64") topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op( helper.append_op(
...@@ -60,3 +62,9 @@ class Accuracy(Evaluator): ...@@ -60,3 +62,9 @@ class Accuracy(Evaluator):
"Label": [label] "Label": [label]
}, },
outputs={"Accuracy": [acc_out]}) outputs={"Accuracy": [acc_out]})
helper.append_op(
type="sum", inputs={"X": [g_total, ], },
outputs={"Out": [g_total]})
return acc_out
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册