未验证 提交 60fc555e 编写于 作者: R ronnywang 提交者: GitHub

[CustomRuntime] fix CustomDeviceContext (#39766)

上级 c5d15655
...@@ -141,7 +141,9 @@ if(WITH_GPU OR WITH_ROCM) ...@@ -141,7 +141,9 @@ if(WITH_GPU OR WITH_ROCM)
target_link_libraries(device_context gpu_info gpu_context pten_gpu_info) target_link_libraries(device_context gpu_info gpu_context pten_gpu_info)
target_link_libraries(device_context gpu_resource_pool) target_link_libraries(device_context gpu_resource_pool)
endif() endif()
if (WITH_CUSTOM_DEVICE)
target_link_libraries(device_context custom_context)
endif()
if(WITH_ASCEND_CL) if(WITH_ASCEND_CL)
target_link_libraries(device_context npu_resource_pool) target_link_libraries(device_context npu_resource_pool)
endif() endif()
......
...@@ -897,21 +897,13 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( ...@@ -897,21 +897,13 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
CustomDeviceContext::CustomDeviceContext(CustomPlace place) : place_(place) { CustomDeviceContext::CustomDeviceContext(CustomPlace place)
DeviceGuard guard(place_); : phi::CustomContext(place) {
stream_.reset(new stream::Stream()); Init();
stream_->Init(place_); stream_.reset(new platform::stream::Stream(place, stream()));
} }
CustomDeviceContext::~CustomDeviceContext() {} CustomDeviceContext::~CustomDeviceContext() {}
const Place& CustomDeviceContext::GetPlace() const { return place_; }
void CustomDeviceContext::Wait() const {
// platform::RecordEvent record_event("NPUDeviceContext/wait");
VLOG(4) << "CustomDevice context(" << this << ") Wait";
stream_->Wait();
}
#endif #endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
...@@ -819,17 +820,12 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -819,17 +820,12 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
class CustomDeviceContext : public DeviceContext { class CustomDeviceContext : public phi::CustomContext {
public: public:
explicit CustomDeviceContext(CustomPlace place); explicit CustomDeviceContext(CustomPlace place);
virtual ~CustomDeviceContext(); virtual ~CustomDeviceContext();
const Place& GetPlace() const override;
void Wait() const override;
Eigen::DefaultDevice* eigen_device() const { return nullptr; } Eigen::DefaultDevice* eigen_device() const { return nullptr; }
C_Stream stream() const {
return reinterpret_cast<C_Stream>(stream_->raw_stream());
}
template <typename Callback> template <typename Callback>
void AddStreamCallback(Callback&& callback) const { void AddStreamCallback(Callback&& callback) const {
...@@ -839,13 +835,7 @@ class CustomDeviceContext : public DeviceContext { ...@@ -839,13 +835,7 @@ class CustomDeviceContext : public DeviceContext {
void WaitStreamCallback() const { return stream_->WaitCallback(); } void WaitStreamCallback() const { return stream_->WaitCallback(); }
private: private:
std::string device_type_;
CustomPlace place_;
std::shared_ptr<platform::stream::Stream> stream_; std::shared_ptr<platform::stream::Stream> stream_;
CustomDeviceContext();
}; };
template <> template <>
struct DefaultDeviceContextType<platform::CustomPlace> { struct DefaultDeviceContextType<platform::CustomPlace> {
......
...@@ -135,9 +135,6 @@ inline Backend StringToBackend(const char* backend_cstr) { ...@@ -135,9 +135,6 @@ inline Backend StringToBackend(const char* backend_cstr) {
if (s == std::string("Undefined")) { if (s == std::string("Undefined")) {
return Backend::UNDEFINED; return Backend::UNDEFINED;
} }
for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]);
}
if (s == std::string("CPU")) { if (s == std::string("CPU")) {
return Backend::CPU; return Backend::CPU;
} else if (s == std::string("GPU")) { } else if (s == std::string("GPU")) {
......
...@@ -188,6 +188,7 @@ class MLUPlace : public Place { ...@@ -188,6 +188,7 @@ class MLUPlace : public Place {
class CustomPlace : public Place { class CustomPlace : public Place {
public: public:
CustomPlace() : Place(AllocationType::CUSTOM, 0, "") {}
explicit CustomPlace(const std::string dev_type) explicit CustomPlace(const std::string dev_type)
: Place(AllocationType::CUSTOM, 0, dev_type) {} : Place(AllocationType::CUSTOM, 0, dev_type) {}
CustomPlace(const std::string dev_type, int device_id) CustomPlace(const std::string dev_type, int device_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册