From dd4802735cf5634004187d3ff403eeb773e9eeb0 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 2 Feb 2023 19:30:52 +0800 Subject: [PATCH] [CustomDevice] refine custom device api (#50152) --- paddle/phi/backends/custom/custom_device.cc | 108 +++++++++++--------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index c0e28a90e9f..164b425e393 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -148,38 +148,41 @@ class CustomDevice : public DeviceInterface { stream::Stream::Flag::kDefaultFlag) override { const auto device = &devices_pool[dev_id]; C_Stream c_stream; - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( - pimpl_->create_stream(device, &c_stream)); + if (pimpl_->create_stream) { + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->create_stream(device, &c_stream)); + } else { + c_stream = nullptr; + } stream->set_stream(c_stream); } void DestroyStream(size_t dev_id, stream::Stream* stream) override { - const auto device = &devices_pool[dev_id]; - - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream( - device, reinterpret_cast(stream->raw_stream()))); + if (pimpl_->destroy_stream) { + const auto device = &devices_pool[dev_id]; + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->destroy_stream( + device, reinterpret_cast(stream->raw_stream()))); + } } void SynchronizeStream(size_t dev_id, const stream::Stream* stream) override { - const auto device = &devices_pool[dev_id]; - - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream( - device, reinterpret_cast(stream->raw_stream()))); + if (pimpl_->synchronize_stream) { + const auto device = &devices_pool[dev_id]; + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->synchronize_stream( + device, reinterpret_cast(stream->raw_stream()))); + } } bool QueryStream(size_t dev_id, const stream::Stream* stream) override { - const auto device = &devices_pool[dev_id]; - if (!pimpl_->query_stream) { SynchronizeStream(dev_id, stream); return true; + } else { + const auto device = &devices_pool[dev_id]; + return pimpl_->query_stream( + device, reinterpret_cast(stream->raw_stream())) == + C_SUCCESS; } - if (pimpl_->query_stream( - device, reinterpret_cast(stream->raw_stream())) == - C_SUCCESS) { - return true; - } - return false; } void AddCallback(size_t dev_id, @@ -259,12 +262,14 @@ class CustomDevice : public DeviceInterface { void StreamWaitEvent(size_t dev_id, const stream::Stream* stream, const event::Event* event) override { - const auto device = &devices_pool[dev_id]; + if (pimpl_->stream_wait_event) { + const auto device = &devices_pool[dev_id]; - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event( - device, - reinterpret_cast(stream->raw_stream()), - reinterpret_cast(event->raw_event()))); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event( + device, + reinterpret_cast(stream->raw_stream()), + reinterpret_cast(event->raw_event()))); + } } void MemoryCopyH2D(size_t dev_id, @@ -279,7 +284,7 @@ class CustomDevice : public DeviceInterface { C_Stream c_stream = reinterpret_cast(stream->raw_stream()); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->async_memory_copy_h2d(device, c_stream, dst, src, size)); - } else { + } else if (pimpl_->memory_copy_h2d) { paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); pool.Get(place)->Wait(); @@ -300,7 +305,7 @@ class CustomDevice : public DeviceInterface { C_Stream c_stream = reinterpret_cast(stream->raw_stream()); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->async_memory_copy_d2h(device, c_stream, dst, src, size)); - } else { + } else if (pimpl_->memory_copy_d2h) { paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); pool.Get(place)->Wait(); @@ -321,7 +326,7 @@ class CustomDevice : public DeviceInterface { C_Stream c_stream = reinterpret_cast(stream->raw_stream()); PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->async_memory_copy_d2d(device, c_stream, dst, src, size)); - } else { + } else if (pimpl_->memory_copy_d2d) { paddle::platform::DeviceContextPool& pool = paddle::platform::DeviceContextPool::Instance(); pool.Get(place)->Wait(); @@ -455,24 +460,33 @@ class CustomDevice : public DeviceInterface { } void MemoryStats(size_t dev_id, size_t* total, size_t* free) override { - const auto device = &devices_pool[dev_id]; + if (pimpl_->device_memory_stats) { + const auto device = &devices_pool[dev_id]; - PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( - pimpl_->device_memory_stats(device, total, free)); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( + pimpl_->device_memory_stats(device, total, free)); - size_t used = *total - *free; - VLOG(10) << Type() << " memory usage " << (used >> 20) << "M/" - << (*total >> 20) << "M, " << (*free >> 20) - << "M available to allocate"; + size_t used = *total - *free; + VLOG(10) << Type() << " memory usage " << (used >> 20) << "M/" + << (*total >> 20) << "M, " << (*free >> 20) + << "M available to allocate"; + } else { + *total = 0; + *free = 0; + } } size_t GetMinChunkSize(size_t dev_id) override { - const auto device = &devices_pool[dev_id]; + if (pimpl_->device_min_chunk_size) { + const auto device = &devices_pool[dev_id]; - size_t size = 0; - pimpl_->device_min_chunk_size(device, &size); - VLOG(10) << Type() << " min chunk size " << size << "B"; - return size; + size_t size = 0; + pimpl_->device_min_chunk_size(device, &size); + VLOG(10) << Type() << " min chunk size " << size << "B"; + return size; + } else { + return 1; + } } size_t GetMaxChunkSize(size_t dev_id) override { @@ -911,8 +925,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { CHECK_INTERFACE(get_device, true); CHECK_INTERFACE(deinit_device, false); - CHECK_INTERFACE(create_stream, true); - CHECK_INTERFACE(destroy_stream, true); + CHECK_INTERFACE(create_stream, false); + CHECK_INTERFACE(destroy_stream, false); CHECK_INTERFACE(query_stream, false); CHECK_INTERFACE(stream_add_callback, false); @@ -922,9 +936,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { CHECK_INTERFACE(query_event, false); CHECK_INTERFACE(synchronize_device, false); - CHECK_INTERFACE(synchronize_stream, true); + CHECK_INTERFACE(synchronize_stream, false); CHECK_INTERFACE(synchronize_event, true); - CHECK_INTERFACE(stream_wait_event, true); + CHECK_INTERFACE(stream_wait_event, false); CHECK_INTERFACE(device_memory_allocate, true); CHECK_INTERFACE(device_memory_deallocate, true); @@ -932,9 +946,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { CHECK_INTERFACE(host_memory_deallocate, false); CHECK_INTERFACE(unified_memory_allocate, false); CHECK_INTERFACE(unified_memory_deallocate, false); - CHECK_INTERFACE(memory_copy_h2d, true); - CHECK_INTERFACE(memory_copy_d2h, true); - CHECK_INTERFACE(memory_copy_d2d, true); + CHECK_INTERFACE(memory_copy_h2d, false); + CHECK_INTERFACE(memory_copy_d2h, false); + CHECK_INTERFACE(memory_copy_d2d, false); CHECK_INTERFACE(memory_copy_p2p, false); CHECK_INTERFACE(async_memory_copy_h2d, false); CHECK_INTERFACE(async_memory_copy_d2h, false); @@ -943,9 +957,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) { CHECK_INTERFACE(get_device_count, true); CHECK_INTERFACE(get_device_list, true); - CHECK_INTERFACE(device_memory_stats, true); + CHECK_INTERFACE(device_memory_stats, false); - CHECK_INTERFACE(device_min_chunk_size, true); + CHECK_INTERFACE(device_min_chunk_size, false); CHECK_INTERFACE(device_max_chunk_size, false); CHECK_INTERFACE(device_max_alloc_size, false); CHECK_INTERFACE(device_extra_padding_size, false); -- GitLab