未验证 提交 dd480273 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] refine custom device api (#50152)

上级 d8643cb6
...@@ -148,38 +148,41 @@ class CustomDevice : public DeviceInterface { ...@@ -148,38 +148,41 @@ class CustomDevice : public DeviceInterface {
stream::Stream::Flag::kDefaultFlag) override { stream::Stream::Flag::kDefaultFlag) override {
const auto device = &devices_pool[dev_id]; const auto device = &devices_pool[dev_id];
C_Stream c_stream; C_Stream c_stream;
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( if (pimpl_->create_stream) {
pimpl_->create_stream(device, &c_stream)); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->create_stream(device, &c_stream));
} else {
c_stream = nullptr;
}
stream->set_stream(c_stream); stream->set_stream(c_stream);
} }
void DestroyStream(size_t dev_id, stream::Stream* stream) override { void DestroyStream(size_t dev_id, stream::Stream* stream) override {
const auto device = &devices_pool[dev_id]; if (pimpl_->destroy_stream) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream()))); device, reinterpret_cast<C_Stream>(stream->raw_stream())));
}
} }
void SynchronizeStream(size_t dev_id, const stream::Stream* stream) override { void SynchronizeStream(size_t dev_id, const stream::Stream* stream) override {
const auto device = &devices_pool[dev_id]; if (pimpl_->synchronize_stream) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream()))); device, reinterpret_cast<C_Stream>(stream->raw_stream())));
}
} }
bool QueryStream(size_t dev_id, const stream::Stream* stream) override { bool QueryStream(size_t dev_id, const stream::Stream* stream) override {
const auto device = &devices_pool[dev_id];
if (!pimpl_->query_stream) { if (!pimpl_->query_stream) {
SynchronizeStream(dev_id, stream); SynchronizeStream(dev_id, stream);
return true; return true;
} else {
const auto device = &devices_pool[dev_id];
return pimpl_->query_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())) ==
C_SUCCESS;
} }
if (pimpl_->query_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())) ==
C_SUCCESS) {
return true;
}
return false;
} }
void AddCallback(size_t dev_id, void AddCallback(size_t dev_id,
...@@ -259,12 +262,14 @@ class CustomDevice : public DeviceInterface { ...@@ -259,12 +262,14 @@ class CustomDevice : public DeviceInterface {
void StreamWaitEvent(size_t dev_id, void StreamWaitEvent(size_t dev_id,
const stream::Stream* stream, const stream::Stream* stream,
const event::Event* event) override { const event::Event* event) override {
const auto device = &devices_pool[dev_id]; if (pimpl_->stream_wait_event) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event(
device, device,
reinterpret_cast<C_Stream>(stream->raw_stream()), reinterpret_cast<C_Stream>(stream->raw_stream()),
reinterpret_cast<C_Event>(event->raw_event()))); reinterpret_cast<C_Event>(event->raw_event())));
}
} }
void MemoryCopyH2D(size_t dev_id, void MemoryCopyH2D(size_t dev_id,
...@@ -279,7 +284,7 @@ class CustomDevice : public DeviceInterface { ...@@ -279,7 +284,7 @@ class CustomDevice : public DeviceInterface {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream()); C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_h2d(device, c_stream, dst, src, size)); pimpl_->async_memory_copy_h2d(device, c_stream, dst, src, size));
} else { } else if (pimpl_->memory_copy_h2d) {
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait(); pool.Get(place)->Wait();
...@@ -300,7 +305,7 @@ class CustomDevice : public DeviceInterface { ...@@ -300,7 +305,7 @@ class CustomDevice : public DeviceInterface {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream()); C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_d2h(device, c_stream, dst, src, size)); pimpl_->async_memory_copy_d2h(device, c_stream, dst, src, size));
} else { } else if (pimpl_->memory_copy_d2h) {
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait(); pool.Get(place)->Wait();
...@@ -321,7 +326,7 @@ class CustomDevice : public DeviceInterface { ...@@ -321,7 +326,7 @@ class CustomDevice : public DeviceInterface {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream()); C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_d2d(device, c_stream, dst, src, size)); pimpl_->async_memory_copy_d2d(device, c_stream, dst, src, size));
} else { } else if (pimpl_->memory_copy_d2d) {
paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait(); pool.Get(place)->Wait();
...@@ -455,24 +460,33 @@ class CustomDevice : public DeviceInterface { ...@@ -455,24 +460,33 @@ class CustomDevice : public DeviceInterface {
} }
void MemoryStats(size_t dev_id, size_t* total, size_t* free) override { void MemoryStats(size_t dev_id, size_t* total, size_t* free) override {
const auto device = &devices_pool[dev_id]; if (pimpl_->device_memory_stats) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->device_memory_stats(device, total, free)); pimpl_->device_memory_stats(device, total, free));
size_t used = *total - *free; size_t used = *total - *free;
VLOG(10) << Type() << " memory usage " << (used >> 20) << "M/" VLOG(10) << Type() << " memory usage " << (used >> 20) << "M/"
<< (*total >> 20) << "M, " << (*free >> 20) << (*total >> 20) << "M, " << (*free >> 20)
<< "M available to allocate"; << "M available to allocate";
} else {
*total = 0;
*free = 0;
}
} }
size_t GetMinChunkSize(size_t dev_id) override { size_t GetMinChunkSize(size_t dev_id) override {
const auto device = &devices_pool[dev_id]; if (pimpl_->device_min_chunk_size) {
const auto device = &devices_pool[dev_id];
size_t size = 0; size_t size = 0;
pimpl_->device_min_chunk_size(device, &size); pimpl_->device_min_chunk_size(device, &size);
VLOG(10) << Type() << " min chunk size " << size << "B"; VLOG(10) << Type() << " min chunk size " << size << "B";
return size; return size;
} else {
return 1;
}
} }
size_t GetMaxChunkSize(size_t dev_id) override { size_t GetMaxChunkSize(size_t dev_id) override {
...@@ -911,8 +925,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { ...@@ -911,8 +925,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(get_device, true); CHECK_INTERFACE(get_device, true);
CHECK_INTERFACE(deinit_device, false); CHECK_INTERFACE(deinit_device, false);
CHECK_INTERFACE(create_stream, true); CHECK_INTERFACE(create_stream, false);
CHECK_INTERFACE(destroy_stream, true); CHECK_INTERFACE(destroy_stream, false);
CHECK_INTERFACE(query_stream, false); CHECK_INTERFACE(query_stream, false);
CHECK_INTERFACE(stream_add_callback, false); CHECK_INTERFACE(stream_add_callback, false);
...@@ -922,9 +936,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { ...@@ -922,9 +936,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(query_event, false); CHECK_INTERFACE(query_event, false);
CHECK_INTERFACE(synchronize_device, false); CHECK_INTERFACE(synchronize_device, false);
CHECK_INTERFACE(synchronize_stream, true); CHECK_INTERFACE(synchronize_stream, false);
CHECK_INTERFACE(synchronize_event, true); CHECK_INTERFACE(synchronize_event, true);
CHECK_INTERFACE(stream_wait_event, true); CHECK_INTERFACE(stream_wait_event, false);
CHECK_INTERFACE(device_memory_allocate, true); CHECK_INTERFACE(device_memory_allocate, true);
CHECK_INTERFACE(device_memory_deallocate, true); CHECK_INTERFACE(device_memory_deallocate, true);
...@@ -932,9 +946,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { ...@@ -932,9 +946,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(host_memory_deallocate, false); CHECK_INTERFACE(host_memory_deallocate, false);
CHECK_INTERFACE(unified_memory_allocate, false); CHECK_INTERFACE(unified_memory_allocate, false);
CHECK_INTERFACE(unified_memory_deallocate, false); CHECK_INTERFACE(unified_memory_deallocate, false);
CHECK_INTERFACE(memory_copy_h2d, true); CHECK_INTERFACE(memory_copy_h2d, false);
CHECK_INTERFACE(memory_copy_d2h, true); CHECK_INTERFACE(memory_copy_d2h, false);
CHECK_INTERFACE(memory_copy_d2d, true); CHECK_INTERFACE(memory_copy_d2d, false);
CHECK_INTERFACE(memory_copy_p2p, false); CHECK_INTERFACE(memory_copy_p2p, false);
CHECK_INTERFACE(async_memory_copy_h2d, false); CHECK_INTERFACE(async_memory_copy_h2d, false);
CHECK_INTERFACE(async_memory_copy_d2h, false); CHECK_INTERFACE(async_memory_copy_d2h, false);
...@@ -943,9 +957,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { ...@@ -943,9 +957,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(get_device_count, true); CHECK_INTERFACE(get_device_count, true);
CHECK_INTERFACE(get_device_list, true); CHECK_INTERFACE(get_device_list, true);
CHECK_INTERFACE(device_memory_stats, true); CHECK_INTERFACE(device_memory_stats, false);
CHECK_INTERFACE(device_min_chunk_size, true); CHECK_INTERFACE(device_min_chunk_size, false);
CHECK_INTERFACE(device_max_chunk_size, false); CHECK_INTERFACE(device_max_chunk_size, false);
CHECK_INTERFACE(device_max_alloc_size, false); CHECK_INTERFACE(device_max_alloc_size, false);
CHECK_INTERFACE(device_extra_padding_size, false); CHECK_INTERFACE(device_extra_padding_size, false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册