diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 4e6d1ef9654012ce6355cbd7561c4fdc1785c11a..0a6a0fd15c73330902552f7a9aa6339de24c1a18 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -12,26 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/operators/accuracy_op.h" +#include "paddle/platform/cuda_helper.h" namespace paddle { namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; -__global__ void AccuracySingleKernel(const int N, const int D, const int top_k, - const int* Xdata, const int* labelData, - float* accuracy) { - int correct = 0; - for (int row = 0; row < N; row++) { - const int label = labelData[row]; - for (int col = 0; col < D; col++) { - const int pred = Xdata[row * D + col]; - if (pred == label) { - ++correct; +template +__global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, + const int* labeldata, float* accuracy) { + int count = 0; + __shared__ int total[BlockSize]; + + // support only 1 block + for (int i = threadIdx.x; i < (N); i += BlockSize) { + for (int j = 0; j < D; ++j) { + if (Xdata[i * D + j] == labeldata[i]) { + ++count; break; } } } - *accuracy = static_cast(correct) / static_cast(N); + total[threadIdx.x] = count; + __syncthreads(); + + // reduce the count with init value 0, and output accuracy. + int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); + if (threadIdx.x == 0) { + *accuracy = static_cast(result) / static_cast(N); + } } template @@ -57,8 +69,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { return; } - AccuracySingleKernel<<<1, 1>>>(num_samples, infer_width, 1, inference_data, - label_data, accuracy_data); + AccuracyCudaKernel<<<1, PADDLE_CUDA_NUM_THREADS>>>( + num_samples, infer_width, inference_data, label_data, accuracy_data); } }; diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h index 6feec0d7f8bd5d32d9e5eedee962fcbeff655f1c..a7d99cde106a0a66f122a8c43f49717c03e60dec 100644 --- a/paddle/platform/cuda_helper.h +++ b/paddle/platform/cuda_helper.h @@ -24,6 +24,11 @@ namespace platform { #define USE_CUDA_ATOMIC(op, T) \ CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); } +// Default thread count per block(or block size). +// TODO(typhoonzero): need to benchmark against setting this value +// to 1024. +constexpr int PADDLE_CUDA_NUM_THREADS = 512; + // For atomicAdd. USE_CUDA_ATOMIC(Add, float); diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index 43d60eb90d5edbd6944a11f7555f0291720dd2be..b6f3a35d6f58ba90b39e3f6296ae635220a2e965 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -6,16 +6,17 @@ from op_test import OpTest class TestAccuracyOp(OpTest): def setUp(self): self.op_type = "accuracy" - infer = np.random.randint(0, 2, (32, 1)).astype("int") - label = np.random.randint(0, 2, (32, )).astype("int") + n = 8192 + infer = np.random.randint(0, 2, (n, 1)).astype("int") + label = np.random.randint(0, 2, (n, )).astype("int") self.inputs = {'Inference': infer, "Label": label} num_correct = 0 - for rowid in xrange(32): + for rowid in xrange(n): for ele in infer[rowid]: if ele == label[rowid]: num_correct += 1 break - self.outputs = {'Accuracy': [num_correct / 32.0]} + self.outputs = {'Accuracy': [num_correct / float(n)]} def test_check_output(self): self.check_output()