未验证 提交 14dba636 编写于 作者: Q Qi Li 提交者: GitHub

[ROCm] fix dcu error in device event base, test=develop (#41521)

* [ROCm] fix dcu error in device event base, test=develop

* fix, test=develop
上级 770ce7cf
......@@ -29,7 +29,7 @@ using ::paddle::platform::kCPU;
USE_EVENT(kCPU)
USE_EVENT_WAIT(kCPU, kCPU)
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
USE_EVENT(kCUDA);
USE_EVENT_WAIT(kCUDA, kCUDA)
USE_EVENT_WAIT(kCPU, kCUDA)
......
......@@ -15,7 +15,7 @@
#include "paddle/fluid/platform/device_event_base.h"
#include "paddle/fluid/platform/event.h"
#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
namespace paddle {
namespace platform {
struct CUDADeviceEventWrapper {
......
......@@ -75,6 +75,58 @@ TEST(DeviceEvent, CUDA) {
}
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
TEST(DeviceEvent, CUDA) {
VLOG(1) << "In Test";
using paddle::platform::CUDAPlace;
auto& pool = DeviceContextPool::Instance();
auto place = CUDAPlace(0);
auto* context =
static_cast<paddle::platform::CUDADeviceContext*>(pool.Get(place));
ASSERT_NE(context, nullptr);
// case 1. test for event_creator
DeviceEvent event(place);
ASSERT_NE(event.GetEvent().get(), nullptr);
bool status = event.Query();
ASSERT_EQ(status, true);
// case 2. test for event_recorder
event.Record(context);
status = event.Query();
ASSERT_EQ(status, false);
// case 3. test for event_finisher
event.Finish();
status = event.Query();
ASSERT_EQ(status, true);
// case 4. test for event_waiter
float *src_fp32, *dst_fp32;
int size = 1000000 * sizeof(float);
hipMallocHost(reinterpret_cast<void**>(&src_fp32), size);
hipMalloc(reinterpret_cast<void**>(&dst_fp32), size);
hipMemcpyAsync(dst_fp32, src_fp32, size, hipMemcpyHostToDevice,
context->stream());
event.Record(context); // step 1. record it
status = event.Query();
ASSERT_EQ(status, false);
event.Wait(kCUDA, context); // step 2. add streamWaitEvent
status = event.Query();
ASSERT_EQ(status, false); // async
event.Wait(kCPU, context); // step 3. EventSynchornize
status = event.Query();
ASSERT_EQ(status, true); // sync
// release resource
hipFree(dst_fp32);
hipFreeHost(src_fp32);
}
#endif
TEST(DeviceEvent, CPU) {
using paddle::platform::CPUPlace;
auto place = CPUPlace();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册