未验证 提交 60fc555e 编写于 作者: R ronnywang 提交者: GitHub

[CustomRuntime] fix CustomDeviceContext (#39766)

上级 c5d15655
......@@ -141,7 +141,9 @@ if(WITH_GPU OR WITH_ROCM)
target_link_libraries(device_context gpu_info gpu_context pten_gpu_info)
target_link_libraries(device_context gpu_resource_pool)
endif()
if (WITH_CUSTOM_DEVICE)
target_link_libraries(device_context custom_context)
endif()
if(WITH_ASCEND_CL)
target_link_libraries(device_context npu_resource_pool)
endif()
......
......@@ -897,21 +897,13 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CustomDeviceContext::CustomDeviceContext(CustomPlace place) : place_(place) {
DeviceGuard guard(place_);
stream_.reset(new stream::Stream());
stream_->Init(place_);
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
: phi::CustomContext(place) {
Init();
stream_.reset(new platform::stream::Stream(place, stream()));
}
CustomDeviceContext::~CustomDeviceContext() {}
const Place& CustomDeviceContext::GetPlace() const { return place_; }
void CustomDeviceContext::Wait() const {
// platform::RecordEvent record_event("NPUDeviceContext/wait");
VLOG(4) << "CustomDevice context(" << this << ") Wait";
stream_->Wait();
}
#endif
} // namespace platform
} // namespace paddle
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/core/device_context.h"
......@@ -819,17 +820,12 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class CustomDeviceContext : public DeviceContext {
class CustomDeviceContext : public phi::CustomContext {
public:
explicit CustomDeviceContext(CustomPlace place);
virtual ~CustomDeviceContext();
const Place& GetPlace() const override;
void Wait() const override;
Eigen::DefaultDevice* eigen_device() const { return nullptr; }
C_Stream stream() const {
return reinterpret_cast<C_Stream>(stream_->raw_stream());
}
template <typename Callback>
void AddStreamCallback(Callback&& callback) const {
......@@ -839,13 +835,7 @@ class CustomDeviceContext : public DeviceContext {
void WaitStreamCallback() const { return stream_->WaitCallback(); }
private:
std::string device_type_;
CustomPlace place_;
std::shared_ptr<platform::stream::Stream> stream_;
CustomDeviceContext();
};
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
......
......@@ -135,9 +135,6 @@ inline Backend StringToBackend(const char* backend_cstr) {
if (s == std::string("Undefined")) {
return Backend::UNDEFINED;
}
for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]);
}
if (s == std::string("CPU")) {
return Backend::CPU;
} else if (s == std::string("GPU")) {
......
......@@ -188,6 +188,7 @@ class MLUPlace : public Place {
class CustomPlace : public Place {
public:
CustomPlace() : Place(AllocationType::CUSTOM, 0, "") {}
explicit CustomPlace(const std::string dev_type)
: Place(AllocationType::CUSTOM, 0, dev_type) {}
CustomPlace(const std::string dev_type, int device_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册