From 903d5c7ec04e4074191875abaf58b418219561f2 Mon Sep 17 00:00:00 2001 From: He Date: Wed, 7 Sep 2016 16:01:36 +0800 Subject: [PATCH] bug fix for hl_matrix_classification_error --- paddle/cuda/src/hl_cuda_matrix.cu | 33 ++++++++++-------------- paddle/math/tests/test_matrixCompare.cpp | 33 +++++++++++++++++++++--- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 15799919fa1..fc003b7d637 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -266,25 +266,21 @@ template __global__ void KeMatrixClassificationError(real* in_A, int* in_B, real* out_C, - int dimM, int dimN) { __shared__ real max_s[blockSize]; __shared__ int max_l[blockSize]; - int cnt = (dimN + blockSize -1) / blockSize; - int tid = threadIdx.x; - int lmt = tid; - int index = 0; - real t; + const int tid = threadIdx.x; + const int rowId = blockIdx.x; max_s[tid] = -1e30f; - for (int ii = 0; ii < cnt && lmt < dimN; ii++) { - index = blockIdx.y*dimN + lmt; - t = in_A[index]; - if (max_s[tid] < t) { - max_s[tid] = t; - max_l[tid] = lmt; + in_A += rowId * dimN; + real tmp; + for (int colId = tid; colId < dimN; colId += blockSize) { + tmp = in_A[colId]; + if (max_s[tid] < tmp) { + max_s[tid] = tmp; + max_l[tid] = colId; } - lmt += blockSize; } __syncthreads(); @@ -300,7 +296,7 @@ __global__ void KeMatrixClassificationError(real* in_A, __syncthreads(); if (tid == 0) { - out_C[blockIdx.y] = (max_l[0] == in_B[blockIdx.y] ? 0 : 1.0f); + out_C[rowId] = (max_l[0] == in_B[rowId] ? 0 : 1.0f); } } @@ -313,12 +309,9 @@ void hl_matrix_classification_error(real* A_d, CHECK_NOTNULL(B_d); CHECK_NOTNULL(C_d); - int blocksX = 1; - int blocksY = dimM; - dim3 threads(1024, 1); - dim3 grid(blocksX, blocksY); - KeMatrixClassificationError<1024><<< grid, threads, 0, STREAM_DEFAULT >>> - (A_d, B_d, C_d, dimM, dimN); + // each sample is calculated by one block + KeMatrixClassificationError<1024><<< dimM, 1024, 0, STREAM_DEFAULT >>> + (A_d, B_d, C_d, dimN); CHECK_SYNC("hl_matrix_classification_error"); } diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 7caade444b8..fe8eacc2efb 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1697,7 +1697,6 @@ TEST(Matrix, cosSimDerivate) { } } - void testParamReluForward(int height, int width, int w_height, int w_width) { MatrixPtr output = CpuMatrix::create(height, width, false, false); @@ -1736,7 +1735,6 @@ TEST(Matrix, paramReluForward) { } } - void testParamReluBackwardW(int height, int width, int w_height, int w_width) { MatrixPtr oGrad = CpuMatrix::create(height, width, false, false); @@ -1775,7 +1773,6 @@ TEST(Matrix, paramReluBackwardW) { } } - void testParamReluBackwardDiff(int height, int width, int w_height, int w_width) { MatrixPtr oGrad = CpuMatrix::create(height, width, false, false); @@ -1819,6 +1816,36 @@ TEST(Matrix, paramReluBackwardDiff) { } } +void testClassificationError(int numSamples, int dim) { + MatrixPtr cpuError = std::make_shared(numSamples, 1); + MatrixPtr gpuError = std::make_shared(numSamples, 1); + MatrixPtr cpuOutput = std::make_shared(numSamples, dim); + MatrixPtr gpuOutput = std::make_shared(numSamples, dim); + IVectorPtr cpuLabel = std::make_shared(numSamples); + IVectorPtr gpuLabel = std::make_shared(numSamples); + + cpuOutput->randomizeUniform(); + cpuLabel->rand(dim); + gpuOutput->copyFrom(*cpuOutput); + gpuLabel->copyFrom(*cpuLabel); + + cpuError->classificationError(cpuOutput, cpuLabel); + gpuError->classificationError(gpuOutput, gpuLabel); + + MatrixPtr check = std::make_shared(numSamples, 1); + check->copyFrom(*gpuError); + MatrixCheckEqual(*cpuError, *check); +} + +TEST(Matrix, classificationError) { + for (auto numSamples : {1, 10, 100, 1000, 70000}) { + for (auto dim : {1, 10, 100, 1000}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim; + testClassificationError(numSamples, dim); + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); -- GitLab