From 6308ccc265247974c9ab253948fbb7b90c77d087 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 8 Nov 2017 13:03:57 +0800 Subject: [PATCH] fix accuracy cudamemset --- paddle/operators/accuracy_op.cu | 4 +++- python/paddle/v2/framework/tests/test_accuracy_op.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index d0c4c0d25d..ccb2c06c22 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include "paddle/operators/accuracy_op.h" #include "paddle/platform/cuda_helper.h" @@ -65,7 +66,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { size_t num_samples = inference->dims()[0]; size_t infer_width = inference->dims()[1]; - cudaMemset((void**)&accuracy_data, 0, sizeof(float)); + cudaError_t e = cudaMemset(accuracy_data, 0, sizeof(float)); + PADDLE_ENFORCE_EQ(0, e, "cudaMemset error"); if (num_samples == 0) { return; diff --git a/python/paddle/v2/framework/tests/test_accuracy_op.py b/python/paddle/v2/framework/tests/test_accuracy_op.py index 85eabdcfb8..6536c297e8 100644 --- a/python/paddle/v2/framework/tests/test_accuracy_op.py +++ b/python/paddle/v2/framework/tests/test_accuracy_op.py @@ -26,5 +26,4 @@ class TestAccuracyOp(OpTest): if __name__ == '__main__': - exit(0) unittest.main() -- GitLab