未验证 提交 dd480273 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] refine custom device api (#50152)

上级 d8643cb6
......@@ -148,38 +148,41 @@ class CustomDevice : public DeviceInterface {
stream::Stream::Flag::kDefaultFlag) override {
const auto device = &devices_pool[dev_id];
C_Stream c_stream;
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->create_stream(device, &c_stream));
if (pimpl_->create_stream) {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->create_stream(device, &c_stream));
} else {
c_stream = nullptr;
}
stream->set_stream(c_stream);
}
void DestroyStream(size_t dev_id, stream::Stream* stream) override {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())));
if (pimpl_->destroy_stream) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())));
}
}
void SynchronizeStream(size_t dev_id, const stream::Stream* stream) override {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())));
if (pimpl_->synchronize_stream) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())));
}
}
bool QueryStream(size_t dev_id, const stream::Stream* stream) override {
const auto device = &devices_pool[dev_id];
if (!pimpl_->query_stream) {
SynchronizeStream(dev_id, stream);
return true;
} else {
const auto device = &devices_pool[dev_id];
return pimpl_->query_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())) ==
C_SUCCESS;
}
if (pimpl_->query_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())) ==
C_SUCCESS) {
return true;
}
return false;
}
void AddCallback(size_t dev_id,
......@@ -259,12 +262,14 @@ class CustomDevice : public DeviceInterface {
void StreamWaitEvent(size_t dev_id,
const stream::Stream* stream,
const event::Event* event) override {
const auto device = &devices_pool[dev_id];
if (pimpl_->stream_wait_event) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event(
device,
reinterpret_cast<C_Stream>(stream->raw_stream()),
reinterpret_cast<C_Event>(event->raw_event())));
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event(
device,
reinterpret_cast<C_Stream>(stream->raw_stream()),
reinterpret_cast<C_Event>(event->raw_event())));
}
}
void MemoryCopyH2D(size_t dev_id,
......@@ -279,7 +284,7 @@ class CustomDevice : public DeviceInterface {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_h2d(device, c_stream, dst, src, size));
} else {
} else if (pimpl_->memory_copy_h2d) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait();
......@@ -300,7 +305,7 @@ class CustomDevice : public DeviceInterface {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_d2h(device, c_stream, dst, src, size));
} else {
} else if (pimpl_->memory_copy_d2h) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait();
......@@ -321,7 +326,7 @@ class CustomDevice : public DeviceInterface {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_d2d(device, c_stream, dst, src, size));
} else {
} else if (pimpl_->memory_copy_d2d) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait();
......@@ -455,24 +460,33 @@ class CustomDevice : public DeviceInterface {
}
void MemoryStats(size_t dev_id, size_t* total, size_t* free) override {
const auto device = &devices_pool[dev_id];
if (pimpl_->device_memory_stats) {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->device_memory_stats(device, total, free));
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->device_memory_stats(device, total, free));
size_t used = *total - *free;
VLOG(10) << Type() << " memory usage " << (used >> 20) << "M/"
<< (*total >> 20) << "M, " << (*free >> 20)
<< "M available to allocate";
size_t used = *total - *free;
VLOG(10) << Type() << " memory usage " << (used >> 20) << "M/"
<< (*total >> 20) << "M, " << (*free >> 20)
<< "M available to allocate";
} else {
*total = 0;
*free = 0;
}
}
size_t GetMinChunkSize(size_t dev_id) override {
const auto device = &devices_pool[dev_id];
if (pimpl_->device_min_chunk_size) {
const auto device = &devices_pool[dev_id];
size_t size = 0;
pimpl_->device_min_chunk_size(device, &size);
VLOG(10) << Type() << " min chunk size " << size << "B";
return size;
size_t size = 0;
pimpl_->device_min_chunk_size(device, &size);
VLOG(10) << Type() << " min chunk size " << size << "B";
return size;
} else {
return 1;
}
}
size_t GetMaxChunkSize(size_t dev_id) override {
......@@ -911,8 +925,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(get_device, true);
CHECK_INTERFACE(deinit_device, false);
CHECK_INTERFACE(create_stream, true);
CHECK_INTERFACE(destroy_stream, true);
CHECK_INTERFACE(create_stream, false);
CHECK_INTERFACE(destroy_stream, false);
CHECK_INTERFACE(query_stream, false);
CHECK_INTERFACE(stream_add_callback, false);
......@@ -922,9 +936,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(query_event, false);
CHECK_INTERFACE(synchronize_device, false);
CHECK_INTERFACE(synchronize_stream, true);
CHECK_INTERFACE(synchronize_stream, false);
CHECK_INTERFACE(synchronize_event, true);
CHECK_INTERFACE(stream_wait_event, true);
CHECK_INTERFACE(stream_wait_event, false);
CHECK_INTERFACE(device_memory_allocate, true);
CHECK_INTERFACE(device_memory_deallocate, true);
......@@ -932,9 +946,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(host_memory_deallocate, false);
CHECK_INTERFACE(unified_memory_allocate, false);
CHECK_INTERFACE(unified_memory_deallocate, false);
CHECK_INTERFACE(memory_copy_h2d, true);
CHECK_INTERFACE(memory_copy_d2h, true);
CHECK_INTERFACE(memory_copy_d2d, true);
CHECK_INTERFACE(memory_copy_h2d, false);
CHECK_INTERFACE(memory_copy_d2h, false);
CHECK_INTERFACE(memory_copy_d2d, false);
CHECK_INTERFACE(memory_copy_p2p, false);
CHECK_INTERFACE(async_memory_copy_h2d, false);
CHECK_INTERFACE(async_memory_copy_d2h, false);
......@@ -943,9 +957,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE(get_device_count, true);
CHECK_INTERFACE(get_device_list, true);
CHECK_INTERFACE(device_memory_stats, true);
CHECK_INTERFACE(device_memory_stats, false);
CHECK_INTERFACE(device_min_chunk_size, true);
CHECK_INTERFACE(device_min_chunk_size, false);
CHECK_INTERFACE(device_max_chunk_size, false);
CHECK_INTERFACE(device_max_alloc_size, false);
CHECK_INTERFACE(device_extra_padding_size, false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册