diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index b808e1561b24afdb7926db6e8e5c6514f84535b2..478b71745e4ac6dc737e66bd88717c38f4461010 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -141,7 +141,9 @@ if(WITH_GPU OR WITH_ROCM) target_link_libraries(device_context gpu_info gpu_context pten_gpu_info) target_link_libraries(device_context gpu_resource_pool) endif() - +if (WITH_CUSTOM_DEVICE) + target_link_libraries(device_context custom_context) +endif() if(WITH_ASCEND_CL) target_link_libraries(device_context npu_resource_pool) endif() diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 6452f6f7984e376ab686c7a417d2431af1045410..e5e369efd6bb428265922945163d4ec8d7e2ade6 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -897,21 +897,13 @@ MKLDNNDeviceContext::BlobPtr_t MKLDNNDeviceContext::GetBlob( #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE -CustomDeviceContext::CustomDeviceContext(CustomPlace place) : place_(place) { - DeviceGuard guard(place_); - stream_.reset(new stream::Stream()); - stream_->Init(place_); +CustomDeviceContext::CustomDeviceContext(CustomPlace place) + : phi::CustomContext(place) { + Init(); + stream_.reset(new platform::stream::Stream(place, stream())); } CustomDeviceContext::~CustomDeviceContext() {} - -const Place& CustomDeviceContext::GetPlace() const { return place_; } - -void CustomDeviceContext::Wait() const { - // platform::RecordEvent record_event("NPUDeviceContext/wait"); - VLOG(4) << "CustomDevice context(" << this << ") Wait"; - stream_->Wait(); -} #endif } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 0101286f0dfa87f3bc3b9ff0aae1e6f7342bace7..17288b354a2806ce56114335592b5103f9bf0ac4 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/core/device_context.h" @@ -819,17 +820,12 @@ class MKLDNNDeviceContext : public CPUDeviceContext { #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE -class CustomDeviceContext : public DeviceContext { +class CustomDeviceContext : public phi::CustomContext { public: explicit CustomDeviceContext(CustomPlace place); virtual ~CustomDeviceContext(); - const Place& GetPlace() const override; - void Wait() const override; Eigen::DefaultDevice* eigen_device() const { return nullptr; } - C_Stream stream() const { - return reinterpret_cast(stream_->raw_stream()); - } template void AddStreamCallback(Callback&& callback) const { @@ -839,13 +835,7 @@ class CustomDeviceContext : public DeviceContext { void WaitStreamCallback() const { return stream_->WaitCallback(); } private: - std::string device_type_; - - CustomPlace place_; - std::shared_ptr stream_; - - CustomDeviceContext(); }; template <> struct DefaultDeviceContextType { diff --git a/paddle/phi/common/backend.h b/paddle/phi/common/backend.h index 9a2ec093119fdbebfd2ea0eba0952b2236ab12e6..1d3e4369c69489fc13ec6938fbb9377e93765bb9 100644 --- a/paddle/phi/common/backend.h +++ b/paddle/phi/common/backend.h @@ -135,9 +135,6 @@ inline Backend StringToBackend(const char* backend_cstr) { if (s == std::string("Undefined")) { return Backend::UNDEFINED; } - for (size_t i = 0; i < s.size(); ++i) { - s[i] = toupper(s[i]); - } if (s == std::string("CPU")) { return Backend::CPU; } else if (s == std::string("GPU")) { diff --git a/paddle/phi/common/place.h b/paddle/phi/common/place.h index b6adb1c2932bff5842ef74947c149f23b8b79a02..36fb910cad6c705952a0e3858eb09810d1ea6f5f 100644 --- a/paddle/phi/common/place.h +++ b/paddle/phi/common/place.h @@ -188,6 +188,7 @@ class MLUPlace : public Place { class CustomPlace : public Place { public: + CustomPlace() : Place(AllocationType::CUSTOM, 0, "") {} explicit CustomPlace(const std::string dev_type) : Place(AllocationType::CUSTOM, 0, dev_type) {} CustomPlace(const std::string dev_type, int device_id)