From 60fc555e2c33d9fdafde4145dfad9e789d89274c Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Tue, 22 Feb 2022 14:28:56 +0800 Subject: [PATCH] [CustomRuntime] fix CustomDeviceContext (#39766) --- paddle/fluid/platform/CMakeLists.txt | 4 +++- paddle/fluid/platform/device_context.cc | 16 ++++------------ paddle/fluid/platform/device_context.h | 14 ++------------ paddle/phi/common/backend.h | 3 --- paddle/phi/common/place.h | 1 + 5 files changed, 10 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index b808e1561b..478b71745e 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -141,7 +141,9 @@ if(WITH_GPU OR WITH_ROCM) target_link_libraries(device_context gpu_info gpu_context pten_gpu_info) target_link_libraries(device_context gpu_resource_pool) endif() - +if (WITH_CUSTOM_DEVICE) + target_link_libraries(device_context custom_context) +endif() if(WITH_ASCEND_CL) target_link_libraries(device_context npu_resource_pool) endif() diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 6452f6f798..e5e369efd6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -897,21 +897,13 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE -CustomDeviceContext::CustomDeviceContext(CustomPlace place) : place_(place) { - DeviceGuard guard(place_); - stream_.reset(new stream::Stream()); - stream_->Init(place_); +CustomDeviceContext::CustomDeviceContext(CustomPlace place) + : phi::CustomContext(place) { + Init(); + stream_.reset(new platform::stream::Stream(place, stream())); } CustomDeviceContext::~CustomDeviceContext() {} - -const Place& CustomDeviceContext::GetPlace() const { return place_; } - -void CustomDeviceContext::Wait() const { - // platform::RecordEvent record_event("NPUDeviceContext/wait"); - VLOG(4) << "CustomDevice context(" << this << ") Wait"; - stream_->Wait(); -} #endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 0101286f0d..17288b354a 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/core/device_context.h" @@ -819,17 +820,12 @@ class MKLDNNDeviceContext : public CPUDeviceContext { #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE -class CustomDeviceContext : public DeviceContext { +class CustomDeviceContext : public phi::CustomContext { public: explicit CustomDeviceContext(CustomPlace place); virtual ~CustomDeviceContext(); - const Place& GetPlace() const override; - void Wait() const override; Eigen::DefaultDevice* eigen_device() const { return nullptr; } - C_Stream stream() const { - return reinterpret_cast(stream_->raw_stream()); - } template void AddStreamCallback(Callback&& callback) const { @@ -839,13 +835,7 @@ class CustomDeviceContext : public DeviceContext { void WaitStreamCallback() const { return stream_->WaitCallback(); } private: - std::string device_type_; - - CustomPlace place_; - std::shared_ptr stream_; - - CustomDeviceContext(); }; template <> struct DefaultDeviceContextType { diff --git a/paddle/phi/common/backend.h b/paddle/phi/common/backend.h index 9a2ec09311..1d3e4369c6 100644 --- a/paddle/phi/common/backend.h +++ b/paddle/phi/common/backend.h @@ -135,9 +135,6 @@ inline Backend StringToBackend(const char* backend_cstr) { if (s == std::string("Undefined")) { return Backend::UNDEFINED; } - for (size_t i = 0; i < s.size(); ++i) { - s[i] = toupper(s[i]); - } if (s == std::string("CPU")) { return Backend::CPU; } else if (s == std::string("GPU")) { diff --git a/paddle/phi/common/place.h b/paddle/phi/common/place.h index b6adb1c293..36fb910cad 100644 --- a/paddle/phi/common/place.h +++ b/paddle/phi/common/place.h @@ -188,6 +188,7 @@ class MLUPlace : public Place { class CustomPlace : public Place { public: + CustomPlace() : Place(AllocationType::CUSTOM, 0, "") {} explicit CustomPlace(const std::string dev_type) : Place(AllocationType::CUSTOM, 0, dev_type) {} CustomPlace(const std::string dev_type, int device_id) -- GitLab