From f122a5da2f27038b48f6ed607e296d762050e920 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 30 Oct 2017 19:35:22 -0700 Subject: [PATCH] Add accuracy layer (#4958) * Complete accuray layer * Fix error * Fix error * Add 'accuracy' to __all__ * update * Fix Type error * Fix error * Refine unit tests * Fix an unit test error --- paddle/operators/accuracy_op.cc | 6 +++-- paddle/operators/top_k_op.cc | 9 ++++++-- python/paddle/v2/framework/layers.py | 22 ++++++++++++++++++- .../v2/framework/tests/test_accuracy_op.py | 4 ++-- .../tests/test_recognize_digits_conv.py | 13 ++++++----- 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index eb8bce8da70..88958e1634c 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -32,7 +32,8 @@ class AccuracyOp : public framework::OperatorWithKernel { auto inference_dim = ctx->GetInputDim("Inference"); auto label_dim = ctx->GetInputDim("Label"); - PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector"); + PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2."); + PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1"); PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0], "inference size must be the same as label size"); @@ -68,7 +69,8 @@ information, or not. But the output only shares the LoD with input `Inference`. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker); +REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( accuracy, ops::AccuracyKernel, ops::AccuracyKernel); diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index d5c2c91a5fb..ac925725954 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -52,7 +52,11 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "The output tensor of Topk op"); AddOutput("Indices", "The indices of Topk elements of input"); AddComment( - R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j]. + R"DOC(If the input is a vector (1d tensor), + finds the k largest entries in the vector + and outputs their values and indices as vectors. + Thus values[j] is the j-th largest entry in input, + and its index is indices[j]. For matrices, computes the top k entries in each row. )DOC"); AddAttr("k", @@ -66,6 +70,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker); +REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker, + paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(top_k, ops::TopkKernel); diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 70447e0d816..4727d139a28 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -5,7 +5,7 @@ import re __all__ = [ 'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat', - 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool' + 'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'accuracy' ] @@ -229,6 +229,26 @@ def square_error_cost(input, label, **kwargs): return square_out +def accuracy(input, label, k=1, **kwargs): + helper = LayerHelper("accuracy", **kwargs) + topk_out = helper.create_tmp_variable(dtype=input.data_type) + topk_indices = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="top_k", + inputs={"X": [input]}, + outputs={"Out": [topk_out], + "Indices": [topk_indices]}, + attrs={"k": k}) + acc_out_dtype = kwargs.get("out_dtype", "float32") + acc_out = helper.create_tmp_variable(dtype=acc_out_dtype) + helper.append_op( + type="accuracy", + inputs={"Inference": [topk_indices], + "Label": [label]}, + outputs={"Accuracy": [acc_out]}) + return acc_out + + def sequence_conv(input, num_filters, name=None, diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index 02be9a02910..f17edd44aef 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -8,12 +8,12 @@ class TestAccuracyOp(OpTest): self.op_type = "accuracy" n = 8192 infer = np.random.randint(0, 2, (n, 1)).astype("int") - label = np.random.randint(0, 2, (n, )).astype("int") + label = np.random.randint(0, 2, (n, 1)).astype("int") self.inputs = {'Inference': infer, "Label": label} num_correct = 0 for rowid in xrange(n): for ele in infer[rowid]: - if ele == label[rowid]: + if ele == label[rowid][0]: num_correct += 1 break self.outputs = { diff --git a/python/paddle/v2/framework/tests/test_recognize_digits_conv.py b/python/paddle/v2/framework/tests/test_recognize_digits_conv.py index a9b6c8410e2..92b1d054261 100644 --- a/python/paddle/v2/framework/tests/test_recognize_digits_conv.py +++ b/python/paddle/v2/framework/tests/test_recognize_digits_conv.py @@ -51,12 +51,14 @@ predict = layers.fc(input=conv_pool_2, cost = layers.cross_entropy( input=predict, label=label, program=program, init_program=init_program) avg_cost = layers.mean(x=cost, program=program) +accuracy = layers.accuracy( + input=predict, label=label, program=program, init_program=init_program) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) opts = sgd_optimizer.minimize(avg_cost) BATCH_SIZE = 50 -PASS_NUM = 1 +PASS_NUM = 3 train_reader = paddle.batch( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=500), @@ -83,10 +85,11 @@ for pass_id in range(PASS_NUM): outs = exe.run(program, feed={"pixel": tensor_img, "label": tensor_y}, - fetch_list=[avg_cost]) - + fetch_list=[avg_cost, accuracy]) loss = np.array(outs[0]) + acc = np.array(outs[1]) - if loss < 10.0: - exit(0) # if avg cost less than 10.0, we think our code is good. + if loss < 10.0 and acc > 0.9: + # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good. + exit(0) exit(1) -- GitLab