From 2d6233646a753487d4ecf8a570683803c01c1d10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Thu, 14 Sep 2017 09:42:12 +0800 Subject: [PATCH] Accuracy op (#3907) * init add * add topk op * someupdate * fix style check * add test py file * update top k cuda kernel * follow comments * remove debug print * accuracy_op * fix casting error * fix casting error * fix casting error * fix rename bug... * make it smaller * update cast --- paddle/operators/accuracy_op.cc | 66 ++++++++++++++++ paddle/operators/accuracy_op.cu | 69 +++++++++++++++++ paddle/operators/accuracy_op.h | 77 +++++++++++++++++++ paddle/pybind/pybind.cc | 1 + .../v2/framework/tests/test_accuracy_op.py | 25 ++++++ 5 files changed, 238 insertions(+) create mode 100644 paddle/operators/accuracy_op.cc create mode 100644 paddle/operators/accuracy_op.cu create mode 100644 paddle/operators/accuracy_op.h create mode 100644 python/paddle/v2/framework/tests/test_accuracy_op.py diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc new file mode 100644 index 00000000000..9ca04d40287 --- /dev/null +++ b/paddle/operators/accuracy_op.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/accuracy_op.h" + +namespace paddle { +namespace operators { + +class AccuracyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input of Inference must be initialized."); + auto *inference = ctx.Input("Inference"); + auto *label = ctx.Input("Label"); + + PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector"); + PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0], + "inference size must be the same as label size"); + + ctx.Output("Accuracy")->Resize({1}); + } +}; + +class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AccuracyOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + // TODO(typhoonzero): support both inference value and indices. + AddInput("Inference", "topk(indices) the network output"); + AddInput("Label", "Label of the training data"); + // TODO(typhoonzero): AddInput("Weight", ... + AddOutput("Accuracy", "The accuracy of current batch"); + + AddComment( + R"DOC(Accuracy. It will print accuracy rate for classification. +The accuracy is: +.. math:: +accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker); +REGISTER_OP_CPU_KERNEL(accuracy, + ops::AccuracyKernel); diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu new file mode 100644 index 00000000000..4e6d1ef9654 --- /dev/null +++ b/paddle/operators/accuracy_op.cu @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/accuracy_op.h" + +namespace paddle { +namespace operators { + +__global__ void AccuracySingleKernel(const int N, const int D, const int top_k, + const int* Xdata, const int* labelData, + float* accuracy) { + int correct = 0; + for (int row = 0; row < N; row++) { + const int label = labelData[row]; + for (int col = 0; col < D; col++) { + const int pred = Xdata[row * D + col]; + if (pred == label) { + ++correct; + break; + } + } + } + *accuracy = static_cast(correct) / static_cast(N); +} + +template +class AccuracyOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + auto* inference = ctx.Input("Inference"); + auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + // FIXME(typhoonzero): only support indices currently + // if add support for output values, how to detect the data type? + const int* inference_data = inference->data(); + const int* label_data = label->data(); + float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); + + size_t num_samples = inference->dims()[0]; + size_t infer_width = inference->dims()[1]; + cudaMemset((void**)&accuracy_data, 0, sizeof(float)); + + if (num_samples == 0) { + return; + } + + AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data, + label_data, accuracy_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_GPU_KERNEL(accuracy, + paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/operators/accuracy_op.h b/paddle/operators/accuracy_op.h new file mode 100644 index 00000000000..fe704efe1c9 --- /dev/null +++ b/paddle/operators/accuracy_op.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + +template +using EigenScalar = framework::EigenScalar; + +template +class AccuracyKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* inference = ctx.Input("Inference"); + auto* label = ctx.Input("Label"); + auto* accuracy = ctx.Output("Accuracy"); + + float* accuracy_data = accuracy->mutable_data(ctx.GetPlace()); + + const T* inference_data = inference->data(); + const T* label_data = label->data(); + + size_t num_samples = inference->dims()[0]; + size_t class_dim = inference->dims()[1]; + *accuracy_data = 0.0f; + + if (num_samples == 0) { + return; + } + + int num_correct = 0; + // assume inference is already the topk of the output + for (size_t i = 0; i < num_samples; ++i) { + PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0"); + for (size_t j = 0; j < class_dim; ++j) { + if (inference_data[i * class_dim + j] == label_data[i]) { + ++num_correct; + break; + } + } + } + + // FIXME(typhoonzero): we don't accumulate the accuracy for now. + *accuracy_data = + static_cast(num_correct) / static_cast(num_samples); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fe1e50927a4..b0979767f80 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -52,6 +52,7 @@ USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_OP(pad); USE_CPU_ONLY_OP(scatter); +USE_OP(accuracy); USE_CPU_ONLY_OP(concat); USE_OP(top_k); USE_OP(squared_l2_distance); diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py new file mode 100644 index 00000000000..43d60eb90d5 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -0,0 +1,25 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAccuracyOp(OpTest): + def setUp(self): + self.op_type = "accuracy" + infer = np.random.randint(0, 2, (32, 1)).astype("int") + label = np.random.randint(0, 2, (32, )).astype("int") + self.inputs = {'Inference': infer, "Label": label} + num_correct = 0 + for rowid in xrange(32): + for ele in infer[rowid]: + if ele == label[rowid]: + num_correct += 1 + break + self.outputs = {'Accuracy': [num_correct / 32.0]} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main() -- GitLab