提交 233a305b 编写于 作者: D Dong Zhihong

"need to write math functors"

上级 83e65005
from paddle.v2.framework.framework import Program, g_program, g_init_program
from paddle.v2.framework.framework import Program, unique_name
from paddle.v2.framework.layer_helper import LayerHelper
import paddle.v2.framework.core as core
......@@ -11,24 +12,14 @@ class Evaluator(object):
add increment operator to accumulate the metric states
"""
def __init__(self, input=None, **kwargs):
if "program" in kwargs:
self._program = kwargs.get("program")
else:
self._program = input.program
def __init__(self, evaluator_type, **kwargs):
self._states = []
def _create_tmp_variable(self, name, dtype):
return self.program.current_block().create_var(
name=unique_name(".".join([self.name, 'tmp'])),
dtype=dtype,
persistable=False)
self._helper = LayerHelper(layer_type=evaluator_type, **kwargs)
@staticmethod
def clear(self):
"""
clear metric states at the begin of each pass/user specified batch
return a clear
"""
raise NotImplementedError()
......@@ -41,7 +32,18 @@ class Evaluator(object):
class Accuracy(Evaluator):
def __init__(self, input, label, k=1, **kwargs):
super(Accuracy, self).__init__(input=input, **kwargs)
super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = helper.create_global_variable(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
topk_out = helper.create_tmp_variable(dtype=input.data_type)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
......@@ -60,3 +62,9 @@ class Accuracy(Evaluator):
"Label": [label]
},
outputs={"Accuracy": [acc_out]})
helper.append_op(
type="sum", inputs={"X": [g_total, ], },
outputs={"Out": [g_total]})
return acc_out
import paddle.v2.framework.core as core
from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \
Operator
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册