提交 2d623364 编写于 作者: 武毅 提交者: GitHub

Accuracy op (#3907)

* init add

* add topk op

* someupdate

* fix style check

* add test py file

* update top k cuda kernel

* follow comments

* remove debug print

* accuracy_op

* fix casting error

* fix casting error

* fix casting error

* fix rename bug...

* make it smaller

* update cast
上级 b3f6b5a9
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/accuracy_op.h"
namespace paddle {
namespace operators {
class AccuracyOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"),
"Input of Inference must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input of Inference must be initialized.");
auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *label = ctx.Input<framework::Tensor>("Label");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
"inference size must be the same as label size");
ctx.Output<Tensor>("Accuracy")->Resize({1});
}
};
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AccuracyOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
// TODO(typhoonzero): support both inference value and indices.
AddInput("Inference", "topk(indices) the network output");
AddInput("Label", "Label of the training data");
// TODO(typhoonzero): AddInput("Weight", ...
AddOutput("Accuracy", "The accuracy of current batch");
AddComment(
R"DOC(Accuracy. It will print accuracy rate for classification.
The accuracy is:
.. math::
accuracy = \\frac{NumOfCorrectPredicts}{NumOfAllSamples})DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(accuracy, ops::AccuracyOp, ops::AccuracyOpMaker);
REGISTER_OP_CPU_KERNEL(accuracy,
ops::AccuracyKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/accuracy_op.h"
namespace paddle {
namespace operators {
__global__ void AccuracySingleKernel(const int N, const int D, const int top_k,
const int* Xdata, const int* labelData,
float* accuracy) {
int correct = 0;
for (int row = 0; row < N; row++) {
const int label = labelData[row];
for (int col = 0; col < D; col++) {
const int pred = Xdata[row * D + col];
if (pred == label) {
++correct;
break;
}
}
}
*accuracy = static_cast<float>(correct) / static_cast<float>(N);
}
template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto* inference = ctx.Input<Tensor>("Inference");
auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy");
// FIXME(typhoonzero): only support indices currently
// if add support for output values, how to detect the data type?
const int* inference_data = inference->data<int>();
const int* label_data = label->data<int>();
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
size_t num_samples = inference->dims()[0];
size_t infer_width = inference->dims()[1];
cudaMemset((void**)&accuracy_data, 0, sizeof(float));
if (num_samples == 0) {
return;
}
AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data,
label_data, accuracy_data);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(accuracy,
paddle::operators::AccuracyOpCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
auto* label = ctx.Input<Tensor>("Label");
auto* accuracy = ctx.Output<Tensor>("Accuracy");
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
const T* inference_data = inference->data<T>();
const T* label_data = label->data<T>();
size_t num_samples = inference->dims()[0];
size_t class_dim = inference->dims()[1];
*accuracy_data = 0.0f;
if (num_samples == 0) {
return;
}
int num_correct = 0;
// assume inference is already the topk of the output
for (size_t i = 0; i < num_samples; ++i) {
PADDLE_ENFORCE_GE(label_data[i], 0, "label must >= 0");
for (size_t j = 0; j < class_dim; ++j) {
if (inference_data[i * class_dim + j] == label_data[i]) {
++num_correct;
break;
}
}
}
// FIXME(typhoonzero): we don't accumulate the accuracy for now.
*accuracy_data =
static_cast<float>(num_correct) / static_cast<float>(num_samples);
}
};
} // namespace operators
} // namespace paddle
......@@ -52,6 +52,7 @@ USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_OP(pad);
USE_CPU_ONLY_OP(scatter);
USE_OP(accuracy);
USE_CPU_ONLY_OP(concat);
USE_OP(top_k);
USE_OP(squared_l2_distance);
......
import unittest
import numpy as np
from op_test import OpTest
class TestAccuracyOp(OpTest):
def setUp(self):
self.op_type = "accuracy"
infer = np.random.randint(0, 2, (32, 1)).astype("int")
label = np.random.randint(0, 2, (32, )).astype("int")
self.inputs = {'Inference': infer, "Label": label}
num_correct = 0
for rowid in xrange(32):
for ele in infer[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {'Accuracy': [num_correct / 32.0]}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册