未验证 提交 e97b8987 编写于 作者: D dzhwinter 提交者: GitHub

"fix accuracy kernel bug" (#5673)

* "fix accuracy kernel bug"

* "relauch ci"
上级 f95c291b
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h" #include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
int num_samples = static_cast<int>(inference->dims()[0]); int num_samples = static_cast<int>(inference->dims()[0]);
size_t infer_width = inference->dims()[1]; size_t infer_width = inference->dims()[1];
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); auto stream = ctx.cuda_device_context().stream();
// cudaMemset((void**)&correct_data, 0, sizeof(float)); platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream);
if (num_samples == 0) { if (num_samples == 0) {
return; return;
} }
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice); platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<< AccuracyCudaKernel<
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data, num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data); accuracy_data);
int d_num_samples, d_num_correct; int d_num_samples, d_num_correct;
float d_accuracy; float d_accuracy;
cudaMemcpy(&d_num_correct, correct_data, sizeof(int), platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost, stream);
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost); platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float), cudaMemcpyDeviceToHost, stream);
cudaMemcpyDeviceToHost); platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
} }
}; };
......
...@@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, ...@@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"); "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer");
} }
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream),
"cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync");
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count, ...@@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count,
void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
size_t count, cudaStream_t stream); size_t count, cudaStream_t stream);
//! Set memory dst with value count size asynchronously
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册