未验证 提交 f122a5da 编写于 作者: F fengjiayi 提交者: GitHub

Add accuracy layer (#4958)

* Complete accuray layer

* Fix error

* Fix error

* Add 'accuracy' to __all__

* update

* Fix Type error

* Fix error

* Refine unit tests

* Fix an unit test error
上级 8b1c50c6
......@@ -32,7 +32,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"inference size must be the same as label size");
......@@ -68,7 +69,8 @@ information, or not. But the output only shares the LoD with input `Inference`.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker);
REGISTER_OPERATOR(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
accuracy, ops::AccuracyKernel<paddle::platform::CPUPlace, int>,
ops::AccuracyKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -52,7 +52,11 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of Topk op");
AddOutput("Indices", "The indices of Topk elements of input");
AddComment(
R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
R"DOC(If the input is a vector (1d tensor),
finds the k largest entries in the vector
and outputs their values and indices as vectors.
Thus values[j] is the j-th largest entry in input,
and its index is indices[j].
For matrices, computes the top k entries in each row. )DOC");
AddAttr<int>("k",
......@@ -66,6 +70,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(top_k,
ops::TopkKernel<paddle::platform::CPUPlace, float>);
......@@ -5,7 +5,7 @@ import re
__all__ = [
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool'
'StaticRNN', 'cast', 'sequence_conv', 'sequence_pool', 'accuracy'
]
......@@ -229,6 +229,26 @@ def square_error_cost(input, label, **kwargs):
return square_out
def accuracy(input, label, k=1, **kwargs):
helper = LayerHelper("accuracy", **kwargs)
topk_out = helper.create_tmp_variable(dtype=input.data_type)
topk_indices = helper.create_tmp_variable(dtype="int64")
helper.append_op(
type="top_k",
inputs={"X": [input]},
outputs={"Out": [topk_out],
"Indices": [topk_indices]},
attrs={"k": k})
acc_out_dtype = kwargs.get("out_dtype", "float32")
acc_out = helper.create_tmp_variable(dtype=acc_out_dtype)
helper.append_op(
type="accuracy",
inputs={"Inference": [topk_indices],
"Label": [label]},
outputs={"Accuracy": [acc_out]})
return acc_out
def sequence_conv(input,
num_filters,
name=None,
......
......@@ -8,12 +8,12 @@ class TestAccuracyOp(OpTest):
self.op_type = "accuracy"
n = 8192
infer = np.random.randint(0, 2, (n, 1)).astype("int")
label = np.random.randint(0, 2, (n, )).astype("int")
label = np.random.randint(0, 2, (n, 1)).astype("int")
self.inputs = {'Inference': infer, "Label": label}
num_correct = 0
for rowid in xrange(n):
for ele in infer[rowid]:
if ele == label[rowid]:
if ele == label[rowid][0]:
num_correct += 1
break
self.outputs = {
......
......@@ -51,12 +51,14 @@ predict = layers.fc(input=conv_pool_2,
cost = layers.cross_entropy(
input=predict, label=label, program=program, init_program=init_program)
avg_cost = layers.mean(x=cost, program=program)
accuracy = layers.accuracy(
input=predict, label=label, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 50
PASS_NUM = 1
PASS_NUM = 3
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
......@@ -83,10 +85,11 @@ for pass_id in range(PASS_NUM):
outs = exe.run(program,
feed={"pixel": tensor_img,
"label": tensor_y},
fetch_list=[avg_cost])
fetch_list=[avg_cost, accuracy])
loss = np.array(outs[0])
acc = np.array(outs[1])
if loss < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.
if loss < 10.0 and acc > 0.9:
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
exit(0)
exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册