diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index 47bcca0b79dbff02322a840f02abe1c5022503e3..7536aa6ea190d9e7df91383a74319ce75f50f3ef 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -1,4 +1,5 @@ -from paddle.v2.framework.framework import Program, g_program, g_init_program +from paddle.v2.framework.framework import Program, unique_name +from paddle.v2.framework.layer_helper import LayerHelper import paddle.v2.framework.core as core @@ -11,24 +12,14 @@ class Evaluator(object): add increment operator to accumulate the metric states """ - def __init__(self, input=None, **kwargs): - if "program" in kwargs: - self._program = kwargs.get("program") - else: - self._program = input.program + def __init__(self, evaluator_type, **kwargs): self._states = [] - - def _create_tmp_variable(self, name, dtype): - return self.program.current_block().create_var( - name=unique_name(".".join([self.name, 'tmp'])), - dtype=dtype, - persistable=False) + self._helper = LayerHelper(layer_type=evaluator_type, **kwargs) @staticmethod def clear(self): """ clear metric states at the begin of each pass/user specified batch - return a clear """ raise NotImplementedError() @@ -41,7 +32,18 @@ class Evaluator(object): class Accuracy(Evaluator): def __init__(self, input, label, k=1, **kwargs): - super(Accuracy, self).__init__(input=input, **kwargs) + super(Accuracy, self).__init__("accuracy", **kwargs) + g_total = helper.create_global_variable( + name=unique_name("Total"), + persistable=True, + dtype="int64", + shape=[1]) + g_correct = helper.create_global_variable( + name=unique_name("Correct"), + persistable=True, + dtype="int64", + shape=[1]) + topk_out = helper.create_tmp_variable(dtype=input.data_type) topk_indices = helper.create_tmp_variable(dtype="int64") helper.append_op( @@ -60,3 +62,9 @@ class Accuracy(Evaluator): "Label": [label] }, outputs={"Accuracy": [acc_out]}) + + helper.append_op( + type="sum", inputs={"X": [g_total, ], }, + outputs={"Out": [g_total]}) + + return acc_out diff --git a/python/paddle/v2/framework/math_ops.py b/python/paddle/v2/framework/math_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..408656a75d676def1a9c026578ea9886f1505151 --- /dev/null +++ b/python/paddle/v2/framework/math_ops.py @@ -0,0 +1,3 @@ +import paddle.v2.framework.core as core +from paddle.v2.framework.framework import OpProtoHolder, Variable, Program, \ + Operator