未验证 提交 d6d0820e 编写于 作者: R ronnywang 提交者: GitHub

[CustomRuntime] add pten::Backend support (#39606)

上级 46161679
...@@ -237,6 +237,17 @@ static void RunKernelFunc(pten::KernelContext* ctx, ...@@ -237,6 +237,17 @@ static void RunKernelFunc(pten::KernelContext* ctx,
if (backend == pten::Backend::CPU) { if (backend == pten::Backend::CPU) {
// do nothing // do nothing
} else { } else {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(pten::Backend::ALL_BACKEND);
std::string device_type = pten::GetGlobalDeviceType(device_type_id_);
if (!device_type.empty()) {
auto custom_ctx =
ctx->GetDeviceContext<paddle::platform::CustomDeviceContext>();
dev_ctx.set_stream(custom_ctx.stream());
return;
}
#endif
LOG(ERROR) << "[CUSTOM KERNEL] Unsupported kernel backend: " << backend LOG(ERROR) << "[CUSTOM KERNEL] Unsupported kernel backend: " << backend
<< " with compiled Paddle."; << " with compiled Paddle.";
return; return;
......
...@@ -846,7 +846,6 @@ class CustomDeviceContext : public DeviceContext { ...@@ -846,7 +846,6 @@ class CustomDeviceContext : public DeviceContext {
std::shared_ptr<platform::stream::Stream> stream_; std::shared_ptr<platform::stream::Stream> stream_;
CustomDeviceContext(); CustomDeviceContext();
DISABLE_COPY_AND_ASSIGN(CustomDeviceContext);
}; };
template <> template <>
struct DefaultDeviceContextType<platform::CustomPlace> { struct DefaultDeviceContextType<platform::CustomPlace> {
......
...@@ -2,6 +2,8 @@ add_subdirectory(dynload) ...@@ -2,6 +2,8 @@ add_subdirectory(dynload)
add_subdirectory(cpu) add_subdirectory(cpu)
add_subdirectory(custom)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
add_subdirectory(gpu) add_subdirectory(gpu)
endif() endif()
......
if (WITH_CUSTOM_DEVICE)
cc_library(custom_context SRCS custom_context.cc DEPS pten_device_context device_manager)
endif()
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/backends/custom/custom_context.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/fluid/platform/device/stream.h"
namespace pten {
struct CustomContext::Impl {
explicit Impl(const CustomPlace& place) : place_(place) {}
~Impl() {}
void Init() {
paddle::platform::DeviceGuard guard(place_);
stream_.reset(new paddle::platform::stream::Stream());
stream_->Init(place_);
}
const Place& GetPlace() const { return place_; }
C_Stream stream() const {
return reinterpret_cast<C_Stream>(stream_->raw_stream());
}
void Wait() const { stream_->Wait(); }
Place place_;
std::shared_ptr<paddle::platform::stream::Stream> stream_;
};
void CustomContext::Init() { impl_->Init(); }
const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); }
C_Stream CustomContext::stream() const { return impl_->stream(); }
void CustomContext::Wait() const { return impl_->Wait(); }
CustomContext::CustomContext(const CustomPlace& place)
: DeviceContext(), impl_(std::make_unique<Impl>(place)) {}
CustomContext::~CustomContext() {}
} // namespace pten
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/pten/common/place.h"
#include "paddle/pten/core/device_context.h"
namespace pten {
class CustomContext : public DeviceContext {
public:
explicit CustomContext(const CustomPlace&);
virtual ~CustomContext();
const Place& GetPlace() const override;
/*! \brief Return stream in the device context. */
C_Stream stream() const;
// Wait for all operations completion in the stream.
void Wait() const override;
public:
// NOTE: DeviceContext hold resources. Used in training scenarios.
// The interface used by the training scene, DeviceContext will initialize
// all resources and delete them when destructing.
void Init();
private:
CustomContext();
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace pten
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <ostream> #include <ostream>
#include "paddle/pten/api/ext/exception.h" #include "paddle/pten/api/ext/exception.h"
#include "paddle/pten/common/place.h"
namespace paddle { namespace paddle {
namespace experimental { namespace experimental {
...@@ -114,8 +115,17 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { ...@@ -114,8 +115,17 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
case Backend::CUDNN: case Backend::CUDNN:
os << "CUDNN"; os << "CUDNN";
break; break;
default: default: {
PD_THROW("Invalid enum backend type `", static_cast<int>(backend), "`."); size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS);
std::string device_type = pten::GetGlobalDeviceType(device_type_id_);
if (!device_type.empty()) {
os << device_type;
} else {
PD_THROW(
"Invalid enum backend type `", static_cast<int>(backend), "`.");
}
}
} }
return os; return os;
} }
......
...@@ -86,7 +86,9 @@ size_t GetOrRegisterGlobalDeviceTypeId(const std::string &device_type) { ...@@ -86,7 +86,9 @@ size_t GetOrRegisterGlobalDeviceTypeId(const std::string &device_type) {
} }
std::string GetGlobalDeviceType(size_t device_type_id) { std::string GetGlobalDeviceType(size_t device_type_id) {
if (device_type_id == 0) return ""; if (global_registered_device_type.find(device_type_id) ==
global_registered_device_type.end())
return "";
return global_registered_device_type[device_type_id]; return global_registered_device_type[device_type_id];
} }
......
cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce) cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce) cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce)
set(convert_utils_deps data_type place op_utils)
if(WITH_GPU) if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils pten_gpu_info) set(convert_utils_deps ${convert_utils_deps} pten_gpu_info)
elseif(WITH_ROCM) elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils pten_gpu_info) set(convert_utils_deps ${convert_utils_deps} pten_gpu_info)
elseif(WITH_XPU) elseif(WITH_XPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils pten_xpu_info) set(convert_utils_deps ${convert_utils_deps} pten_xpu_info)
else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils)
endif() endif()
if(WITH_CUSTOM_DEVICE)
set(convert_utils_deps ${convert_utils_deps} device_manager)
endif()
cc_library(convert_utils SRCS convert_utils.cc DEPS ${convert_utils_deps})
...@@ -19,6 +19,10 @@ limitations under the License. */ ...@@ -19,6 +19,10 @@ limitations under the License. */
#include "paddle/pten/common/place.h" #include "paddle/pten/common/place.h"
#include "paddle/pten/core/compat/op_utils.h" #include "paddle/pten/core/compat/op_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_manager.h"
#endif
namespace pten { namespace pten {
Backend TransToPtenBackend(const pten::Place& place) { Backend TransToPtenBackend(const pten::Place& place) {
...@@ -26,6 +30,10 @@ Backend TransToPtenBackend(const pten::Place& place) { ...@@ -26,6 +30,10 @@ Backend TransToPtenBackend(const pten::Place& place) {
return Backend::CPU; return Backend::CPU;
} else if (place.GetType() == pten::AllocationType::GPU) { } else if (place.GetType() == pten::AllocationType::GPU) {
return Backend::GPU; return Backend::GPU;
} else if (place.GetType() == pten::AllocationType::CUSTOM) {
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
} else { } else {
return Backend::UNDEFINED; return Backend::UNDEFINED;
} }
...@@ -57,10 +65,23 @@ pten::Place TransToPtenPlace(const Backend& backend, bool set_device_id) { ...@@ -57,10 +65,23 @@ pten::Place TransToPtenPlace(const Backend& backend, bool set_device_id) {
return pten::XPUPlace( return pten::XPUPlace(
set_device_id ? pten::backends::xpu::GetXPUCurrentDeviceId() : 0); set_device_id ? pten::backends::xpu::GetXPUCurrentDeviceId() : 0);
#endif #endif
default: default: {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS);
std::string device_type = pten::GetGlobalDeviceType(device_type_id_);
if (!device_type.empty()) {
return pten::CustomPlace(
device_type,
set_device_id
? paddle::platform::DeviceManager::GetDevice(device_type)
: 0);
}
#endif
PADDLE_THROW(pten::errors::Unimplemented( PADDLE_THROW(pten::errors::Unimplemented(
"Unsupported backend `%s` when casting it to paddle place type.", "Unsupported backend `%s` when casting it to paddle place type.",
backend)); backend));
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册