提交 8580dce3 编写于 作者: 武毅 提交者: GitHub

Refine accuracy_op CUDA kernel (#4097)

* refind accuracy_op

* follow comments

* follow comments
上级 59c48f98
...@@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h" #include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
__global__ void AccuracySingleKernel(const int N, const int D, const int top_k, template <int BlockSize>
const int* Xdata, const int* labelData, __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
float* accuracy) { const int* labeldata, float* accuracy) {
int correct = 0; int count = 0;
for (int row = 0; row < N; row++) { __shared__ int total[BlockSize];
const int label = labelData[row];
for (int col = 0; col < D; col++) { // support only 1 block
const int pred = Xdata[row * D + col]; for (int i = threadIdx.x; i < (N); i += BlockSize) {
if (pred == label) { for (int j = 0; j < D; ++j) {
++correct; if (Xdata[i * D + j] == labeldata[i]) {
++count;
break; break;
} }
} }
} }
*accuracy = static_cast<float>(correct) / static_cast<float>(N); total[threadIdx.x] = count;
__syncthreads();
// reduce the count with init value 0, and output accuracy.
int result = thrust::reduce(thrust::device, total, total + BlockSize, 0);
if (threadIdx.x == 0) {
*accuracy = static_cast<float>(result) / static_cast<float>(N);
}
} }
template <typename T> template <typename T>
...@@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { ...@@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
return; return;
} }
AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data, AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS>>>(
label_data, accuracy_data); num_samples, infer_width, inference_data, label_data, accuracy_data);
} }
}; };
......
...@@ -24,6 +24,11 @@ namespace platform { ...@@ -24,6 +24,11 @@ namespace platform {
#define USE_CUDA_ATOMIC(op, T) \ #define USE_CUDA_ATOMIC(op, T) \
CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
// Default thread count per block(or block size).
// TODO(typhoonzero): need to benchmark against setting this value
// to 1024.
constexpr int PADDLE_CUDA_NUM_THREADS = 512;
// For atomicAdd. // For atomicAdd.
USE_CUDA_ATOMIC(Add, float); USE_CUDA_ATOMIC(Add, float);
......
...@@ -6,16 +6,17 @@ from op_test import OpTest ...@@ -6,16 +6,17 @@ from op_test import OpTest
class TestAccuracyOp(OpTest): class TestAccuracyOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "accuracy" self.op_type = "accuracy"
infer = np.random.randint(0, 2, (32, 1)).astype("int") n = 8192
label = np.random.randint(0, 2, (32, )).astype("int") infer = np.random.randint(0, 2, (n, 1)).astype("int")
label = np.random.randint(0, 2, (n, )).astype("int")
self.inputs = {'Inference': infer, "Label": label} self.inputs = {'Inference': infer, "Label": label}
num_correct = 0 num_correct = 0
for rowid in xrange(32): for rowid in xrange(n):
for ele in infer[rowid]: for ele in infer[rowid]:
if ele == label[rowid]: if ele == label[rowid]:
num_correct += 1 num_correct += 1
break break
self.outputs = {'Accuracy': [num_correct / 32.0]} self.outputs = {'Accuracy': [num_correct / float(n)]}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册