提交 6308ccc2 编写于 作者: T typhoonzero

fix accuracy cudamemset

上级 ee82e660
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <iostream>
#include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"
......@@ -65,7 +66,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
size_t num_samples = inference->dims()[0];
size_t infer_width = inference->dims()[1];
cudaMemset((void**)&accuracy_data, 0, sizeof(float));
cudaError_t e = cudaMemset(accuracy_data, 0, sizeof(float));
PADDLE_ENFORCE_EQ(0, e, "cudaMemset error");
if (num_samples == 0) {
return;
......
......@@ -26,5 +26,4 @@ class TestAccuracyOp(OpTest):
if __name__ == '__main__':
exit(0)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册