未验证 提交 b4665d23 编写于 作者: R ronnywang 提交者: GitHub

[CustomRuntime] migrate CustomRuntime into phi (#39908)

上级 756af9ff
......@@ -440,11 +440,10 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper phi_tensor op_meta_info phi_api)
cc_library(custom_kernel SRCS custom_kernel.cc DEPS op_registry phi_custom_kernel phi_tensor_raw)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator custom_kernel)
set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator phi_custom_kernel)
cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES})
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/fluid/framework/custom_kernel.h"
#include "paddle/phi/core/custom_kernel.h"
namespace paddle {
namespace framework {
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX
typedef phi::CustomKernelMap& get_custom_kernel_map_t();
auto* func = reinterpret_cast<get_custom_kernel_map_t*>(
dlsym(dso_handle, "PD_GetCustomKernelMap"));
if (func == nullptr) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetCustomKernelMap symbol in this lib.";
return;
}
auto& custom_kernel_map = func();
phi::RegisterCustomKernels(custom_kernel_map);
LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path;
#else
VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux.";
#endif
return;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
namespace paddle {
namespace framework {
// Load custom kernel lib and register
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
} // namespace framework
} // namespace paddle
......@@ -231,19 +231,19 @@ void CustomDeviceUnsafeFastGarbageCollector::ClearCallback(
CustomStreamGarbageCollector::CustomStreamGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::DeviceGuard guard(place);
stream_.reset(new platform::stream::Stream);
phi::DeviceGuard guard(place);
stream_.reset(new phi::stream::Stream);
stream_->Init(place);
callback_manager_.reset(new platform::CallbackManager(stream_.get()));
callback_manager_.reset(new phi::CallbackManager(stream_.get()));
}
CustomStreamGarbageCollector::~CustomStreamGarbageCollector() {
platform::DeviceGuard guard(this->dev_ctx_->GetPlace());
phi::DeviceGuard guard(this->dev_ctx_->GetPlace());
stream_->Synchronize();
stream_->Destroy();
}
platform::stream::Stream *CustomStreamGarbageCollector::stream() const {
phi::stream::Stream *CustomStreamGarbageCollector::stream() const {
return stream_.get();
}
......
......@@ -230,14 +230,14 @@ class CustomStreamGarbageCollector : public GarbageCollector {
void Wait() const override;
platform::stream::Stream *stream() const;
phi::stream::Stream *stream() const;
protected:
void ClearCallback(const std::function<void()> &callback) override;
private:
std::unique_ptr<platform::stream::Stream> stream_;
std::unique_ptr<platform::CallbackManager> callback_manager_;
std::unique_ptr<phi::stream::Stream> stream_;
std::unique_ptr<phi::CallbackManager> callback_manager_;
};
#endif
......
......@@ -254,7 +254,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
"reinstall Paddle with CustomDevice support.",
place));
#else
platform::DeviceManager::SetDevice(place);
phi::DeviceManager::SetDevice(place);
#endif
}
......
......@@ -253,7 +253,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::DeviceManager::SetDevice(place);
phi::DeviceManager::SetDevice(place);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CustomDevice if use "
......
......@@ -31,7 +31,7 @@ cc_library(paddle_infer_contrib SRCS paddle_infer_contrib.cc DEPS zero_copy_tens
cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
set(paddle_inference_api_deps lod_tensor scope reset_tensor_array
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator custom_kernel)
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator phi_custom_kernel)
if(WITH_CRYPTO)
list(APPEND paddle_inference_api_deps paddle_crypto)
......
......@@ -193,10 +193,10 @@ class AllocatorFacadePrivate {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < platform::DeviceManager::GetDeviceCount(dev_type);
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
++dev_id) {
InitNaiveBestFitCustomDeviceAllocator(
platform::CustomPlace(dev_type, dev_id));
......@@ -240,10 +240,10 @@ class AllocatorFacadePrivate {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < platform::DeviceManager::GetDeviceCount(dev_type);
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
++dev_id) {
InitAutoGrowthCustomDeviceAllocator(
platform::CustomPlace(dev_type, dev_id), allow_free_idle_chunk);
......@@ -738,7 +738,7 @@ class AllocatorFacadePrivate {
auto custom_allocator =
std::make_shared<paddle::memory::allocation::CustomAllocator>(p);
allocators_[p] = std::make_shared<AutoGrowthBestFitAllocator>(
custom_allocator, platform::DeviceManager::GetMinChunkSize(p),
custom_allocator, phi::DeviceManager::GetMinChunkSize(p),
allow_free_idle_chunk);
}
#endif
......@@ -814,11 +814,10 @@ class AllocatorFacadePrivate {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < platform::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
dev_id < phi::DeviceManager::GetDeviceCount(dev_type); dev_id++) {
places.emplace_back(platform::CustomPlace(dev_type, dev_id));
}
}
......
......@@ -32,17 +32,16 @@ void CustomAllocator::FreeImpl(phi::Allocation* allocation) {
}
phi::Allocation* CustomAllocator::AllocateImpl(size_t size) {
std::call_once(once_flag_,
[this] { platform::DeviceManager::SetDevice(place_); });
std::call_once(once_flag_, [this] { phi::DeviceManager::SetDevice(place_); });
void* ptr =
platform::DeviceManager::GetDeviceWithPlace(place_)->MemoryAllocate(size);
phi::DeviceManager::GetDeviceWithPlace(place_)->MemoryAllocate(size);
if (LIKELY(ptr)) {
return new Allocation(ptr, size, place_);
}
size_t avail, total;
platform::DeviceManager::MemoryStats(place_, &total, &avail);
phi::DeviceManager::MemoryStats(place_, &total, &avail);
auto dev_type = platform::PlaceHelper::GetDeviceType(place_);
auto dev_id = platform::PlaceHelper::GetDeviceId(place_);
......
......@@ -739,7 +739,7 @@ class BuddyAllocatorList {
private:
explicit BuddyAllocatorList(const std::string &device_type)
: device_type_(device_type) {
auto devices = platform::DeviceManager::GetDeviceList(device_type);
auto devices = phi::DeviceManager::GetDeviceList(device_type);
for (auto dev_id : devices) {
init_flags_[dev_id].reset(new std::once_flag());
}
......@@ -766,15 +766,15 @@ class BuddyAllocatorList {
device_type_, dev_id));
std::call_once(*init_flags_[dev_id], [this, dev_id] {
platform::DeviceManager::SetDevice(device_type_, dev_id);
phi::DeviceManager::SetDevice(device_type_, dev_id);
platform::CustomPlace place(device_type_, dev_id);
allocators_[dev_id].reset(new BuddyAllocator(
std::unique_ptr<detail::SystemAllocator>(
new detail::CustomAllocator(device_type_, dev_id)),
platform::DeviceManager::GetMinChunkSize(place),
platform::DeviceManager::GetMaxChunkSize(place),
platform::DeviceManager::GetExtraPaddingSize(place), device_type_));
phi::DeviceManager::GetMinChunkSize(place),
phi::DeviceManager::GetMaxChunkSize(place),
phi::DeviceManager::GetExtraPaddingSize(place), device_type_));
});
return allocators_[dev_id].get();
......@@ -808,9 +808,9 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
auto *ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) {
platform::DeviceGuard guard(place);
phi::DeviceGuard guard(place);
size_t avail, total;
platform::DeviceManager::MemoryStats(place, &total, &avail);
phi::DeviceManager::MemoryStats(place, &total, &avail);
PADDLE_THROW(platform::errors::ResourceExhausted(
"Cannot allocate %s in %s:%d, avaliable %s, total %s, used "
"%s. ",
......@@ -819,8 +819,7 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
string::HumanReadableSize(total - avail)));
} else {
if (FLAGS_init_allocated_mem) {
platform::DeviceManager::GetDeviceWithPlace(place)->MemorySet(ptr, 0xEF,
size);
phi::DeviceManager::GetDeviceWithPlace(place)->MemorySet(ptr, 0xEF, size);
}
}
VLOG(10) << " pointer=" << ptr;
......
......@@ -43,11 +43,11 @@ BuddyAllocator::BuddyAllocator(
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (!dev_type.empty()) {
init_allocate_size_func_ = [dev_type]() {
return platform::DeviceManager::GetInitAllocSize(
return phi::DeviceManager::GetInitAllocSize(
platform::PlaceHelper::CreatePlace(dev_type));
};
re_allocate_size_func_ = [dev_type]() {
return platform::DeviceManager::GetReallocSize(
return phi::DeviceManager::GetReallocSize(
platform::PlaceHelper::CreatePlace(dev_type));
};
} else {
......
......@@ -438,7 +438,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
void* p;
auto place = platform::CustomPlace(dev_type_, dev_id_);
auto device = platform::DeviceManager::GetDeviceWithPlace(place);
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
p = device->MemoryAllocate(size);
if (LIKELY(p)) {
VLOG(4) << "CustomAllocator::Alloc " << p << " size " << size;
......@@ -447,7 +447,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
} else {
size_t avail, total;
platform::DeviceManager::MemoryStats(place, &total, &avail);
phi::DeviceManager::MemoryStats(place, &total, &avail);
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
"\n\nOut of memory error on %s %d. "
"total memory is %s, used memory is %s, "
......@@ -470,7 +470,7 @@ void CustomAllocator::Free(void* p, size_t size, size_t index) {
size, plug_alloc_size));
plug_alloc_size -= size;
auto place = platform::CustomPlace(dev_type_, dev_id_);
auto device = platform::DeviceManager::GetDeviceWithPlace(place);
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
device->MemoryDeallocate(p, size);
}
......
......@@ -44,9 +44,9 @@ void Copy<platform::CPUPlace, platform::CustomPlace>(
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << ", stream=" << stream;
platform::DeviceManager::SetDevice(src_place);
platform::stream::Stream stream_wrapper(src_place, stream);
platform::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2H(
phi::DeviceManager::SetDevice(src_place);
phi::stream::Stream stream_wrapper(src_place, stream);
phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2H(
dst, src, num, &stream_wrapper);
}
......@@ -62,9 +62,9 @@ void Copy<platform::CustomPlace, platform::CPUPlace>(
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
<< dst_place << ", stream=" << stream;
platform::DeviceManager::SetDevice(dst_place);
platform::stream::Stream stream_wrapper(dst_place, stream);
platform::DeviceManager::GetDeviceWithPlace(dst_place)->MemoryCopyH2D(
phi::DeviceManager::SetDevice(dst_place);
phi::stream::Stream stream_wrapper(dst_place, stream);
phi::DeviceManager::GetDeviceWithPlace(dst_place)->MemoryCopyH2D(
dst, src, num, &stream_wrapper);
}
......@@ -82,16 +82,16 @@ void Copy<platform::CustomPlace, platform::CustomPlace>(
<< dst_place << ", stream=" << stream;
if (src_type == dst_type) {
platform::DeviceManager::SetDevice(src_place);
platform::stream::Stream stream_wrapper(src_place, stream);
phi::DeviceManager::SetDevice(src_place);
phi::stream::Stream stream_wrapper(src_place, stream);
auto src_id = platform::PlaceHelper::GetDeviceId(src_place);
auto dst_id = platform::PlaceHelper::GetDeviceId(dst_place);
if (src_id == dst_id) {
platform::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2D(
phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyD2D(
dst, src, num, &stream_wrapper);
} else {
platform::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyP2P(
phi::DeviceManager::GetDeviceWithPlace(src_place)->MemoryCopyP2P(
dst_place, dst, src, num, &stream_wrapper);
}
} else {
......
......@@ -117,7 +117,7 @@ endif()
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
# seperate init from device_context to avoid cycle dependencies
cc_library(init SRCS init.cc DEPS device_context custom_kernel)
cc_library(init SRCS init.cc DEPS device_context phi_custom_kernel)
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
......
IF(WITH_CUSTOM_DEVICE)
cc_library(callback_manager SRCS callback_manager.cc DEPS enforce place)
cc_library(device_guard SRCS device_guard.cc DEPS enforce place)
cc_library(stream SRCS stream.cc DEPS callback_manager)
cc_library(event SRCS event.cc DEPS enforce place)
cc_library(device_base SRCS device_base.cc DEPS stream event callback_manager device_guard device_context flags)
ENDIF()
set(DEV_LIBS custom_device)
......@@ -37,11 +25,3 @@ ENDIF()
IF(WITH_MLU)
add_subdirectory(mlu)
ENDIF()
# CUSTOM
IF(WITH_CUSTOM_DEVICE)
add_subdirectory(custom)
cc_library(device_manager SRCS device_manager.cc DEPS custom_device)
set(GLOB_DEV_LIB device_manager custom_device CACHE INTERNAL "Global DEV library")
ENDIF()
IF(WITH_CUSTOM_DEVICE)
cc_library(custom_device SRCS custom_device.cc DEPS device_base device_context)
cc_test(custom_device_test SRCS custom_device_test.cc DEPS device_manager device_context )
ENDIF()
......@@ -14,7 +14,10 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_ext.h"
#include <string>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_ext.h"
namespace paddle {
namespace platform {
......
......@@ -40,10 +40,10 @@ limitations under the License. */
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
#endif
......@@ -903,7 +903,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
: phi::CustomContext(place) {
Init();
stream_.reset(new platform::stream::Stream(place, stream()));
stream_.reset(new phi::stream::Stream(place, stream()));
}
CustomDeviceContext::~CustomDeviceContext() {}
......
......@@ -72,8 +72,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#endif
#include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/stream.h"
#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
#include "unsupported/Eigen/CXX11/Tensor"
......@@ -838,7 +838,7 @@ class CustomDeviceContext : public phi::CustomContext {
void WaitStreamCallback() const { return stream_->WaitCallback(); }
private:
std::shared_ptr<platform::stream::Stream> stream_;
std::shared_ptr<phi::stream::Stream> stream_;
};
template <>
struct DefaultDeviceContextType<platform::CustomPlace> {
......
......@@ -55,7 +55,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
#endif
#include "paddle/fluid/framework/custom_kernel.h"
#include "paddle/phi/core/custom_kernel.h"
DECLARE_int32(paddle_num_threads);
PADDLE_DEFINE_EXPORTED_int32(
......@@ -145,7 +145,7 @@ void InitCupti() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void LoadCustomDevice(const std::string &library_dir) {
LOG(INFO) << "Try loading custom device libs from: [" << library_dir << "]";
std::vector<std::string> libs = platform::ListAllLibraries(library_dir);
std::vector<std::string> libs = phi::ListAllLibraries(library_dir);
for (const auto &lib_path : libs) {
auto dso_handle = dlopen(lib_path.c_str(), RTLD_NOW);
PADDLE_ENFORCE_NOT_NULL(
......@@ -153,8 +153,8 @@ void LoadCustomDevice(const std::string &library_dir) {
platform::errors::InvalidArgument(
"Fail to open library: %s with error: %s", lib_path, dlerror()));
platform::LoadCustomRuntimeLib(lib_path, dso_handle);
framework::LoadCustomKernelLib(lib_path, dso_handle);
phi::LoadCustomRuntimeLib(lib_path, dso_handle);
phi::LoadCustomKernelLib(lib_path, dso_handle);
}
LOG(INFO) << "Finished in LoadCustomDevice with libs_path: [" << library_dir
<< "]";
......@@ -259,9 +259,9 @@ void InitDevices(const std::vector<int> devices) {
LOG(INFO) << "ENV [CUSTOM_DEVICE_ROOT]=" << custom_kernel_root;
LoadCustomDevice(custom_kernel_root);
auto device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (auto &dev_type : device_types) {
auto device_count = platform::DeviceManager::GetDeviceCount(dev_type);
auto device_count = phi::DeviceManager::GetDeviceCount(dev_type);
LOG(INFO) << "CustomDevice: " << dev_type
<< ", visible devices count: " << device_count;
for (size_t i = 0; i < device_count; i++) {
......
......@@ -1668,7 +1668,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_all_device_type", []() {
std::vector<std::string> device_types;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
device_types = platform::DeviceManager::GetAllDeviceTypes();
device_types = phi::DeviceManager::GetAllDeviceTypes();
#else
LOG(WARNING) << string::Sprintf(
"Cannot use get_all_device_type because you have installed"
......@@ -1682,7 +1682,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_all_custom_device_type", []() {
std::vector<std::string> device_types;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
device_types = platform::DeviceManager::GetAllCustomDeviceTypes();
device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
#else
LOG(WARNING) << string::Sprintf(
"Cannot use get_all_custom_device_type because you have installed"
......@@ -1696,7 +1696,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_available_device", [] {
std::vector<std::string> devices;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
devices = platform::DeviceManager::GetAllDeviceList();
devices = phi::DeviceManager::GetAllDeviceList();
#else
LOG(WARNING) << string::Sprintf(
"Cannot use get_available_device because you have installed"
......@@ -1710,7 +1710,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_available_custom_device", [] {
std::vector<std::string> devices;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
devices = platform::DeviceManager::GetAllCustomDeviceList();
devices = phi::DeviceManager::GetAllCustomDeviceList();
#else
LOG(WARNING) << string::Sprintf(
"Cannot use get_available_custom_device because you have "
......@@ -1747,10 +1747,10 @@ All parameter, weight, gradient are variables in Paddle.
std::exit(-1);
}
if (LIKELY(platform::DeviceManager::HasDeviceType(device_type) &&
platform::DeviceManager::IsCustom(device_type))) {
if (LIKELY(phi::DeviceManager::HasDeviceType(device_type) &&
phi::DeviceManager::IsCustom(device_type))) {
int dev_count = static_cast<int>(
platform::DeviceManager::GetDeviceCount(device_type));
phi::DeviceManager::GetDeviceCount(device_type));
if (UNLIKELY(dev_id >= dev_count)) {
if (dev_count == 0) {
LOG(ERROR) << "Cannot use " << device_type
......
......@@ -393,10 +393,10 @@ void SetTensorFromPyArrayT(
} else if (paddle::platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::Place tmp_place = place;
platform::DeviceGuard guard(tmp_place);
phi::DeviceGuard guard(tmp_place);
auto dst = self->mutable_data<T>(place);
platform::DeviceManager::GetDeviceWithPlace(tmp_place)->MemoryCopyH2D(
phi::DeviceManager::GetDeviceWithPlace(tmp_place)->MemoryCopyH2D(
reinterpret_cast<void *>(dst),
const_cast<void *>(reinterpret_cast<const void *>(array.data())),
array.nbytes());
......
......@@ -24,4 +24,11 @@ endif()
if(WITH_CUSTOM_DEVICE)
add_dependencies(phi_context custom_context)
cc_library(callback_manager SRCS callback_manager.cc DEPS enforce place)
cc_library(device_guard SRCS device_guard.cc DEPS enforce place)
cc_library(stream SRCS stream.cc DEPS callback_manager)
cc_library(event SRCS event.cc DEPS enforce place)
cc_library(device_base SRCS device_base.cc DEPS stream event callback_manager device_guard device_context flags)
cc_library(device_manager SRCS device_manager.cc DEPS custom_device)
set(GLOB_DEV_LIB device_manager custom_device CACHE INTERNAL "Global DEV library")
endif()
......@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace phi {
CallbackManager::CallbackManager(stream::Stream *stream)
: stream_(stream), thread_pool_(1) {}
......@@ -32,12 +31,12 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
});
});
platform::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->AddCallback(stream_, func);
}
void CallbackManager::Wait() const {
platform::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
phi::DeviceManager::GetDeviceWithPlace(stream_->GetPlace())
->SynchronizeStream(stream_);
{
......@@ -48,5 +47,4 @@ void CallbackManager::Wait() const {
}
}
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -32,8 +32,7 @@
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace phi {
namespace stream {
class Stream;
......@@ -58,5 +57,4 @@ class CallbackManager {
mutable std::future<void> last_future_;
};
} // namespace platform
} // namespace paddle
} // namespace phi
if (WITH_CUSTOM_DEVICE)
cc_library(custom_context SRCS custom_context.cc DEPS phi_device_context device_manager)
cc_library(custom_device SRCS custom_device.cc DEPS device_base device_context)
cc_test(custom_device_test SRCS custom_device_test.cc DEPS device_manager device_context)
endif()
......@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/stream.h"
namespace phi {
......@@ -25,8 +25,8 @@ struct CustomContext::Impl {
~Impl() {}
void Init() {
paddle::platform::DeviceGuard guard(place_);
stream_.reset(new paddle::platform::stream::Stream());
phi::DeviceGuard guard(place_);
stream_.reset(new phi::stream::Stream());
stream_->Init(place_);
}
......@@ -40,7 +40,7 @@ struct CustomContext::Impl {
Place place_;
std::shared_ptr<paddle::platform::stream::Stream> stream_;
std::shared_ptr<phi::stream::Stream> stream_;
};
void CustomContext::Init() { impl_->Init(); }
......
......@@ -12,23 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/device_base.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
static bool operator==(const C_Device_st& d1, const C_Device_st& d2) {
return d1.id == d2.id;
}
namespace paddle {
namespace platform {
namespace phi {
class CustomDevice : public DeviceInterface {
public:
CustomDevice(const std::string& type, int priority, bool is_custom,
std::unique_ptr<C_DeviceInterface> pimpl, void* dso_handle)
CustomDevice(const std::string& type,
int priority,
bool is_custom,
std::unique_ptr<C_DeviceInterface> pimpl,
void* dso_handle)
: DeviceInterface(type, priority, is_custom),
pimpl_(std::move(pimpl)),
dso_handle_(dso_handle) {
......@@ -122,14 +127,15 @@ class CustomDevice : public DeviceInterface {
return device.id;
}
void CreateStream(size_t dev_id, stream::Stream* stream,
void CreateStream(size_t dev_id,
stream::Stream* stream,
const stream::Stream::Priority& priority =
stream::Stream::Priority::kNormal,
const stream::Stream::Flag& flag =
stream::Stream::Flag::kDefaultFlag) override {
if (priority != stream::Stream::Priority::kNormal ||
flag != stream::Stream::Flag::kDefaultFlag) {
PADDLE_THROW(platform::errors::Unavailable(
PADDLE_THROW(phi::errors::Unavailable(
"priority != stream::Stream::Priority::kNormal || flag != "
"stream::Stream::Flag::kDefaultFlag is not allowed on "
"CustomDevice."));
......@@ -162,23 +168,28 @@ class CustomDevice : public DeviceInterface {
SynchronizeStream(dev_id, stream);
return true;
}
if (pimpl_->query_stream(device, reinterpret_cast<C_Stream>(
stream->raw_stream())) == C_SUCCESS) {
if (pimpl_->query_stream(
device, reinterpret_cast<C_Stream>(stream->raw_stream())) ==
C_SUCCESS) {
return true;
}
return false;
}
void AddCallback(size_t dev_id, stream::Stream* stream,
void AddCallback(size_t dev_id,
stream::Stream* stream,
stream::Stream::Callback* callback) override {
if (!pimpl_->stream_add_callback) {
PADDLE_THROW(platform::errors::Unavailable(
PADDLE_THROW(phi::errors::Unavailable(
"AddCallback is not supported on %s.", Type()));
} else {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_add_callback(
device, reinterpret_cast<C_Stream>(stream->raw_stream()),
[](C_Device device, C_Stream stream, void* user_data,
device,
reinterpret_cast<C_Stream>(stream->raw_stream()),
[](C_Device device,
C_Stream stream,
void* user_data,
C_Status* status) {
std::unique_ptr<std::function<void()>> func(
reinterpret_cast<std::function<void()>*>(user_data));
......@@ -188,7 +199,8 @@ class CustomDevice : public DeviceInterface {
}
}
void CreateEvent(size_t dev_id, event::Event* event,
void CreateEvent(size_t dev_id,
event::Event* event,
event::Event::Flag flags) override {
const auto device = &devices_pool[dev_id];
C_Event c_event;
......@@ -205,13 +217,15 @@ class CustomDevice : public DeviceInterface {
device, reinterpret_cast<C_Event>(event->raw_event())));
}
void RecordEvent(size_t dev_id, const event::Event* event,
void RecordEvent(size_t dev_id,
const event::Event* event,
const stream::Stream* stream) override {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->record_event(
device, reinterpret_cast<C_Stream>(stream->raw_stream()),
reinterpret_cast<C_Event>(event->raw_event())));
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->record_event(device,
reinterpret_cast<C_Stream>(stream->raw_stream()),
reinterpret_cast<C_Event>(event->raw_event())));
}
void SynchronizeEvent(size_t dev_id, const event::Event* event) override {
......@@ -228,78 +242,93 @@ class CustomDevice : public DeviceInterface {
SynchronizeEvent(dev_id, event);
return true;
}
if (pimpl_->query_event(device, reinterpret_cast<C_Event>(
event->raw_event())) == C_SUCCESS) {
if (pimpl_->query_event(device,
reinterpret_cast<C_Event>(event->raw_event())) ==
C_SUCCESS) {
return true;
}
return false;
}
void StreamWaitEvent(size_t dev_id, const stream::Stream* stream,
void StreamWaitEvent(size_t dev_id,
const stream::Stream* stream,
const event::Event* event) override {
const auto device = &devices_pool[dev_id];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->stream_wait_event(
device, reinterpret_cast<C_Stream>(stream->raw_stream()),
device,
reinterpret_cast<C_Stream>(stream->raw_stream()),
reinterpret_cast<C_Event>(event->raw_event())));
}
void MemoryCopyH2D(size_t dev_id, void* dst, const void* src, size_t size,
void MemoryCopyH2D(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr) override {
const auto device = &devices_pool[dev_id];
auto place = platform::CustomPlace(Type(), dev_id);
auto place = CustomPlace(Type(), dev_id);
if (stream && stream->raw_stream() && pimpl_->async_memory_copy_h2d) {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_h2d(device, c_stream, dst, src, size));
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->memory_copy_h2d(device, dst, src, size));
}
}
void MemoryCopyD2H(size_t dev_id, void* dst, const void* src, size_t size,
void MemoryCopyD2H(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr) override {
const auto device = &devices_pool[dev_id];
auto place = platform::CustomPlace(Type(), dev_id);
auto place = CustomPlace(Type(), dev_id);
if (stream && stream->raw_stream() && pimpl_->async_memory_copy_d2h) {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_d2h(device, c_stream, dst, src, size));
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->memory_copy_d2h(device, dst, src, size));
}
}
void MemoryCopyD2D(size_t dev_id, void* dst, const void* src, size_t size,
void MemoryCopyD2D(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr) override {
const auto device = &devices_pool[dev_id];
auto place = platform::CustomPlace(Type(), dev_id);
auto place = CustomPlace(Type(), dev_id);
if (stream && stream->raw_stream() && pimpl_->async_memory_copy_d2d) {
C_Stream c_stream = reinterpret_cast<C_Stream>(stream->raw_stream());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->async_memory_copy_d2d(device, c_stream, dst, src, size));
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(place)->Wait();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->memory_copy_d2d(device, dst, src, size));
}
}
void MemoryCopyP2P(const Place& dst_place, void* dst, size_t src_dev_id,
const void* src, size_t size,
void MemoryCopyP2P(const Place& dst_place,
void* dst,
size_t src_dev_id,
const void* src,
size_t size,
const stream::Stream* stream = nullptr) override {
int dst_dev_id = PlaceToId(dst_place);
auto dst_device = &devices_pool[dst_dev_id];
......@@ -310,8 +339,12 @@ class CustomDevice : public DeviceInterface {
MemoryCopyP2P(dst_place, dst, src_dev_id, src, size);
} else {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->async_memory_copy_p2p(
dst_device, src_device,
reinterpret_cast<C_Stream>(stream->raw_stream()), dst, src, size));
dst_device,
src_device,
reinterpret_cast<C_Stream>(stream->raw_stream()),
dst,
src,
size));
}
} else {
if (!pimpl_->memory_copy_p2p) {
......@@ -319,9 +352,9 @@ class CustomDevice : public DeviceInterface {
MemoryCopyD2H(src_dev_id, tmp.get(), src, size);
MemoryCopyH2D(dst_dev_id, dst, tmp.get(), size);
} else {
auto src_place = platform::CustomPlace(Type(), src_dev_id);
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto src_place = CustomPlace(Type(), src_dev_id);
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
pool.Get(src_place)->Wait();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->memory_copy_p2p(dst_device, src_device, dst, src, size));
......@@ -350,8 +383,8 @@ class CustomDevice : public DeviceInterface {
const auto device = &devices_pool[dev_id];
if (!pimpl_->unified_memory_allocate) {
PADDLE_THROW(platform::errors::Unavailable(
"MemoryAllocKind::Host is not supported on %s.", Type()));
PADDLE_THROW(phi::errors::Unavailable(
"MemoryAllocateHost is not supported on %s.", Type()));
} else {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->host_memory_allocate(device, &ptr, size));
......@@ -363,8 +396,8 @@ class CustomDevice : public DeviceInterface {
const auto device = &devices_pool[dev_id];
if (!pimpl_->host_memory_deallocate) {
PADDLE_THROW(platform::errors::Unavailable(
"MemoryAllocKind::Host is not supported on %s.", Type()));
PADDLE_THROW(phi::errors::Unavailable(
"MemoryDeallocateHost is not supported on %s.", Type()));
} else {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->host_memory_deallocate(device, ptr, size));
......@@ -376,8 +409,8 @@ class CustomDevice : public DeviceInterface {
const auto device = &devices_pool[dev_id];
if (!pimpl_->unified_memory_allocate) {
PADDLE_THROW(platform::errors::Unavailable(
"MemoryAllocKind::Unified is not supported on %s.", Type()));
PADDLE_THROW(phi::errors::Unavailable(
"MemoryAllocateUnified is not supported on %s.", Type()));
} else {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->unified_memory_allocate(device, &ptr, size));
......@@ -389,15 +422,17 @@ class CustomDevice : public DeviceInterface {
const auto device = &devices_pool[dev_id];
if (!pimpl_->unified_memory_deallocate) {
PADDLE_THROW(platform::errors::Unavailable(
"MemoryAllocKind::Host is not supported on %s.", Type()));
PADDLE_THROW(phi::errors::Unavailable(
"MemoryDeallocateUnified is not supported on %s.", Type()));
} else {
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(
pimpl_->unified_memory_deallocate(device, ptr, size));
}
}
void MemorySet(size_t dev_id, void* ptr, uint8_t value,
void MemorySet(size_t dev_id,
void* ptr,
uint8_t value,
size_t size) override {
const auto device = &devices_pool[dev_id];
......@@ -532,10 +567,12 @@ class CustomDevice : public DeviceInterface {
inline int PlaceToId(const Place& place) {
int dev_id = PlaceToIdNoCheck(place);
PADDLE_ENFORCE_NE(devices_pool.find(dev_id), devices_pool.end(),
platform::errors::NotFound(
PADDLE_ENFORCE_NE(devices_pool.find(dev_id),
devices_pool.end(),
phi::errors::NotFound(
"Cannot found %s %d, please check visible devices",
Type(), dev_id));
Type(),
dev_id));
return dev_id;
}
......@@ -623,11 +660,14 @@ typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
const std::string& dso_lib_path, void* dso_handle) {
const std::string& dso_lib_path,
void* dso_handle) {
if (ValidCustomCustomRuntimeParams(&runtime_params)) {
auto device =
std::make_unique<CustomDevice>(runtime_params.device_type, 255, true,
std::move(device_interface), dso_handle);
auto device = std::make_unique<CustomDevice>(runtime_params.device_type,
255,
true,
std::move(device_interface),
dso_handle);
if (false == DeviceManager::Register(std::move(device))) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path
<< "]. Register failed!!! there may be a "
......@@ -665,10 +705,9 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
"compatibility between PaddlePaddle and Custom Runtime.";
return;
}
LoadCustomRuntimeLib(runtime_params, std::move(device_interface),
dso_lib_path, dso_handle);
LoadCustomRuntimeLib(
runtime_params, std::move(device_interface), dso_lib_path, dso_handle);
LOG(INFO) << "Successed in loading custom runtime in lib: " << dso_lib_path;
}
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -17,9 +17,9 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/custom/fake_cpu_device.h"
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h"
void RegisterDevice() {
CustomRuntimeParams runtime_params;
......@@ -30,23 +30,22 @@ void RegisterDevice() {
runtime_params.interface->size = sizeof(C_DeviceInterface);
InitFakeCPUDevice(&runtime_params);
paddle::platform::LoadCustomRuntimeLib(
phi::LoadCustomRuntimeLib(
runtime_params, std::move(device_interface), "", nullptr);
}
void InitDevice() {
RegisterDevice();
EXPECT_GT(static_cast<int>(
paddle::platform::DeviceManager::GetAllDeviceTypes().size()),
EXPECT_GT(static_cast<int>(phi::DeviceManager::GetAllDeviceTypes().size()),
0);
auto place = paddle::platform::CustomPlace(DEVICE_TYPE, 0);
auto device = paddle::platform::DeviceManager::GetDeviceWithPlace(place);
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
EXPECT_NE(device, nullptr);
std::vector<paddle::platform::Place> places;
auto device_types = paddle::platform::DeviceManager::GetAllDeviceTypes();
auto device_types = phi::DeviceManager::GetAllDeviceTypes();
for (auto dev_type : device_types) {
auto devices = paddle::platform::DeviceManager::GetDeviceList(dev_type);
auto devices = phi::DeviceManager::GetDeviceList(dev_type);
for (auto dev_id : devices) {
places.push_back(
paddle::platform::PlaceHelper::CreatePlace(dev_type, dev_id));
......@@ -60,14 +59,14 @@ void InitDevice() {
void TestDeviceInterface(const paddle::platform::Place& place) {
std::cout << "TestDeviceInterface on " << place << std::endl;
if (paddle::platform::is_custom_place(place)) {
auto device = paddle::platform::DeviceManager::GetDeviceWithPlace(place);
auto device = phi::DeviceManager::GetDeviceWithPlace(place);
auto dev_type = paddle::platform::PlaceHelper::GetDeviceType(place);
auto p1 = device->MemoryAllocate(
paddle::platform::DeviceManager::GetMinChunkSize(place));
auto p1 =
device->MemoryAllocate(phi::DeviceManager::GetMinChunkSize(place));
EXPECT_NE(p1, nullptr);
paddle::platform::DeviceManager::SetDevice(place);
auto dev_id = paddle::platform::DeviceManager::GetDevice(dev_type);
phi::DeviceManager::SetDevice(place);
auto dev_id = phi::DeviceManager::GetDevice(dev_type);
EXPECT_EQ(dev_id, place.GetDeviceId());
}
}
......@@ -168,11 +167,10 @@ void TestTensorUtils(const paddle::platform::Place& place) {
TEST(CustomDevice, Tensor) {
InitDevice();
auto dev_types = paddle::platform::DeviceManager::GetAllDeviceTypes();
auto dev_types = phi::DeviceManager::GetAllDeviceTypes();
for (const auto& dev_type : dev_types) {
std::cout << "Test on " << dev_type << std::endl;
EXPECT_GT(static_cast<int>(
paddle::platform::DeviceManager::GetDeviceCount(dev_type)),
EXPECT_GT(static_cast<int>(phi::DeviceManager::GetDeviceCount(dev_type)),
0);
auto place = paddle::platform::PlaceHelper::CreatePlace(dev_type);
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/phi/backends/device_ext.h"
constexpr size_t global_total_memory = 1024 * 1024UL;
static size_t global_free_memory = global_total_memory;
......@@ -43,14 +43,19 @@ C_Status GetDevicesList(size_t *device) {
return C_SUCCESS;
}
C_Status MemCpy(const C_Device device, void *dst, const void *src,
C_Status MemCpy(const C_Device device,
void *dst,
const void *src,
size_t size) {
memcpy(dst, src, size);
return C_SUCCESS;
}
C_Status AsyncMemCpy(const C_Device device, C_Stream stream, void *dst,
const void *src, size_t size) {
C_Status AsyncMemCpy(const C_Device device,
C_Stream stream,
void *dst,
const void *src,
size_t size) {
memcpy(dst, src, size);
return C_SUCCESS;
}
......@@ -100,14 +105,16 @@ C_Status SyncStream(const C_Device device, C_Stream stream) {
C_Status SyncEvent(const C_Device device, C_Event event) { return C_SUCCESS; }
C_Status StreamWaitEvent(const C_Device device, C_Stream stream,
C_Status StreamWaitEvent(const C_Device device,
C_Stream stream,
C_Event event) {
return C_SUCCESS;
}
C_Status VisibleDevices(size_t *devices) { return C_SUCCESS; }
C_Status DeviceMemStats(const C_Device device, size_t *total_memory,
C_Status DeviceMemStats(const C_Device device,
size_t *total_memory,
size_t *free_memory) {
*total_memory = global_total_memory;
*free_memory = global_free_memory;
......@@ -139,7 +146,8 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params->version.minor = PADDLE_CUSTOM_RUNTIME_MINOR_VERSION;
params->version.patch = PADDLE_CUSTOM_RUNTIME_PATCH_VERSION;
memset(reinterpret_cast<void *>(params->interface), 0,
memset(reinterpret_cast<void *>(params->interface),
0,
sizeof(C_DeviceInterface));
params->interface->initialize = Init;
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/device_base.h"
#include "paddle/phi/backends/device_base.h"
#include "gflags/gflags.h"
DECLARE_double(fraction_of_gpu_memory_to_use);
......@@ -21,26 +21,25 @@ DECLARE_uint64(reallocate_gpu_memory_in_mb);
constexpr static float fraction_reserve_gpu_memory = 0.05f;
namespace paddle {
namespace platform {
namespace phi {
#define INTERFACE_UNIMPLEMENT \
PADDLE_THROW(platform::errors::Unimplemented( \
#define INTERFACE_UNIMPLEMENT \
PADDLE_THROW(phi::errors::Unimplemented( \
"%s is not implemented on %s device.", __func__, Type()));
// info
size_t DeviceInterface::GetComputeCapability() {
VLOG(10) << Type() + " get compute capability " << 0;
VLOG(10) << Type() << " get compute capability " << 0;
return 0;
}
size_t DeviceInterface::GetRuntimeVersion() {
VLOG(10) << Type() + " get runtime version " << 0;
VLOG(10) << Type() << " get runtime version " << 0;
return 0;
}
size_t DeviceInterface::GetDriverVersion() {
VLOG(10) << Type() + " get driver version " << 0;
VLOG(10) << Type() << " get driver version " << 0;
return 0;
}
......@@ -62,7 +61,8 @@ void DeviceInterface::SetDevice(size_t dev_id) { INTERFACE_UNIMPLEMENT; }
int DeviceInterface::GetDevice() { INTERFACE_UNIMPLEMENT; }
// stream manage
void DeviceInterface::CreateStream(size_t dev_id, stream::Stream* stream,
void DeviceInterface::CreateStream(size_t dev_id,
stream::Stream* stream,
const stream::Stream::Priority& priority,
const stream::Stream::Flag& flag) {
INTERFACE_UNIMPLEMENT;
......@@ -82,7 +82,8 @@ bool DeviceInterface::QueryStream(size_t dev_id, const stream::Stream* stream) {
return true;
}
void DeviceInterface::AddCallback(size_t dev_id, stream::Stream* stream,
void DeviceInterface::AddCallback(size_t dev_id,
stream::Stream* stream,
stream::Stream::Callback* callback) {
INTERFACE_UNIMPLEMENT;
}
......@@ -94,7 +95,8 @@ void DeviceInterface::StreamWaitEvent(size_t dev_id,
}
// event manage
void DeviceInterface::CreateEvent(size_t dev_id, event::Event* event,
void DeviceInterface::CreateEvent(size_t dev_id,
event::Event* event,
event::Event::Flag flags) {
INTERFACE_UNIMPLEMENT;
}
......@@ -103,7 +105,8 @@ void DeviceInterface::DestroyEvent(size_t dev_id, event::Event* event) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::RecordEvent(size_t dev_id, const event::Event* event,
void DeviceInterface::RecordEvent(size_t dev_id,
const event::Event* event,
const stream::Stream* stream) {
INTERFACE_UNIMPLEMENT;
}
......@@ -119,23 +122,35 @@ bool DeviceInterface::QueryEvent(size_t dev_id, const event::Event* event) {
}
// memery manage
void DeviceInterface::MemoryCopyH2D(size_t dev_id, void* dst, const void* src,
size_t size, const stream::Stream* stream) {
void DeviceInterface::MemoryCopyH2D(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::MemoryCopyD2H(size_t dev_id, void* dst, const void* src,
size_t size, const stream::Stream* stream) {
void DeviceInterface::MemoryCopyD2H(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::MemoryCopyD2D(size_t dev_id, void* dst, const void* src,
size_t size, const stream::Stream* stream) {
void DeviceInterface::MemoryCopyD2D(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::MemoryCopyP2P(const Place& dst_place, void* dst,
size_t src_id, const void* src, size_t size,
void DeviceInterface::MemoryCopyP2P(const Place& dst_place,
void* dst,
size_t src_id,
const void* src,
size_t size,
const stream::Stream* stream) {
INTERFACE_UNIMPLEMENT;
}
......@@ -154,7 +169,8 @@ void* DeviceInterface::MemoryAllocateHost(size_t dev_id, size_t size) {
return nullptr;
}
void DeviceInterface::MemoryDeallocateHost(size_t dev_id, void* ptr,
void DeviceInterface::MemoryDeallocateHost(size_t dev_id,
void* ptr,
size_t size) {
INTERFACE_UNIMPLEMENT;
}
......@@ -164,12 +180,15 @@ void* DeviceInterface::MemoryAllocateUnified(size_t dev_id, size_t size) {
return nullptr;
}
void DeviceInterface::MemoryDeallocateUnified(size_t dev_id, void* ptr,
void DeviceInterface::MemoryDeallocateUnified(size_t dev_id,
void* ptr,
size_t size) {
INTERFACE_UNIMPLEMENT;
}
void DeviceInterface::MemorySet(size_t dev_id, void* ptr, uint8_t value,
void DeviceInterface::MemorySet(size_t dev_id,
void* ptr,
uint8_t value,
size_t size) {
INTERFACE_UNIMPLEMENT;
}
......@@ -184,8 +203,9 @@ size_t DeviceInterface::GetMinChunkSize(size_t dev_id) {
size_t DeviceInterface::AllocSize(size_t dev_id, bool realloc) {
size_t available_to_alloc = AvailableAllocSize(dev_id);
PADDLE_ENFORCE_GT(available_to_alloc, 0,
platform::errors::ResourceExhausted(
PADDLE_ENFORCE_GT(available_to_alloc,
0,
phi::errors::ResourceExhausted(
"Not enough available %s memory.", Type()));
// If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
// allocated by fraction
......@@ -194,8 +214,9 @@ size_t DeviceInterface::AllocSize(size_t dev_id, bool realloc) {
size_t alloc_bytes =
(flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
FLAGS_fraction_of_gpu_memory_to_use);
PADDLE_ENFORCE_GE(available_to_alloc, alloc_bytes,
platform::errors::ResourceExhausted(
PADDLE_ENFORCE_GE(available_to_alloc,
alloc_bytes,
phi::errors::ResourceExhausted(
"Not enough available %s memory.", Type()));
return alloc_bytes;
}
......@@ -217,33 +238,32 @@ size_t DeviceInterface::AvailableAllocSize(size_t dev_id) {
size_t DeviceInterface::GetInitAllocSize(size_t dev_id) {
size_t init_alloc_size = AllocSize(dev_id, false);
VLOG(10) << Type() + " init alloc size " << (init_alloc_size >> 20) << "M";
VLOG(10) << Type() << " init alloc size " << (init_alloc_size >> 20) << "M";
return init_alloc_size;
}
size_t DeviceInterface::GetReallocSize(size_t dev_id) {
size_t realloc_size = AllocSize(dev_id, true);
VLOG(10) << Type() + " realloc size " << (realloc_size >> 20) << "M";
VLOG(10) << Type() << " realloc size " << (realloc_size >> 20) << "M";
return realloc_size;
}
size_t DeviceInterface::GetMaxAllocSize(size_t dev_id) {
size_t max_alloc_size =
std::max(GetInitAllocSize(dev_id), GetReallocSize(dev_id));
VLOG(10) << Type() + " max alloc size " << (max_alloc_size >> 20) << "M";
VLOG(10) << Type() << " max alloc size " << (max_alloc_size >> 20) << "M";
return max_alloc_size;
}
size_t DeviceInterface::GetMaxChunkSize(size_t dev_id) {
size_t max_chunk_size = GetMaxAllocSize(dev_id);
VLOG(10) << Type() + " max chunk size " << (max_chunk_size >> 20) << "M";
VLOG(10) << Type() << " max chunk size " << (max_chunk_size >> 20) << "M";
return max_chunk_size;
}
size_t DeviceInterface::GetExtraPaddingSize(size_t dev_id) {
VLOG(10) << Type() + " extra padding size " << 0;
VLOG(10) << Type() << " extra padding size " << 0;
return 0;
}
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -14,11 +14,10 @@
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
namespace paddle {
namespace platform {
namespace phi {
class DeviceInterface { // Driver / Runtime
public:
......@@ -66,7 +65,8 @@ class DeviceInterface { // Driver / Runtime
// Stream
// ! Create an asynchronous stream
virtual void CreateStream(
size_t dev_id, stream::Stream* stream,
size_t dev_id,
stream::Stream* stream,
const stream::Stream::Priority& priority =
stream::Stream::Priority::kNormal,
const stream::Stream::Flag& flag = stream::Stream::Flag::kDefaultFlag);
......@@ -81,19 +81,22 @@ class DeviceInterface { // Driver / Runtime
virtual bool QueryStream(size_t dev_id, const stream::Stream* stream);
// ! Add a callback to a compute stream.
virtual void AddCallback(size_t dev_id, stream::Stream* stream,
virtual void AddCallback(size_t dev_id,
stream::Stream* stream,
stream::Stream::Callback* callback);
// Event
// ! Create an event.
virtual void CreateEvent(size_t dev_id, event::Event* event,
virtual void CreateEvent(size_t dev_id,
event::Event* event,
event::Event::Flag flags);
// ! Destroy an event.
virtual void DestroyEvent(size_t dev_id, event::Event* event);
// ! Records an event.
virtual void RecordEvent(size_t dev_id, const event::Event* event,
virtual void RecordEvent(size_t dev_id,
const event::Event* event,
const stream::Stream* stream);
// ! Waits for event to complete.
......@@ -102,24 +105,34 @@ class DeviceInterface { // Driver / Runtime
virtual bool QueryEvent(size_t dev_id, const event::Event* event);
// ! Make a compute stream wait on an event
virtual void StreamWaitEvent(size_t dev_id, const stream::Stream* stream,
virtual void StreamWaitEvent(size_t dev_id,
const stream::Stream* stream,
const event::Event* event);
// Memory
virtual void MemoryCopyH2D(size_t dev_id, void* dst, const void* src,
virtual void MemoryCopyH2D(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
virtual void MemoryCopyD2H(size_t dev_id, void* dst, const void* src,
virtual void MemoryCopyD2H(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
virtual void MemoryCopyD2D(size_t dev_id, void* dst, const void* src,
virtual void MemoryCopyD2D(size_t dev_id,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
virtual void MemoryCopyP2P(const Place& dst_place, void* dst, size_t src_id,
const void* src, size_t size,
virtual void MemoryCopyP2P(const Place& dst_place,
void* dst,
size_t src_id,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
virtual void* MemoryAllocate(size_t dev_id, size_t size);
......@@ -160,7 +173,6 @@ class DeviceInterface { // Driver / Runtime
size_t AvailableAllocSize(size_t dev_id);
};
} // namespace platform
} // namespace paddle
} // namespace phi
#endif
......@@ -40,7 +40,9 @@ typedef struct C_Stream_st* C_Stream;
typedef struct C_Event_st* C_Event;
typedef void (*C_Callback)(C_Device device, C_Stream stream, void* user_data,
typedef void (*C_Callback)(C_Device device,
C_Stream stream,
void* user_data,
C_Status* status);
struct C_DeviceInterface {
......@@ -124,8 +126,10 @@ struct C_DeviceInterface {
* @param[C_Callback] callback
* @param[void*] user_data
*/
C_Status (*stream_add_callback)(const C_Device device, C_Stream stream,
C_Callback callback, void* user_data);
C_Status (*stream_add_callback)(const C_Device device,
C_Stream stream,
C_Callback callback,
void* user_data);
/**
* @brief Create an event
......@@ -142,7 +146,8 @@ struct C_DeviceInterface {
* @param[C_Stream] stream
* @param[C_Event] event
*/
C_Status (*record_event)(const C_Device device, C_Stream stream,
C_Status (*record_event)(const C_Device device,
C_Stream stream,
C_Event event);
/**
......@@ -191,7 +196,8 @@ struct C_DeviceInterface {
* @param[C_Stream] stream
* @param[C_Event] event
*/
C_Status (*stream_wait_event)(const C_Device device, C_Stream stream,
C_Status (*stream_wait_event)(const C_Device device,
C_Stream stream,
C_Event event);
void* reserved_dev_api[8];
......@@ -207,7 +213,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
*/
C_Status (*device_memory_allocate)(const C_Device device, void** ptr,
C_Status (*device_memory_allocate)(const C_Device device,
void** ptr,
size_t size);
/**
......@@ -217,7 +224,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[size_t] size
*/
C_Status (*device_memory_deallocate)(const C_Device device, void* ptr,
C_Status (*device_memory_deallocate)(const C_Device device,
void* ptr,
size_t size);
/**
......@@ -228,8 +236,10 @@ struct C_DeviceInterface {
* @param[unsigned char] value
* @param[size_t] size
*/
C_Status (*device_memory_set)(const C_Device device, void* ptr,
unsigned char value, size_t size);
C_Status (*device_memory_set)(const C_Device device,
void* ptr,
unsigned char value,
size_t size);
/**
* @brief Host memory allocate
......@@ -238,7 +248,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
*/
C_Status (*host_memory_allocate)(const C_Device device, void** ptr,
C_Status (*host_memory_allocate)(const C_Device device,
void** ptr,
size_t size);
/**
......@@ -248,7 +259,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[size_t] size
*/
C_Status (*host_memory_deallocate)(const C_Device device, void* ptr,
C_Status (*host_memory_deallocate)(const C_Device device,
void* ptr,
size_t size);
/**
......@@ -258,7 +270,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
*/
C_Status (*unified_memory_allocate)(const C_Device device, void** ptr,
C_Status (*unified_memory_allocate)(const C_Device device,
void** ptr,
size_t size);
/**
......@@ -268,7 +281,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[size_t] size
*/
C_Status (*unified_memory_deallocate)(const C_Device device, void* ptr,
C_Status (*unified_memory_deallocate)(const C_Device device,
void* ptr,
size_t size);
/**
......@@ -279,7 +293,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status (*memory_copy_h2d)(const C_Device device, void* dst, const void* src,
C_Status (*memory_copy_h2d)(const C_Device device,
void* dst,
const void* src,
size_t size);
/**
......@@ -290,7 +306,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status (*memory_copy_d2h)(const C_Device device, void* dst, const void* src,
C_Status (*memory_copy_d2h)(const C_Device device,
void* dst,
const void* src,
size_t size);
/**
......@@ -301,7 +319,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status (*memory_copy_d2d)(const C_Device device, void* dst, const void* src,
C_Status (*memory_copy_d2d)(const C_Device device,
void* dst,
const void* src,
size_t size);
/**
......@@ -314,8 +334,10 @@ struct C_DeviceInterface {
* @param[size_t] size
*/
C_Status (*memory_copy_p2p)(const C_Device dst_device,
const C_Device src_device, void* dst,
const void* src, size_t size);
const C_Device src_device,
void* dst,
const void* src,
size_t size);
/**
* @brief Asynchonrize memory copy from host to device
......@@ -326,8 +348,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status (*async_memory_copy_h2d)(const C_Device device, C_Stream stream,
void* dst, const void* src, size_t size);
C_Status (*async_memory_copy_h2d)(const C_Device device,
C_Stream stream,
void* dst,
const void* src,
size_t size);
/**
* @brief Asynchonrize memory copy from device to host
......@@ -338,8 +363,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status (*async_memory_copy_d2h)(const C_Device device, C_Stream stream,
void* dst, const void* src, size_t size);
C_Status (*async_memory_copy_d2h)(const C_Device device,
C_Stream stream,
void* dst,
const void* src,
size_t size);
/**
* @brief Asynchonrize memory copy from device to device
......@@ -350,8 +378,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status (*async_memory_copy_d2d)(const C_Device device, C_Stream stream,
void* dst, const void* src, size_t size);
C_Status (*async_memory_copy_d2d)(const C_Device device,
C_Stream stream,
void* dst,
const void* src,
size_t size);
/**
* @brief Peer asynchonrize memory copy from host to device
......@@ -363,8 +394,11 @@ struct C_DeviceInterface {
* @param[size_t] size
*/
C_Status (*async_memory_copy_p2p)(const C_Device dst_device,
const C_Device src_device, C_Stream stream,
void* dst, const void* src, size_t size);
const C_Device src_device,
C_Stream stream,
void* dst,
const void* src,
size_t size);
void* reserved_mem_api[8];
......@@ -394,7 +428,8 @@ struct C_DeviceInterface {
* @param[size_t*] free_memory
* @param[size_t*] used_memory
*/
C_Status (*device_memory_stats)(const C_Device device, size_t* total_memory,
C_Status (*device_memory_stats)(const C_Device device,
size_t* total_memory,
size_t* free_memory);
/**
......
......@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/phi/backends/device_guard.h"
namespace paddle {
namespace platform {
namespace phi {
// Even this source file does not contains any code, it is better to keep this
// source file for cmake dependency.
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -13,17 +13,16 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/phi/backends/device_manager.h"
namespace paddle {
namespace platform {
namespace phi {
class DeviceGuard {
public:
explicit inline DeviceGuard(const Place& place)
: dev_type_(PlaceHelper::GetDeviceType(place)) {
: dev_type_(place.GetDeviceType()) {
prev_id = DeviceManager::GetDevice(dev_type_);
cur_id = PlaceHelper::GetDeviceId(place);
cur_id = place.GetDeviceId();
if (cur_id != prev_id) {
DeviceManager::SetDevice(dev_type_, cur_id);
......@@ -44,5 +43,4 @@ class DeviceGuard {
std::string dev_type_;
};
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -13,7 +13,7 @@
// limitations under the License.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/phi/backends/device_manager.h"
#if !defined(_WIN32)
#include <dirent.h>
......@@ -24,8 +24,7 @@
#include <functional>
#include <regex>
namespace paddle {
namespace platform {
namespace phi {
void Device::CreateStream(stream::Stream* stream,
const stream::Stream::Priority& priority,
......@@ -76,23 +75,32 @@ void Device::StreamWaitEvent(const stream::Stream* stream,
impl_->StreamWaitEvent(dev_id_, stream, event);
}
void Device::MemoryCopyH2D(void* dst, const void* src, size_t size,
void Device::MemoryCopyH2D(void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
impl_->MemoryCopyH2D(dev_id_, dst, src, size, stream);
}
void Device::MemoryCopyD2H(void* dst, const void* src, size_t size,
void Device::MemoryCopyD2H(void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
impl_->MemoryCopyD2H(dev_id_, dst, src, size, stream);
}
void Device::MemoryCopyD2D(void* dst, const void* src, size_t size,
void Device::MemoryCopyD2D(void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
impl_->MemoryCopyD2D(dev_id_, dst, src, size, stream);
}
void Device::MemoryCopyP2P(const Place& dst_place, void* dst, const void* src,
size_t size, const stream::Stream* stream) {
void Device::MemoryCopyP2P(const Place& dst_place,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream) {
impl_->MemoryCopyP2P(dst_place, dst, dev_id_, src, size, stream);
}
......@@ -173,7 +181,7 @@ DeviceInterface* DeviceManager::GetDeviceInterfaceWithType(
} else {
LOG(ERROR) << "GetDeviceInterfaceWithType - " << device_type << " Failed\n";
PADDLE_THROW(
platform::errors::Fatal("Unregistered device type %s.", device_type));
phi::errors::Fatal("Unregistered device type %s.", device_type));
return nullptr;
}
}
......@@ -182,17 +190,21 @@ Device* DeviceManager::GetDeviceWithPlace(const Place& place) {
phi::AutoRDLock lock(&_global_device_manager_rw_lock);
auto& dev_map = Instance().device_map_;
auto dev_type = PlaceHelper::GetDeviceType(place);
auto dev_id = PlaceHelper::GetDeviceId(place);
PADDLE_ENFORCE_NE(dev_map.find(dev_type), dev_map.end(),
platform::errors::NotFound(
"Unable to find Device with type %s.", dev_type));
auto dev_type = place.GetDeviceType();
auto dev_id = place.GetDeviceId();
PADDLE_ENFORCE_NE(
dev_map.find(dev_type),
dev_map.end(),
phi::errors::NotFound("Unable to find Device with type %s.", dev_type));
auto& dev_vec = dev_map[dev_type];
PADDLE_ENFORCE_LT(
dev_id, dev_vec.size(),
platform::errors::OutOfRange(
dev_id,
dev_vec.size(),
phi::errors::OutOfRange(
"The visible devices count of type %s is %d, but dev_id is %d.",
dev_type, dev_vec.size(), dev_id));
dev_type,
dev_vec.size(),
dev_id));
return dev_vec[dev_id].get();
}
......@@ -277,22 +289,22 @@ void DeviceManager::Finalize(const std::string& device_type) {
}
void DeviceManager::SynchronizeDevice(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->SynchronizeDevice(device_id);
}
void DeviceManager::InitDevice(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->InitDevice(device_id);
}
void DeviceManager::DeInitDevice(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->DeInitDevice(device_id);
}
......@@ -304,8 +316,8 @@ void DeviceManager::SetDevice(const std::string& device_type,
}
void DeviceManager::SetDevice(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
DeviceManager::SetDevice(device_type, device_id);
}
......@@ -315,51 +327,52 @@ int DeviceManager::GetDevice(const std::string& device_type) {
}
size_t DeviceManager::GetMinChunkSize(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->GetMinChunkSize(device_id);
}
size_t DeviceManager::GetMaxChunkSize(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->GetMaxChunkSize(device_id);
}
size_t DeviceManager::GetMaxAllocSize(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->GetMaxAllocSize(device_id);
}
size_t DeviceManager::GetInitAllocSize(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->GetInitAllocSize(device_id);
}
size_t DeviceManager::GetReallocSize(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->GetReallocSize(device_id);
}
size_t DeviceManager::GetExtraPaddingSize(const Place& place) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->GetExtraPaddingSize(device_id);
}
void DeviceManager::MemoryStats(const Place& place, size_t* total,
void DeviceManager::MemoryStats(const Place& place,
size_t* total,
size_t* free) {
auto device_type = PlaceHelper::GetDeviceType(place);
auto device_id = PlaceHelper::GetDeviceId(place);
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
dev_impl->MemoryStats(device_id, total, free);
}
......@@ -393,8 +406,8 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
} else {
while ((ptr = readdir(dir)) != nullptr) {
std::string filename(ptr->d_name);
if (std::regex_match(filename.begin(), filename.end(), results,
express)) {
if (std::regex_match(
filename.begin(), filename.end(), results, express)) {
libraries.push_back(library_dir + '/' + filename);
VLOG(4) << "Found lib: " << libraries.back();
}
......@@ -405,6 +418,5 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
return libraries;
}
} // namespace platform
} // namespace paddle
} // namespace phi
#endif
......@@ -15,17 +15,16 @@
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_base.h"
#include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_ext.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace paddle {
namespace platform {
namespace phi {
class Device final {
public:
Device(size_t dev_id, DeviceInterface* impl) : dev_id_(dev_id), impl_(impl) {}
......@@ -33,8 +32,9 @@ class Device final {
// Stream
// ! Create an asynchronous stream
void CreateStream(
stream::Stream* stream, const stream::Stream::Priority& priority =
stream::Stream::Priority::kNormal,
stream::Stream* stream,
const stream::Stream::Priority& priority =
stream::Stream::Priority::kNormal,
const stream::Stream::Flag& flag = stream::Stream::Flag::kDefaultFlag);
// ! Destroys an asynchronous stream.
......@@ -69,17 +69,26 @@ class Device final {
void StreamWaitEvent(const stream::Stream* stream, const event::Event* event);
// Memory
void MemoryCopyH2D(void* dst, const void* src, size_t size,
void MemoryCopyH2D(void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
void MemoryCopyD2H(void* dst, const void* src, size_t size,
void MemoryCopyD2H(void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
void MemoryCopyD2D(void* dst, const void* src, size_t size,
void MemoryCopyD2D(void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
void MemoryCopyP2P(const Place& dst_place, void* dst, const void* src,
size_t size, const stream::Stream* stream = nullptr);
void MemoryCopyP2P(const Place& dst_place,
void* dst,
const void* src,
size_t size,
const stream::Stream* stream = nullptr);
void* MemoryAllocate(size_t size);
......@@ -168,7 +177,8 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
void LoadCustomRuntimeLib(const CustomRuntimeParams& runtime_params,
std::unique_ptr<C_DeviceInterface> device_interface,
const std::string& dso_lib_path, void* dso_handle);
const std::string& dso_lib_path,
void* dso_handle);
class Registrar {
public:
......@@ -180,7 +190,6 @@ class Registrar {
void Touch() {}
};
} // namespace platform
} // namespace paddle
} // namespace phi
#endif
......@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/phi/backends/event.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/stream.h"
namespace paddle {
namespace platform {
namespace phi {
namespace event {
event_t Event::raw_event() const { return event_; }
......@@ -27,7 +26,7 @@ void Event::set_event(event_t event) { event_ = event; }
Event::Event(const Place& place, event_t event)
: place_(place),
device_(platform::DeviceManager::GetDeviceWithPlace(place)),
device_(phi::DeviceManager::GetDeviceWithPlace(place)),
event_(event),
own_data_(false) {}
......@@ -60,5 +59,4 @@ void Event::Synchonrize() const { device_->SynchronizeEvent(this); }
const Place& Event::GetPlace() const { return place_; }
} // namespace event
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -15,8 +15,7 @@
#pragma once
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
namespace phi {
class Device;
......@@ -57,5 +56,4 @@ class Event {
};
} // namespace event
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/event.h"
namespace paddle {
namespace platform {
namespace phi {
namespace stream {
Stream::~Stream() { Destroy(); }
......@@ -30,15 +29,16 @@ void Stream::set_stream(stream_t stream) { stream_ = stream; }
// For compatiable
Stream::Stream(const Place& place, stream_t stream)
: place_(place),
device_(platform::DeviceManager::GetDeviceWithPlace(place)),
device_(phi::DeviceManager::GetDeviceWithPlace(place)),
stream_(stream),
callback_manager_(new CallbackManager(this)),
own_data_(false) {}
bool Stream::Init(const Place& place, const Priority& priority,
bool Stream::Init(const Place& place,
const Priority& priority,
const Flag& flag) {
place_ = place;
device_ = platform::DeviceManager::GetDeviceWithPlace(place);
device_ = phi::DeviceManager::GetDeviceWithPlace(place);
DeviceGuard guard(place_);
device_->CreateStream(this, priority, flag);
......@@ -92,5 +92,4 @@ void Stream::Synchronize() const { device_->SynchronizeStream(this); }
const Place& Stream::GetPlace() const { return place_; }
} // namespace stream
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -14,11 +14,10 @@
#pragma once
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/callback_manager.h"
namespace paddle {
namespace platform {
namespace phi {
class Device;
......@@ -49,7 +48,8 @@ class Stream {
~Stream();
const stream_t& raw_stream() const;
void set_stream(stream_t stream);
bool Init(const Place& place, const Priority& priority = Priority::kNormal,
bool Init(const Place& place,
const Priority& priority = Priority::kNormal,
const Flag& flag = Flag::kDefaultFlag);
template <typename Callback>
void AddCallback(Callback&& callback) const {
......@@ -75,5 +75,4 @@ class Stream {
};
} // namespace stream
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -25,7 +25,7 @@ cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy)
cc_library(phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows)
cc_library(phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils)
cc_library(phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils op_registry phi_tensor_raw)
# Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN)
......
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/phi/backends/device_manager.h"
#endif
namespace phi {
......@@ -83,9 +83,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
if (!device_type.empty()) {
return phi::CustomPlace(
device_type,
set_device_id
? paddle::platform::DeviceManager::GetDevice(device_type)
: 0);
set_device_id ? phi::DeviceManager::GetDevice(device_type) : 0);
}
#endif
PADDLE_THROW(phi::errors::Unimplemented(
......
......@@ -12,6 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/phi/core/custom_kernel.h"
namespace phi {
......@@ -50,6 +55,25 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
}
}
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX
typedef phi::CustomKernelMap& get_custom_kernel_map_t();
auto* func = reinterpret_cast<get_custom_kernel_map_t*>(
dlsym(dso_handle, "PD_GetCustomKernelMap"));
if (func == nullptr) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetCustomKernelMap symbol in this lib.";
return;
}
auto& custom_kernel_map = func();
phi::RegisterCustomKernels(custom_kernel_map);
LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path;
#else
VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux.";
#endif
return;
}
} // namespace phi
#ifdef __cplusplus
......
......@@ -46,4 +46,6 @@ class CustomKernelMap {
*/
void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map);
// Load custom kernel lib and register
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
} // namespace phi
......@@ -579,8 +579,7 @@ headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # phi core headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/backends', recursive=True)) + # phi backends headers
# utila api headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True)) + # paddle utils headers
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/device/device_ext.h'])
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True))) # paddle utils headers
if '${WITH_MKLDNN}' == 'ON':
headers += list(find_files('*', '${MKLDNN_INSTALL_DIR}/include')) # mkldnn
......@@ -625,8 +624,6 @@ class InstallHeaders(Command):
elif 'third_party' not in header:
# paddle headers
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
if 'device_ext.h' in header:
install_dir = "paddle/"
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册