diff --git a/paddle/fluid/framework/custom_kernel.cc b/paddle/fluid/framework/custom_kernel.cc index 058be98928227f74630389738939f39fee6c87e8..6bcae738cc421523d23b0769830177eb43f98f7c 100644 --- a/paddle/fluid/framework/custom_kernel.cc +++ b/paddle/fluid/framework/custom_kernel.cc @@ -237,6 +237,17 @@ static void RunKernelFunc(pten::KernelContext* ctx, if (backend == pten::Backend::CPU) { // do nothing } else { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + size_t device_type_id_ = static_cast(backend) - + static_cast(pten::Backend::ALL_BACKEND); + std::string device_type = pten::GetGlobalDeviceType(device_type_id_); + if (!device_type.empty()) { + auto custom_ctx = + ctx->GetDeviceContext(); + dev_ctx.set_stream(custom_ctx.stream()); + return; + } +#endif LOG(ERROR) << "[CUSTOM KERNEL] Unsupported kernel backend: " << backend << " with compiled Paddle."; return; diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 1d51383f6833b584f77bce9e865ad5d229590421..48a933ef0e219bc10c4278825531ef04b6be6100 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -846,7 +846,6 @@ class CustomDeviceContext : public DeviceContext { std::shared_ptr stream_; CustomDeviceContext(); - DISABLE_COPY_AND_ASSIGN(CustomDeviceContext); }; template <> struct DefaultDeviceContextType { diff --git a/paddle/pten/backends/CMakeLists.txt b/paddle/pten/backends/CMakeLists.txt index cc9352892030a5d86773c0131b0081fb425a4e9e..441bd0a8c303b5e45f173f20e78ca2e65b9fc314 100644 --- a/paddle/pten/backends/CMakeLists.txt +++ b/paddle/pten/backends/CMakeLists.txt @@ -2,6 +2,8 @@ add_subdirectory(dynload) add_subdirectory(cpu) +add_subdirectory(custom) + if(WITH_GPU OR WITH_ROCM) add_subdirectory(gpu) endif() diff --git a/paddle/pten/backends/custom/CMakeLists.txt b/paddle/pten/backends/custom/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a7de35dd4e66c687bf501845d7b079f90f42464 --- /dev/null +++ b/paddle/pten/backends/custom/CMakeLists.txt @@ -0,0 +1,3 @@ +if (WITH_CUSTOM_DEVICE) + cc_library(custom_context SRCS custom_context.cc DEPS pten_device_context device_manager) +endif() diff --git a/paddle/pten/backends/custom/custom_context.cc b/paddle/pten/backends/custom/custom_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..12e13609ebe1acac3da1d5400ab00d469f178e93 --- /dev/null +++ b/paddle/pten/backends/custom/custom_context.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/backends/custom/custom_context.h" + +#include "paddle/fluid/platform/device/device_guard.h" +#include "paddle/fluid/platform/device/stream.h" + +namespace pten { + +struct CustomContext::Impl { + explicit Impl(const CustomPlace& place) : place_(place) {} + + ~Impl() {} + + void Init() { + paddle::platform::DeviceGuard guard(place_); + stream_.reset(new paddle::platform::stream::Stream()); + stream_->Init(place_); + } + + const Place& GetPlace() const { return place_; } + + C_Stream stream() const { + return reinterpret_cast(stream_->raw_stream()); + } + + void Wait() const { stream_->Wait(); } + + Place place_; + + std::shared_ptr stream_; +}; + +void CustomContext::Init() { impl_->Init(); } + +const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); } + +C_Stream CustomContext::stream() const { return impl_->stream(); } + +void CustomContext::Wait() const { return impl_->Wait(); } + +CustomContext::CustomContext(const CustomPlace& place) + : DeviceContext(), impl_(std::make_unique(place)) {} + +CustomContext::~CustomContext() {} + +} // namespace pten diff --git a/paddle/pten/backends/custom/custom_context.h b/paddle/pten/backends/custom/custom_context.h new file mode 100644 index 0000000000000000000000000000000000000000..86fa44ce8dca413056c074082b8f5486094a0e97 --- /dev/null +++ b/paddle/pten/backends/custom/custom_context.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/platform/device/device_ext.h" +#include "paddle/pten/common/place.h" +#include "paddle/pten/core/device_context.h" + +namespace pten { + +class CustomContext : public DeviceContext { + public: + explicit CustomContext(const CustomPlace&); + + virtual ~CustomContext(); + + const Place& GetPlace() const override; + + /*! \brief Return stream in the device context. */ + C_Stream stream() const; + + // Wait for all operations completion in the stream. + void Wait() const override; + + public: + // NOTE: DeviceContext hold resources. Used in training scenarios. + // The interface used by the training scene, DeviceContext will initialize + // all resources and delete them when destructing. + void Init(); + + private: + CustomContext(); + + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace pten diff --git a/paddle/pten/common/backend.h b/paddle/pten/common/backend.h index 9944083248c4c7b718d31d3ccc4797cafbc09557..f5457807ad3b960e00c4354bd0dcb7779afdd30c 100644 --- a/paddle/pten/common/backend.h +++ b/paddle/pten/common/backend.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/pten/api/ext/exception.h" +#include "paddle/pten/common/place.h" namespace paddle { namespace experimental { @@ -114,8 +115,17 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { case Backend::CUDNN: os << "CUDNN"; break; - default: - PD_THROW("Invalid enum backend type `", static_cast(backend), "`."); + default: { + size_t device_type_id_ = static_cast(backend) - + static_cast(Backend::NUM_BACKENDS); + std::string device_type = pten::GetGlobalDeviceType(device_type_id_); + if (!device_type.empty()) { + os << device_type; + } else { + PD_THROW( + "Invalid enum backend type `", static_cast(backend), "`."); + } + } } return os; } diff --git a/paddle/pten/common/place.cc b/paddle/pten/common/place.cc index 0a3bfccb16a4b2aa83425ddc41ae141251842bac..46c0f92b85eda65e6e852945385f7d1339980612 100644 --- a/paddle/pten/common/place.cc +++ b/paddle/pten/common/place.cc @@ -86,7 +86,9 @@ size_t GetOrRegisterGlobalDeviceTypeId(const std::string &device_type) { } std::string GetGlobalDeviceType(size_t device_type_id) { - if (device_type_id == 0) return ""; + if (global_registered_device_type.find(device_type_id) == + global_registered_device_type.end()) + return ""; return global_registered_device_type[device_type_id]; } diff --git a/paddle/pten/core/compat/CMakeLists.txt b/paddle/pten/core/compat/CMakeLists.txt index efdb53c512a0d69dde682070ff3f58a6dc829a8d..c6bc9e15a535b52def1caef463a8a9228ab51e4a 100644 --- a/paddle/pten/core/compat/CMakeLists.txt +++ b/paddle/pten/core/compat/CMakeLists.txt @@ -1,11 +1,16 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS pten_enforce) cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce) + +set(convert_utils_deps data_type place op_utils) + if(WITH_GPU) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils pten_gpu_info) + set(convert_utils_deps ${convert_utils_deps} pten_gpu_info) elseif(WITH_ROCM) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils pten_gpu_info) + set(convert_utils_deps ${convert_utils_deps} pten_gpu_info) elseif(WITH_XPU) - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils pten_xpu_info) -else() - cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place op_utils) + set(convert_utils_deps ${convert_utils_deps} pten_xpu_info) endif() +if(WITH_CUSTOM_DEVICE) + set(convert_utils_deps ${convert_utils_deps} device_manager) +endif() +cc_library(convert_utils SRCS convert_utils.cc DEPS ${convert_utils_deps}) diff --git a/paddle/pten/core/compat/convert_utils.cc b/paddle/pten/core/compat/convert_utils.cc index 2cd286677462de8507285de3221e0c7ea41a8bb6..47126512503f3d4f701cc98331f3803f7cdc48ee 100644 --- a/paddle/pten/core/compat/convert_utils.cc +++ b/paddle/pten/core/compat/convert_utils.cc @@ -19,6 +19,10 @@ limitations under the License. */ #include "paddle/pten/common/place.h" #include "paddle/pten/core/compat/op_utils.h" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/fluid/platform/device/device_manager.h" +#endif + namespace pten { Backend TransToPtenBackend(const pten::Place& place) { @@ -26,6 +30,10 @@ Backend TransToPtenBackend(const pten::Place& place) { return Backend::CPU; } else if (place.GetType() == pten::AllocationType::GPU) { return Backend::GPU; + } else if (place.GetType() == pten::AllocationType::CUSTOM) { + return static_cast( + static_cast(Backend::NUM_BACKENDS) + + GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType())); } else { return Backend::UNDEFINED; } @@ -57,10 +65,23 @@ pten::Place TransToPtenPlace(const Backend& backend, bool set_device_id) { return pten::XPUPlace( set_device_id ? pten::backends::xpu::GetXPUCurrentDeviceId() : 0); #endif - default: + default: { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + size_t device_type_id_ = static_cast(backend) - + static_cast(Backend::NUM_BACKENDS); + std::string device_type = pten::GetGlobalDeviceType(device_type_id_); + if (!device_type.empty()) { + return pten::CustomPlace( + device_type, + set_device_id + ? paddle::platform::DeviceManager::GetDevice(device_type) + : 0); + } +#endif PADDLE_THROW(pten::errors::Unimplemented( "Unsupported backend `%s` when casting it to paddle place type.", backend)); + } } }