From c09ad73c33533a120ecdc4aed71f676c11cd1c8f Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Mon, 6 Nov 2017 23:06:59 -0800 Subject: [PATCH] "add fit a line test" --- paddle/operators/accuracy_op.cc | 4 ++ paddle/operators/accuracy_op.h | 3 ++ python/paddle/v2/framework/evaluator.py | 47 ++++++++++++++----- .../v2/framework/tests/test_fit_a_line.py | 4 ++ 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 142883d9eae..f50e41bc410 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -32,6 +32,8 @@ class AccuracyOp : public framework::OperatorWithKernel { "Output (Accuracy) of AccuracyOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Correct"), "Output (Correct) of AccuracyOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Total"), + "Output (Total) of AccuracyOp should not be null."); auto inference_dim = ctx->GetInputDim("Out"); auto label_dim = ctx->GetInputDim("Label"); @@ -46,6 +48,7 @@ class AccuracyOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Accuracy", {1}); ctx->SetOutputDim("Correct", {1}); + ctx->SetOutputDim("Total", {1}); ctx->ShareLoD("Out", /*->*/ "Accuracy"); } @@ -69,6 +72,7 @@ class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { // TODO(typhoonzero): AddInput("Weight", ... AddOutput("Accuracy", "The accuracy of current batch"); AddOutput("Correct", "The correct samples count of current batch"); + AddOutput("Total", "The samples count of current batch"); AddComment(R"DOC( Accuracy. It will print accuracy rate for classification. diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h index cc0ea802f9a..e130d9a4ffb 100644 --- a/paddle/operators/accuracy_op.h +++ b/paddle/operators/accuracy_op.h @@ -43,9 +43,11 @@ class AccuracyKernel : public framework::OpKernel { auto* label = ctx.Input("Label"); auto* accuracy = ctx.Output("Accuracy"); auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); float* correct_data = correct->mutable_data(ctx.GetPlace()); int* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); + int* total_data = total->mutable_data(ctx.GetPlace()); const int64_t* indices_data = indices->data(); const int64_t* label_data = label->data(); @@ -71,6 +73,7 @@ class AccuracyKernel : public framework::OpKernel { } *correct_data = num_correct; + *total_data = num_samples; *accuracy_data = static_cast(num_correct) / static_cast(num_samples); } diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index 4d305f899bc..ba2a0618789 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -1,4 +1,4 @@ -from paddle.v2.framework.framework import Program, unique_name +from paddle.v2.framework.framework import Program, g_program, unique_name from paddle.v2.framework.layer_helper import LayerHelper import paddle.v2.framework.core as core @@ -13,8 +13,12 @@ class Evaluator(object): """ def __init__(self, name, **kwargs): - self._states = [] + self._states = {} self._helper = LayerHelper(layer_type=name, **kwargs) + # if kwargs.has_key("program"): + # self._program = kwargs.get("program") + # else: + # self._program = g_program # def _update(self): # """ @@ -22,12 +26,15 @@ class Evaluator(object): # """ # raise NotImplementedError() - def reset(self): + def reset(self, executor, program=None): """ Clear metric states at the begin of each pass/user specified batch """ - reset_program = Program() - for var in self._states: + if program == None: + reset_program = Program() + else: + reset_program = program + for k, var in self._states.iteritems(): zeros = helper.create_tmp_variable(dtype=var.data_type) self._helper.append_op( type="fill_constant", @@ -38,7 +45,7 @@ class Evaluator(object): }) self._helper.append_op( type="scale", inputs={"X": zeros}, outputs={"Out": var}) - return reset_program + executor.run(reset_program) def eval(self): """ @@ -64,8 +71,8 @@ class Accuracy(Evaluator): persistable=True, dtype="int64", shape=[1]) - self._states.append(g_total) - self._states.append(g_correct) + self._states["Total"] = g_total + self._states["Correct"] = g_correct topk_out = helper.create_tmp_variable(dtype=input.data_type) topk_indices = helper.create_tmp_variable(dtype="int64") @@ -86,18 +93,32 @@ class Accuracy(Evaluator): }, outputs={ "Accuracy": [acc_out], - "Correct": [tp_out], + "Correct": [correct], + "Total": [total], }) helper.append_op( type="sum", - inputs={"X": [g_total, tp_out]}, + inputs={"X": [g_total, total]}, + outputs={"Out": [g_total]}) + helper.append_op( + type="sum", + inputs={"X": [g_correct, correct]}, outputs={"Out": [g_total]}) return acc_out - def eval(self): - eval_program = Program() - g_total = self._program + def eval(self, executor, program=None): + if program == None: + eval_program = Program() + else: + eval_program = program + eval_out = helper.create_tmp_variable(dtype=self._helper.input_dtype()) + self._helper.append_op( + type="elementwise_div", + inputs={"X": self._states["Total"], + "Y": self._states["Correct"]}, + outputs={"Out": eval_out}) + return executor.run(eval_program, fetch_list=[eval_out]) # This is demo for composing low level op to compute metric diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/framework/tests/test_fit_a_line.py index 944240629ca..588e1d58822 100644 --- a/python/paddle/v2/framework/tests/test_fit_a_line.py +++ b/python/paddle/v2/framework/tests/test_fit_a_line.py @@ -6,6 +6,7 @@ import paddle.v2.framework.optimizer as optimizer from paddle.v2.framework.framework import Program, g_program from paddle.v2.framework.io import save_persistables, load_persistables from paddle.v2.framework.executor import Executor +from paddle.v2.framework.evaluator import Accuracy import numpy as np @@ -31,6 +32,8 @@ y = layers.data( program=program, init_program=init_program) +accuracy = evaluator.Accuracy(input=y_predict, label=y) + cost = layers.square_error_cost( input=y_predict, label=y, program=program, init_program=init_program) avg_cost = layers.mean(x=cost, program=program, init_program=init_program) @@ -54,6 +57,7 @@ PASS_NUM = 100 for pass_id in range(PASS_NUM): save_persistables(exe, "./fit_a_line.model/", program=program) load_persistables(exe, "./fit_a_line.model/", program=program) + exe.run(accuracy.eval(), ) for data in train_reader(): x_data = np.array(map(lambda x: x[0], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32") -- GitLab