未验证 提交 e97b8987 编写于 作者: D dzhwinter 提交者: GitHub

"fix accuracy kernel bug" (#5673)

* "fix accuracy kernel bug"

* "relauch ci"
上级 f95c291b
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <thrust/reduce.h>
#include "paddle/operators/accuracy_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace operators {
......@@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
int num_samples = static_cast<int>(inference->dims()[0]);
size_t infer_width = inference->dims()[1];
PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float)));
// cudaMemset((void**)&correct_data, 0, sizeof(float));
auto stream = ctx.cuda_device_context().stream();
platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream);
if (num_samples == 0) {
return;
}
cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice);
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
AccuracyCudaKernel<
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data);
int d_num_samples, d_num_correct;
float d_accuracy;
cudaMemcpy(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost);
cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost);
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
}
};
......
......@@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer");
}
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream),
"cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync");
}
} // namespace platform
} // namespace paddle
......@@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count,
void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
size_t count, cudaStream_t stream);
//! Set memory dst with value count size asynchronously
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);
} // namespace platform
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册