Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b4665d23
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4665d23
编写于
3月 03, 2022
作者:
R
ronnywang
提交者:
GitHub
3月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomRuntime] migrate CustomRuntime into phi (#39908)
上级
756af9ff
变更
48
隐藏空白更改
内联
并排
Showing
48 changed file
with
513 addition
and
462 deletion
+513
-462
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-2
paddle/fluid/framework/custom_kernel.cc
paddle/fluid/framework/custom_kernel.cc
+0
-47
paddle/fluid/framework/custom_kernel.h
paddle/fluid/framework/custom_kernel.h
+0
-26
paddle/fluid/framework/garbage_collector.cc
paddle/fluid/framework/garbage_collector.cc
+5
-5
paddle/fluid/framework/garbage_collector.h
paddle/fluid/framework/garbage_collector.h
+3
-3
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+1
-1
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+1
-1
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+1
-1
paddle/fluid/memory/allocation/allocator_facade.cc
paddle/fluid/memory/allocation/allocator_facade.cc
+7
-8
paddle/fluid/memory/allocation/custom_allocator.cc
paddle/fluid/memory/allocation/custom_allocator.cc
+3
-4
paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
+8
-9
paddle/fluid/memory/detail/buddy_allocator.cc
paddle/fluid/memory/detail/buddy_allocator.cc
+2
-2
paddle/fluid/memory/detail/system_allocator.cc
paddle/fluid/memory/detail/system_allocator.cc
+3
-3
paddle/fluid/memory/memcpy.cc
paddle/fluid/memory/memcpy.cc
+10
-10
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+1
-1
paddle/fluid/platform/device/CMakeLists.txt
paddle/fluid/platform/device/CMakeLists.txt
+0
-20
paddle/fluid/platform/device/custom/CMakeLists.txt
paddle/fluid/platform/device/custom/CMakeLists.txt
+0
-4
paddle/fluid/platform/device/custom/enforce_custom.h
paddle/fluid/platform/device/custom/enforce_custom.h
+4
-1
paddle/fluid/platform/device/device_wrapper.h
paddle/fluid/platform/device/device_wrapper.h
+5
-5
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+1
-1
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+3
-3
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+6
-6
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+7
-7
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+2
-2
paddle/phi/backends/CMakeLists.txt
paddle/phi/backends/CMakeLists.txt
+7
-0
paddle/phi/backends/callback_manager.cc
paddle/phi/backends/callback_manager.cc
+5
-7
paddle/phi/backends/callback_manager.h
paddle/phi/backends/callback_manager.h
+2
-4
paddle/phi/backends/custom/CMakeLists.txt
paddle/phi/backends/custom/CMakeLists.txt
+2
-0
paddle/phi/backends/custom/custom_context.cc
paddle/phi/backends/custom/custom_context.cc
+5
-5
paddle/phi/backends/custom/custom_device.cc
paddle/phi/backends/custom/custom_device.cc
+103
-64
paddle/phi/backends/custom/custom_device_test.cc
paddle/phi/backends/custom/custom_device_test.cc
+14
-16
paddle/phi/backends/custom/fake_cpu_device.h
paddle/phi/backends/custom/fake_cpu_device.h
+15
-7
paddle/phi/backends/device_base.cc
paddle/phi/backends/device_base.cc
+54
-34
paddle/phi/backends/device_base.h
paddle/phi/backends/device_base.h
+28
-16
paddle/phi/backends/device_ext.h
paddle/phi/backends/device_ext.h
+62
-27
paddle/phi/backends/device_guard.cc
paddle/phi/backends/device_guard.cc
+3
-5
paddle/phi/backends/device_guard.h
paddle/phi/backends/device_guard.h
+5
-7
paddle/phi/backends/device_manager.cc
paddle/phi/backends/device_manager.cc
+56
-44
paddle/phi/backends/device_manager.h
paddle/phi/backends/device_manager.h
+26
-17
paddle/phi/backends/event.cc
paddle/phi/backends/event.cc
+6
-8
paddle/phi/backends/event.h
paddle/phi/backends/event.h
+2
-4
paddle/phi/backends/stream.cc
paddle/phi/backends/stream.cc
+9
-10
paddle/phi/backends/stream.h
paddle/phi/backends/stream.h
+5
-6
paddle/phi/core/CMakeLists.txt
paddle/phi/core/CMakeLists.txt
+1
-1
paddle/phi/core/compat/convert_utils.cc
paddle/phi/core/compat/convert_utils.cc
+2
-4
paddle/phi/core/custom_kernel.cc
paddle/phi/core/custom_kernel.cc
+24
-0
paddle/phi/core/custom_kernel.h
paddle/phi/core/custom_kernel.h
+2
-0
python/setup.py.in
python/setup.py.in
+1
-4
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
b4665d23
...
...
@@ -440,11 +440,10 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file
(
commit.h.in commit.h
)
cc_library
(
custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper phi_tensor op_meta_info phi_api
)
cc_library
(
custom_kernel SRCS custom_kernel.cc DEPS op_registry phi_custom_kernel phi_tensor_raw
)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator custom_kernel
)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator
phi_
custom_kernel
)
cc_library
(
paddle_framework DEPS
${
FLUID_FRAMEWORK_MODULES
}
)
...
...
paddle/fluid/framework/custom_kernel.cc
已删除
100644 → 0
浏览文件 @
756af9ff
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/fluid/framework/custom_kernel.h"
#include "paddle/phi/core/custom_kernel.h"
namespace
paddle
{
namespace
framework
{
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
#ifdef _LINUX
typedef
phi
::
CustomKernelMap
&
get_custom_kernel_map_t
();
auto
*
func
=
reinterpret_cast
<
get_custom_kernel_map_t
*>
(
dlsym
(
dso_handle
,
"PD_GetCustomKernelMap"
));
if
(
func
==
nullptr
)
{
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
<<
"]: fail to find "
<<
"PD_GetCustomKernelMap symbol in this lib."
;
return
;
}
auto
&
custom_kernel_map
=
func
();
phi
::
RegisterCustomKernels
(
custom_kernel_map
);
LOG
(
INFO
)
<<
"Successed in loading custom kernels in lib: "
<<
dso_lib_path
;
#else
VLOG
(
3
)
<<
"Unsupported: Custom kernel is only implemented on Linux."
;
#endif
return
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/custom_kernel.h
已删除
100644 → 0
浏览文件 @
756af9ff
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
namespace
paddle
{
namespace
framework
{
// Load custom kernel lib and register
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/garbage_collector.cc
浏览文件 @
b4665d23
...
...
@@ -231,19 +231,19 @@ void CustomDeviceUnsafeFastGarbageCollector::ClearCallback(
CustomStreamGarbageCollector
::
CustomStreamGarbageCollector
(
const
platform
::
CustomPlace
&
place
,
size_t
max_memory_size
)
:
GarbageCollector
(
place
,
max_memory_size
)
{
p
latform
::
DeviceGuard
guard
(
place
);
stream_
.
reset
(
new
p
latform
::
stream
::
Stream
);
p
hi
::
DeviceGuard
guard
(
place
);
stream_
.
reset
(
new
p
hi
::
stream
::
Stream
);
stream_
->
Init
(
place
);
callback_manager_
.
reset
(
new
p
latform
::
CallbackManager
(
stream_
.
get
()));
callback_manager_
.
reset
(
new
p
hi
::
CallbackManager
(
stream_
.
get
()));
}
CustomStreamGarbageCollector
::~
CustomStreamGarbageCollector
()
{
p
latform
::
DeviceGuard
guard
(
this
->
dev_ctx_
->
GetPlace
());
p
hi
::
DeviceGuard
guard
(
this
->
dev_ctx_
->
GetPlace
());
stream_
->
Synchronize
();
stream_
->
Destroy
();
}
p
latform
::
stream
::
Stream
*
CustomStreamGarbageCollector
::
stream
()
const
{
p
hi
::
stream
::
Stream
*
CustomStreamGarbageCollector
::
stream
()
const
{
return
stream_
.
get
();
}
...
...
paddle/fluid/framework/garbage_collector.h
浏览文件 @
b4665d23
...
...
@@ -230,14 +230,14 @@ class CustomStreamGarbageCollector : public GarbageCollector {
void
Wait
()
const
override
;
p
latform
::
stream
::
Stream
*
stream
()
const
;
p
hi
::
stream
::
Stream
*
stream
()
const
;
protected:
void
ClearCallback
(
const
std
::
function
<
void
()
>
&
callback
)
override
;
private:
std
::
unique_ptr
<
p
latform
::
stream
::
Stream
>
stream_
;
std
::
unique_ptr
<
p
latform
::
CallbackManager
>
callback_manager_
;
std
::
unique_ptr
<
p
hi
::
stream
::
Stream
>
stream_
;
std
::
unique_ptr
<
p
hi
::
CallbackManager
>
callback_manager_
;
};
#endif
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
b4665d23
...
...
@@ -254,7 +254,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
"reinstall Paddle with CustomDevice support."
,
place
));
#else
p
latform
::
DeviceManager
::
SetDevice
(
place
);
p
hi
::
DeviceManager
::
SetDevice
(
place
);
#endif
}
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
b4665d23
...
...
@@ -253,7 +253,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
#endif
}
else
if
(
platform
::
is_custom_place
(
place
))
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
p
latform
::
DeviceManager
::
SetDevice
(
place
);
p
hi
::
DeviceManager
::
SetDevice
(
place
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should compile with CustomDevice if use "
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
b4665d23
...
...
@@ -31,7 +31,7 @@ cc_library(paddle_infer_contrib SRCS paddle_infer_contrib.cc DEPS zero_copy_tens
cc_library
(
paddle_pass_builder SRCS paddle_pass_builder.cc
)
set
(
paddle_inference_api_deps lod_tensor scope reset_tensor_array
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator custom_kernel
)
analysis_config paddle_infer_contrib zero_copy_tensor trainer_desc_proto custom_operator
phi_
custom_kernel
)
if
(
WITH_CRYPTO
)
list
(
APPEND paddle_inference_api_deps paddle_crypto
)
...
...
paddle/fluid/memory/allocation/allocator_facade.cc
浏览文件 @
b4665d23
...
...
@@ -193,10 +193,10 @@ class AllocatorFacadePrivate {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
p
latform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
<
p
hi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
++
dev_id
)
{
InitNaiveBestFitCustomDeviceAllocator
(
platform
::
CustomPlace
(
dev_type
,
dev_id
));
...
...
@@ -240,10 +240,10 @@ class AllocatorFacadePrivate {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
p
latform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
<
p
hi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
++
dev_id
)
{
InitAutoGrowthCustomDeviceAllocator
(
platform
::
CustomPlace
(
dev_type
,
dev_id
),
allow_free_idle_chunk
);
...
...
@@ -738,7 +738,7 @@ class AllocatorFacadePrivate {
auto
custom_allocator
=
std
::
make_shared
<
paddle
::
memory
::
allocation
::
CustomAllocator
>
(
p
);
allocators_
[
p
]
=
std
::
make_shared
<
AutoGrowthBestFitAllocator
>
(
custom_allocator
,
p
latform
::
DeviceManager
::
GetMinChunkSize
(
p
),
custom_allocator
,
p
hi
::
DeviceManager
::
GetMinChunkSize
(
p
),
allow_free_idle_chunk
);
}
#endif
...
...
@@ -814,11 +814,10 @@ class AllocatorFacadePrivate {
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
const
auto
&
dev_type
:
device_types
)
{
for
(
size_t
dev_id
=
0
;
dev_id
<
platform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
++
)
{
dev_id
<
phi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
dev_id
++
)
{
places
.
emplace_back
(
platform
::
CustomPlace
(
dev_type
,
dev_id
));
}
}
...
...
paddle/fluid/memory/allocation/custom_allocator.cc
浏览文件 @
b4665d23
...
...
@@ -32,17 +32,16 @@ void CustomAllocator::FreeImpl(phi::Allocation* allocation) {
}
phi
::
Allocation
*
CustomAllocator
::
AllocateImpl
(
size_t
size
)
{
std
::
call_once
(
once_flag_
,
[
this
]
{
platform
::
DeviceManager
::
SetDevice
(
place_
);
});
std
::
call_once
(
once_flag_
,
[
this
]
{
phi
::
DeviceManager
::
SetDevice
(
place_
);
});
void
*
ptr
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place_
)
->
MemoryAllocate
(
size
);
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place_
)
->
MemoryAllocate
(
size
);
if
(
LIKELY
(
ptr
))
{
return
new
Allocation
(
ptr
,
size
,
place_
);
}
size_t
avail
,
total
;
p
latform
::
DeviceManager
::
MemoryStats
(
place_
,
&
total
,
&
avail
);
p
hi
::
DeviceManager
::
MemoryStats
(
place_
,
&
total
,
&
avail
);
auto
dev_type
=
platform
::
PlaceHelper
::
GetDeviceType
(
place_
);
auto
dev_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
place_
);
...
...
paddle/fluid/memory/allocation/naive_best_fit_allocator.cc
浏览文件 @
b4665d23
...
...
@@ -739,7 +739,7 @@ class BuddyAllocatorList {
private:
explicit
BuddyAllocatorList
(
const
std
::
string
&
device_type
)
:
device_type_
(
device_type
)
{
auto
devices
=
p
latform
::
DeviceManager
::
GetDeviceList
(
device_type
);
auto
devices
=
p
hi
::
DeviceManager
::
GetDeviceList
(
device_type
);
for
(
auto
dev_id
:
devices
)
{
init_flags_
[
dev_id
].
reset
(
new
std
::
once_flag
());
}
...
...
@@ -766,15 +766,15 @@ class BuddyAllocatorList {
device_type_
,
dev_id
));
std
::
call_once
(
*
init_flags_
[
dev_id
],
[
this
,
dev_id
]
{
p
latform
::
DeviceManager
::
SetDevice
(
device_type_
,
dev_id
);
p
hi
::
DeviceManager
::
SetDevice
(
device_type_
,
dev_id
);
platform
::
CustomPlace
place
(
device_type_
,
dev_id
);
allocators_
[
dev_id
].
reset
(
new
BuddyAllocator
(
std
::
unique_ptr
<
detail
::
SystemAllocator
>
(
new
detail
::
CustomAllocator
(
device_type_
,
dev_id
)),
p
latform
::
DeviceManager
::
GetMinChunkSize
(
place
),
p
latform
::
DeviceManager
::
GetMaxChunkSize
(
place
),
p
latform
::
DeviceManager
::
GetExtraPaddingSize
(
place
),
device_type_
));
p
hi
::
DeviceManager
::
GetMinChunkSize
(
place
),
p
hi
::
DeviceManager
::
GetMaxChunkSize
(
place
),
p
hi
::
DeviceManager
::
GetExtraPaddingSize
(
place
),
device_type_
));
});
return
allocators_
[
dev_id
].
get
();
...
...
@@ -808,9 +808,9 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
auto
*
ptr
=
buddy_allocator
->
Alloc
(
size
);
if
(
ptr
==
nullptr
)
{
p
latform
::
DeviceGuard
guard
(
place
);
p
hi
::
DeviceGuard
guard
(
place
);
size_t
avail
,
total
;
p
latform
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
p
hi
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
PADDLE_THROW
(
platform
::
errors
::
ResourceExhausted
(
"Cannot allocate %s in %s:%d, avaliable %s, total %s, used "
"%s. "
,
...
...
@@ -819,8 +819,7 @@ void *Alloc<platform::CustomPlace>(const platform::CustomPlace &place,
string
::
HumanReadableSize
(
total
-
avail
)));
}
else
{
if
(
FLAGS_init_allocated_mem
)
{
platform
::
DeviceManager
::
GetDeviceWithPlace
(
place
)
->
MemorySet
(
ptr
,
0xEF
,
size
);
phi
::
DeviceManager
::
GetDeviceWithPlace
(
place
)
->
MemorySet
(
ptr
,
0xEF
,
size
);
}
}
VLOG
(
10
)
<<
" pointer="
<<
ptr
;
...
...
paddle/fluid/memory/detail/buddy_allocator.cc
浏览文件 @
b4665d23
...
...
@@ -43,11 +43,11 @@ BuddyAllocator::BuddyAllocator(
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if
(
!
dev_type
.
empty
())
{
init_allocate_size_func_
=
[
dev_type
]()
{
return
p
latform
::
DeviceManager
::
GetInitAllocSize
(
return
p
hi
::
DeviceManager
::
GetInitAllocSize
(
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
));
};
re_allocate_size_func_
=
[
dev_type
]()
{
return
p
latform
::
DeviceManager
::
GetReallocSize
(
return
p
hi
::
DeviceManager
::
GetReallocSize
(
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
));
};
}
else
{
...
...
paddle/fluid/memory/detail/system_allocator.cc
浏览文件 @
b4665d23
...
...
@@ -438,7 +438,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
void
*
p
;
auto
place
=
platform
::
CustomPlace
(
dev_type_
,
dev_id_
);
auto
device
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
p
=
device
->
MemoryAllocate
(
size
);
if
(
LIKELY
(
p
))
{
VLOG
(
4
)
<<
"CustomAllocator::Alloc "
<<
p
<<
" size "
<<
size
;
...
...
@@ -447,7 +447,7 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) {
}
else
{
size_t
avail
,
total
;
p
latform
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
p
hi
::
DeviceManager
::
MemoryStats
(
place
,
&
total
,
&
avail
);
PADDLE_THROW_BAD_ALLOC
(
platform
::
errors
::
ResourceExhausted
(
"
\n\n
Out of memory error on %s %d. "
"total memory is %s, used memory is %s, "
...
...
@@ -470,7 +470,7 @@ void CustomAllocator::Free(void* p, size_t size, size_t index) {
size
,
plug_alloc_size
));
plug_alloc_size
-=
size
;
auto
place
=
platform
::
CustomPlace
(
dev_type_
,
dev_id_
);
auto
device
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
device
->
MemoryDeallocate
(
p
,
size
);
}
...
...
paddle/fluid/memory/memcpy.cc
浏览文件 @
b4665d23
...
...
@@ -44,9 +44,9 @@ void Copy<platform::CPUPlace, platform::CustomPlace>(
VLOG
(
4
)
<<
"memory::Copy "
<<
num
<<
" Bytes from "
<<
src_place
<<
" to "
<<
dst_place
<<
", stream="
<<
stream
;
p
latform
::
DeviceManager
::
SetDevice
(
src_place
);
p
latform
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2H
(
p
hi
::
DeviceManager
::
SetDevice
(
src_place
);
p
hi
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2H
(
dst
,
src
,
num
,
&
stream_wrapper
);
}
...
...
@@ -62,9 +62,9 @@ void Copy<platform::CustomPlace, platform::CPUPlace>(
VLOG
(
4
)
<<
"memory::Copy "
<<
num
<<
" Bytes from "
<<
src_place
<<
" to "
<<
dst_place
<<
", stream="
<<
stream
;
p
latform
::
DeviceManager
::
SetDevice
(
dst_place
);
p
latform
::
stream
::
Stream
stream_wrapper
(
dst_place
,
stream
);
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
dst_place
)
->
MemoryCopyH2D
(
p
hi
::
DeviceManager
::
SetDevice
(
dst_place
);
p
hi
::
stream
::
Stream
stream_wrapper
(
dst_place
,
stream
);
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
dst_place
)
->
MemoryCopyH2D
(
dst
,
src
,
num
,
&
stream_wrapper
);
}
...
...
@@ -82,16 +82,16 @@ void Copy<platform::CustomPlace, platform::CustomPlace>(
<<
dst_place
<<
", stream="
<<
stream
;
if
(
src_type
==
dst_type
)
{
p
latform
::
DeviceManager
::
SetDevice
(
src_place
);
p
latform
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
p
hi
::
DeviceManager
::
SetDevice
(
src_place
);
p
hi
::
stream
::
Stream
stream_wrapper
(
src_place
,
stream
);
auto
src_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
src_place
);
auto
dst_id
=
platform
::
PlaceHelper
::
GetDeviceId
(
dst_place
);
if
(
src_id
==
dst_id
)
{
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2D
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyD2D
(
dst
,
src
,
num
,
&
stream_wrapper
);
}
else
{
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyP2P
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
src_place
)
->
MemoryCopyP2P
(
dst_place
,
dst
,
src
,
num
,
&
stream_wrapper
);
}
}
else
{
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
b4665d23
...
...
@@ -117,7 +117,7 @@ endif()
cc_library
(
cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost
)
# seperate init from device_context to avoid cycle dependencies
cc_library
(
init SRCS init.cc DEPS device_context custom_kernel
)
cc_library
(
init SRCS init.cc DEPS device_context
phi_
custom_kernel
)
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
...
...
paddle/fluid/platform/device/CMakeLists.txt
浏览文件 @
b4665d23
IF
(
WITH_CUSTOM_DEVICE
)
cc_library
(
callback_manager SRCS callback_manager.cc DEPS enforce place
)
cc_library
(
device_guard SRCS device_guard.cc DEPS enforce place
)
cc_library
(
stream SRCS stream.cc DEPS callback_manager
)
cc_library
(
event SRCS event.cc DEPS enforce place
)
cc_library
(
device_base SRCS device_base.cc DEPS stream event callback_manager device_guard device_context flags
)
ENDIF
()
set
(
DEV_LIBS custom_device
)
...
...
@@ -37,11 +25,3 @@ ENDIF()
IF
(
WITH_MLU
)
add_subdirectory
(
mlu
)
ENDIF
()
# CUSTOM
IF
(
WITH_CUSTOM_DEVICE
)
add_subdirectory
(
custom
)
cc_library
(
device_manager SRCS device_manager.cc DEPS custom_device
)
set
(
GLOB_DEV_LIB device_manager custom_device CACHE INTERNAL
"Global DEV library"
)
ENDIF
()
paddle/fluid/platform/device/custom/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
756af9ff
IF
(
WITH_CUSTOM_DEVICE
)
cc_library
(
custom_device SRCS custom_device.cc DEPS device_base device_context
)
cc_test
(
custom_device_test SRCS custom_device_test.cc DEPS device_manager device_context
)
ENDIF
()
paddle/fluid/platform/device/custom/enforce_custom.h
浏览文件 @
b4665d23
...
...
@@ -14,7 +14,10 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/device_ext.h"
#include <string>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_ext.h"
namespace
paddle
{
namespace
platform
{
...
...
paddle/fluid/platform/device/device_wrapper.h
浏览文件 @
b4665d23
...
...
@@ -40,10 +40,10 @@ limitations under the License. */
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
#endif
paddle/fluid/platform/device_context.cc
浏览文件 @
b4665d23
...
...
@@ -903,7 +903,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
CustomDeviceContext
::
CustomDeviceContext
(
CustomPlace
place
)
:
phi
::
CustomContext
(
place
)
{
Init
();
stream_
.
reset
(
new
p
latform
::
stream
::
Stream
(
place
,
stream
()));
stream_
.
reset
(
new
p
hi
::
stream
::
Stream
(
place
,
stream
()));
}
CustomDeviceContext
::~
CustomDeviceContext
()
{}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
b4665d23
...
...
@@ -72,8 +72,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#endif
#include "paddle/
fluid/platform/device
/device_ext.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/device_ext.h"
#include "paddle/
phi/backends
/stream.h"
#if !defined(PADDLE_WITH_XPU_KP) || defined(__xpu_on_host__)
#include "unsupported/Eigen/CXX11/Tensor"
...
...
@@ -838,7 +838,7 @@ class CustomDeviceContext : public phi::CustomContext {
void
WaitStreamCallback
()
const
{
return
stream_
->
WaitCallback
();
}
private:
std
::
shared_ptr
<
p
latform
::
stream
::
Stream
>
stream_
;
std
::
shared_ptr
<
p
hi
::
stream
::
Stream
>
stream_
;
};
template
<
>
struct
DefaultDeviceContextType
<
platform
::
CustomPlace
>
{
...
...
paddle/fluid/platform/init.cc
浏览文件 @
b4665d23
...
...
@@ -55,7 +55,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_info.h"
#endif
#include "paddle/
fluid/framework
/custom_kernel.h"
#include "paddle/
phi/core
/custom_kernel.h"
DECLARE_int32
(
paddle_num_threads
);
PADDLE_DEFINE_EXPORTED_int32
(
...
...
@@ -145,7 +145,7 @@ void InitCupti() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
void
LoadCustomDevice
(
const
std
::
string
&
library_dir
)
{
LOG
(
INFO
)
<<
"Try loading custom device libs from: ["
<<
library_dir
<<
"]"
;
std
::
vector
<
std
::
string
>
libs
=
p
latform
::
ListAllLibraries
(
library_dir
);
std
::
vector
<
std
::
string
>
libs
=
p
hi
::
ListAllLibraries
(
library_dir
);
for
(
const
auto
&
lib_path
:
libs
)
{
auto
dso_handle
=
dlopen
(
lib_path
.
c_str
(),
RTLD_NOW
);
PADDLE_ENFORCE_NOT_NULL
(
...
...
@@ -153,8 +153,8 @@ void LoadCustomDevice(const std::string &library_dir) {
platform
::
errors
::
InvalidArgument
(
"Fail to open library: %s with error: %s"
,
lib_path
,
dlerror
()));
p
latform
::
LoadCustomRuntimeLib
(
lib_path
,
dso_handle
);
framework
::
LoadCustomKernelLib
(
lib_path
,
dso_handle
);
p
hi
::
LoadCustomRuntimeLib
(
lib_path
,
dso_handle
);
phi
::
LoadCustomKernelLib
(
lib_path
,
dso_handle
);
}
LOG
(
INFO
)
<<
"Finished in LoadCustomDevice with libs_path: ["
<<
library_dir
<<
"]"
;
...
...
@@ -259,9 +259,9 @@ void InitDevices(const std::vector<int> devices) {
LOG
(
INFO
)
<<
"ENV [CUSTOM_DEVICE_ROOT]="
<<
custom_kernel_root
;
LoadCustomDevice
(
custom_kernel_root
);
auto
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
for
(
auto
&
dev_type
:
device_types
)
{
auto
device_count
=
p
latform
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
auto
device_count
=
p
hi
::
DeviceManager
::
GetDeviceCount
(
dev_type
);
LOG
(
INFO
)
<<
"CustomDevice: "
<<
dev_type
<<
", visible devices count: "
<<
device_count
;
for
(
size_t
i
=
0
;
i
<
device_count
;
i
++
)
{
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
b4665d23
...
...
@@ -1668,7 +1668,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_all_device_type"
,
[]()
{
std
::
vector
<
std
::
string
>
device_types
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
device_types
=
p
latform
::
DeviceManager
::
GetAllDeviceTypes
();
device_types
=
p
hi
::
DeviceManager
::
GetAllDeviceTypes
();
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_all_device_type because you have installed"
...
...
@@ -1682,7 +1682,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_all_custom_device_type"
,
[]()
{
std
::
vector
<
std
::
string
>
device_types
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
device_types
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceTypes
();
device_types
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceTypes
();
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_all_custom_device_type because you have installed"
...
...
@@ -1696,7 +1696,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_available_device"
,
[]
{
std
::
vector
<
std
::
string
>
devices
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
devices
=
p
latform
::
DeviceManager
::
GetAllDeviceList
();
devices
=
p
hi
::
DeviceManager
::
GetAllDeviceList
();
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_available_device because you have installed"
...
...
@@ -1710,7 +1710,7 @@ All parameter, weight, gradient are variables in Paddle.
m
.
def
(
"get_available_custom_device"
,
[]
{
std
::
vector
<
std
::
string
>
devices
;
#ifdef PADDLE_WITH_CUSTOM_DEVICE
devices
=
p
latform
::
DeviceManager
::
GetAllCustomDeviceList
();
devices
=
p
hi
::
DeviceManager
::
GetAllCustomDeviceList
();
#else
LOG
(
WARNING
)
<<
string
::
Sprintf
(
"Cannot use get_available_custom_device because you have "
...
...
@@ -1747,10 +1747,10 @@ All parameter, weight, gradient are variables in Paddle.
std
::
exit
(
-
1
);
}
if
(
LIKELY
(
p
latform
::
DeviceManager
::
HasDeviceType
(
device_type
)
&&
p
latform
::
DeviceManager
::
IsCustom
(
device_type
)))
{
if
(
LIKELY
(
p
hi
::
DeviceManager
::
HasDeviceType
(
device_type
)
&&
p
hi
::
DeviceManager
::
IsCustom
(
device_type
)))
{
int
dev_count
=
static_cast
<
int
>
(
p
latform
::
DeviceManager
::
GetDeviceCount
(
device_type
));
p
hi
::
DeviceManager
::
GetDeviceCount
(
device_type
));
if
(
UNLIKELY
(
dev_id
>=
dev_count
))
{
if
(
dev_count
==
0
)
{
LOG
(
ERROR
)
<<
"Cannot use "
<<
device_type
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
b4665d23
...
...
@@ -393,10 +393,10 @@ void SetTensorFromPyArrayT(
}
else
if
(
paddle
::
platform
::
is_custom_place
(
place
))
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform
::
Place
tmp_place
=
place
;
p
latform
::
DeviceGuard
guard
(
tmp_place
);
p
hi
::
DeviceGuard
guard
(
tmp_place
);
auto
dst
=
self
->
mutable_data
<
T
>
(
place
);
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
tmp_place
)
->
MemoryCopyH2D
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
tmp_place
)
->
MemoryCopyH2D
(
reinterpret_cast
<
void
*>
(
dst
),
const_cast
<
void
*>
(
reinterpret_cast
<
const
void
*>
(
array
.
data
())),
array
.
nbytes
());
...
...
paddle/phi/backends/CMakeLists.txt
浏览文件 @
b4665d23
...
...
@@ -24,4 +24,11 @@ endif()
if
(
WITH_CUSTOM_DEVICE
)
add_dependencies
(
phi_context custom_context
)
cc_library
(
callback_manager SRCS callback_manager.cc DEPS enforce place
)
cc_library
(
device_guard SRCS device_guard.cc DEPS enforce place
)
cc_library
(
stream SRCS stream.cc DEPS callback_manager
)
cc_library
(
event SRCS event.cc DEPS enforce place
)
cc_library
(
device_base SRCS device_base.cc DEPS stream event callback_manager device_guard device_context flags
)
cc_library
(
device_manager SRCS device_manager.cc DEPS custom_device
)
set
(
GLOB_DEV_LIB device_manager custom_device CACHE INTERNAL
"Global DEV library"
)
endif
()
paddle/
fluid/platform/device
/callback_manager.cc
→
paddle/
phi/backends
/callback_manager.cc
浏览文件 @
b4665d23
...
...
@@ -12,12 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/
fluid/platform/device
/callback_manager.h"
#include "paddle/
phi/backends
/callback_manager.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
CallbackManager
::
CallbackManager
(
stream
::
Stream
*
stream
)
:
stream_
(
stream
),
thread_pool_
(
1
)
{}
...
...
@@ -32,12 +31,12 @@ void CallbackManager::AddCallback(std::function<void()> callback) const {
});
});
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
->
AddCallback
(
stream_
,
func
);
}
void
CallbackManager
::
Wait
()
const
{
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
stream_
->
GetPlace
())
->
SynchronizeStream
(
stream_
);
{
...
...
@@ -48,5 +47,4 @@ void CallbackManager::Wait() const {
}
}
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/callback_manager.h
→
paddle/
phi/backends
/callback_manager.h
浏览文件 @
b4665d23
...
...
@@ -32,8 +32,7 @@
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
namespace
stream
{
class
Stream
;
...
...
@@ -58,5 +57,4 @@ class CallbackManager {
mutable
std
::
future
<
void
>
last_future_
;
};
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/phi/backends/custom/CMakeLists.txt
浏览文件 @
b4665d23
if
(
WITH_CUSTOM_DEVICE
)
cc_library
(
custom_context SRCS custom_context.cc DEPS phi_device_context device_manager
)
cc_library
(
custom_device SRCS custom_device.cc DEPS device_base device_context
)
cc_test
(
custom_device_test SRCS custom_device_test.cc DEPS device_manager device_context
)
endif
()
paddle/phi/backends/custom/custom_context.cc
浏览文件 @
b4665d23
...
...
@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/
fluid/platform/device
/device_guard.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/device_guard.h"
#include "paddle/
phi/backends
/stream.h"
namespace
phi
{
...
...
@@ -25,8 +25,8 @@ struct CustomContext::Impl {
~
Impl
()
{}
void
Init
()
{
p
addle
::
platform
::
DeviceGuard
guard
(
place_
);
stream_
.
reset
(
new
p
addle
::
platform
::
stream
::
Stream
());
p
hi
::
DeviceGuard
guard
(
place_
);
stream_
.
reset
(
new
p
hi
::
stream
::
Stream
());
stream_
->
Init
(
place_
);
}
...
...
@@ -40,7 +40,7 @@ struct CustomContext::Impl {
Place
place_
;
std
::
shared_ptr
<
p
addle
::
platform
::
stream
::
Stream
>
stream_
;
std
::
shared_ptr
<
p
hi
::
stream
::
Stream
>
stream_
;
};
void
CustomContext
::
Init
()
{
impl_
->
Init
();
}
...
...
paddle/
fluid/platform/device
/custom/custom_device.cc
→
paddle/
phi/backends
/custom/custom_device.cc
浏览文件 @
b4665d23
...
...
@@ -12,23 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/device_base.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/fluid/platform/device/custom/enforce_custom.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/callback_manager.h"
#include "paddle/phi/backends/device_base.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/backends/event.h"
#include "paddle/phi/backends/stream.h"
static
bool
operator
==
(
const
C_Device_st
&
d1
,
const
C_Device_st
&
d2
)
{
return
d1
.
id
==
d2
.
id
;
}
namespace
paddle
{
namespace
platform
{
namespace
phi
{
class
CustomDevice
:
public
DeviceInterface
{
public:
CustomDevice
(
const
std
::
string
&
type
,
int
priority
,
bool
is_custom
,
std
::
unique_ptr
<
C_DeviceInterface
>
pimpl
,
void
*
dso_handle
)
CustomDevice
(
const
std
::
string
&
type
,
int
priority
,
bool
is_custom
,
std
::
unique_ptr
<
C_DeviceInterface
>
pimpl
,
void
*
dso_handle
)
:
DeviceInterface
(
type
,
priority
,
is_custom
),
pimpl_
(
std
::
move
(
pimpl
)),
dso_handle_
(
dso_handle
)
{
...
...
@@ -122,14 +127,15 @@ class CustomDevice : public DeviceInterface {
return
device
.
id
;
}
void
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
)
override
{
if
(
priority
!=
stream
::
Stream
::
Priority
::
kNormal
||
flag
!=
stream
::
Stream
::
Flag
::
kDefaultFlag
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"priority != stream::Stream::Priority::kNormal || flag != "
"stream::Stream::Flag::kDefaultFlag is not allowed on "
"CustomDevice."
));
...
...
@@ -162,23 +168,28 @@ class CustomDevice : public DeviceInterface {
SynchronizeStream
(
dev_id
,
stream
);
return
true
;
}
if
(
pimpl_
->
query_stream
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()))
==
C_SUCCESS
)
{
if
(
pimpl_
->
query_stream
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()))
==
C_SUCCESS
)
{
return
true
;
}
return
false
;
}
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
stream
::
Stream
::
Callback
*
callback
)
override
{
if
(
!
pimpl_
->
stream_add_callback
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"AddCallback is not supported on %s."
,
Type
()));
}
else
{
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_add_callback
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
[](
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
[](
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
C_Status
*
status
)
{
std
::
unique_ptr
<
std
::
function
<
void
()
>>
func
(
reinterpret_cast
<
std
::
function
<
void
()
>*>
(
user_data
));
...
...
@@ -188,7 +199,8 @@ class CustomDevice : public DeviceInterface {
}
}
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
event
::
Event
::
Flag
flags
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
C_Event
c_event
;
...
...
@@ -205,13 +217,15 @@ class CustomDevice : public DeviceInterface {
device
,
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
const
stream
::
Stream
*
stream
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
record_event
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
record_event
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
void
SynchronizeEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
)
override
{
...
...
@@ -228,78 +242,93 @@ class CustomDevice : public DeviceInterface {
SynchronizeEvent
(
dev_id
,
event
);
return
true
;
}
if
(
pimpl_
->
query_event
(
device
,
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
()))
==
C_SUCCESS
)
{
if
(
pimpl_
->
query_event
(
device
,
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
()))
==
C_SUCCESS
)
{
return
true
;
}
return
false
;
}
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_wait_event
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
auto
place
=
platform
::
CustomPlace
(
Type
(),
dev_id
);
auto
place
=
CustomPlace
(
Type
(),
dev_id
);
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_h2d
)
{
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_h2d
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_h2d
(
device
,
dst
,
src
,
size
));
}
}
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
auto
place
=
platform
::
CustomPlace
(
Type
(),
dev_id
);
auto
place
=
CustomPlace
(
Type
(),
dev_id
);
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_d2h
)
{
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_d2h
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_d2h
(
device
,
dst
,
src
,
size
));
}
}
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
auto
place
=
platform
::
CustomPlace
(
Type
(),
dev_id
);
auto
place
=
CustomPlace
(
Type
(),
dev_id
);
if
(
stream
&&
stream
->
raw_stream
()
&&
pimpl_
->
async_memory_copy_d2d
)
{
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_d2d
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_d2d
(
device
,
dst
,
src
,
size
));
}
}
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_dev_id
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_dev_id
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
)
override
{
int
dst_dev_id
=
PlaceToId
(
dst_place
);
auto
dst_device
=
&
devices_pool
[
dst_dev_id
];
...
...
@@ -310,8 +339,12 @@ class CustomDevice : public DeviceInterface {
MemoryCopyP2P
(
dst_place
,
dst
,
src_dev_id
,
src
,
size
);
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_p2p
(
dst_device
,
src_device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
dst
,
src
,
size
));
dst_device
,
src_device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
dst
,
src
,
size
));
}
}
else
{
if
(
!
pimpl_
->
memory_copy_p2p
)
{
...
...
@@ -319,9 +352,9 @@ class CustomDevice : public DeviceInterface {
MemoryCopyD2H
(
src_dev_id
,
tmp
.
get
(),
src
,
size
);
MemoryCopyH2D
(
dst_dev_id
,
dst
,
tmp
.
get
(),
size
);
}
else
{
auto
src_place
=
platform
::
CustomPlace
(
Type
(),
src_dev_id
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
src_place
=
CustomPlace
(
Type
(),
src_dev_id
);
p
addle
::
p
latform
::
DeviceContextPool
&
pool
=
p
addle
::
p
latform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
src_place
)
->
Wait
();
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
memory_copy_p2p
(
dst_device
,
src_device
,
dst
,
src
,
size
));
...
...
@@ -350,8 +383,8 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
unified_memory_allocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
"MemoryAlloc
Kind::
Host is not supported on %s."
,
Type
()));
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"MemoryAlloc
ate
Host is not supported on %s."
,
Type
()));
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
host_memory_allocate
(
device
,
&
ptr
,
size
));
...
...
@@ -363,8 +396,8 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
host_memory_deallocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
"Memory
AllocKind::
Host is not supported on %s."
,
Type
()));
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"Memory
Deallocate
Host is not supported on %s."
,
Type
()));
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
host_memory_deallocate
(
device
,
ptr
,
size
));
...
...
@@ -376,8 +409,8 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
unified_memory_allocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
"MemoryAlloc
Kind::
Unified is not supported on %s."
,
Type
()));
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"MemoryAlloc
ate
Unified is not supported on %s."
,
Type
()));
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
unified_memory_allocate
(
device
,
&
ptr
,
size
));
...
...
@@ -389,15 +422,17 @@ class CustomDevice : public DeviceInterface {
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
unified_memory_deallocate
)
{
PADDLE_THROW
(
p
latform
::
errors
::
Unavailable
(
"Memory
AllocKind::Host
is not supported on %s."
,
Type
()));
PADDLE_THROW
(
p
hi
::
errors
::
Unavailable
(
"Memory
DeallocateUnified
is not supported on %s."
,
Type
()));
}
else
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
unified_memory_deallocate
(
device
,
ptr
,
size
));
}
}
void
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
void
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
size_t
size
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
...
...
@@ -532,10 +567,12 @@ class CustomDevice : public DeviceInterface {
inline
int
PlaceToId
(
const
Place
&
place
)
{
int
dev_id
=
PlaceToIdNoCheck
(
place
);
PADDLE_ENFORCE_NE
(
devices_pool
.
find
(
dev_id
),
devices_pool
.
end
(),
platform
::
errors
::
NotFound
(
PADDLE_ENFORCE_NE
(
devices_pool
.
find
(
dev_id
),
devices_pool
.
end
(),
phi
::
errors
::
NotFound
(
"Cannot found %s %d, please check visible devices"
,
Type
(),
dev_id
));
Type
(),
dev_id
));
return
dev_id
;
}
...
...
@@ -623,11 +660,14 @@ typedef bool (*RegisterDevicePluginFn)(CustomRuntimeParams* runtime_params);
void
LoadCustomRuntimeLib
(
const
CustomRuntimeParams
&
runtime_params
,
std
::
unique_ptr
<
C_DeviceInterface
>
device_interface
,
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
if
(
ValidCustomCustomRuntimeParams
(
&
runtime_params
))
{
auto
device
=
std
::
make_unique
<
CustomDevice
>
(
runtime_params
.
device_type
,
255
,
true
,
std
::
move
(
device_interface
),
dso_handle
);
auto
device
=
std
::
make_unique
<
CustomDevice
>
(
runtime_params
.
device_type
,
255
,
true
,
std
::
move
(
device_interface
),
dso_handle
);
if
(
false
==
DeviceManager
::
Register
(
std
::
move
(
device
)))
{
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
<<
"]. Register failed!!! there may be a "
...
...
@@ -665,10 +705,9 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) {
"compatibility between PaddlePaddle and Custom Runtime."
;
return
;
}
LoadCustomRuntimeLib
(
runtime_params
,
std
::
move
(
device_interface
),
dso_lib_path
,
dso_handle
);
LoadCustomRuntimeLib
(
runtime_params
,
std
::
move
(
device_interface
),
dso_lib_path
,
dso_handle
);
LOG
(
INFO
)
<<
"Successed in loading custom runtime in lib: "
<<
dso_lib_path
;
}
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/custom/custom_device_test.cc
→
paddle/
phi/backends
/custom/custom_device_test.cc
浏览文件 @
b4665d23
...
...
@@ -17,9 +17,9 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/custom/fake_cpu_device.h"
#include "paddle/fluid/platform/device/device_manager.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/custom/fake_cpu_device.h"
#include "paddle/phi/backends/device_manager.h"
void
RegisterDevice
()
{
CustomRuntimeParams
runtime_params
;
...
...
@@ -30,23 +30,22 @@ void RegisterDevice() {
runtime_params
.
interface
->
size
=
sizeof
(
C_DeviceInterface
);
InitFakeCPUDevice
(
&
runtime_params
);
p
addle
::
platform
::
LoadCustomRuntimeLib
(
p
hi
::
LoadCustomRuntimeLib
(
runtime_params
,
std
::
move
(
device_interface
),
""
,
nullptr
);
}
void
InitDevice
()
{
RegisterDevice
();
EXPECT_GT
(
static_cast
<
int
>
(
paddle
::
platform
::
DeviceManager
::
GetAllDeviceTypes
().
size
()),
EXPECT_GT
(
static_cast
<
int
>
(
phi
::
DeviceManager
::
GetAllDeviceTypes
().
size
()),
0
);
auto
place
=
paddle
::
platform
::
CustomPlace
(
DEVICE_TYPE
,
0
);
auto
device
=
p
addle
::
platform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
EXPECT_NE
(
device
,
nullptr
);
std
::
vector
<
paddle
::
platform
::
Place
>
places
;
auto
device_types
=
p
addle
::
platform
::
DeviceManager
::
GetAllDeviceTypes
();
auto
device_types
=
p
hi
::
DeviceManager
::
GetAllDeviceTypes
();
for
(
auto
dev_type
:
device_types
)
{
auto
devices
=
p
addle
::
platform
::
DeviceManager
::
GetDeviceList
(
dev_type
);
auto
devices
=
p
hi
::
DeviceManager
::
GetDeviceList
(
dev_type
);
for
(
auto
dev_id
:
devices
)
{
places
.
push_back
(
paddle
::
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
,
dev_id
));
...
...
@@ -60,14 +59,14 @@ void InitDevice() {
void
TestDeviceInterface
(
const
paddle
::
platform
::
Place
&
place
)
{
std
::
cout
<<
"TestDeviceInterface on "
<<
place
<<
std
::
endl
;
if
(
paddle
::
platform
::
is_custom_place
(
place
))
{
auto
device
=
p
addle
::
platform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
device
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
auto
dev_type
=
paddle
::
platform
::
PlaceHelper
::
GetDeviceType
(
place
);
auto
p1
=
device
->
MemoryAllocate
(
paddle
::
platform
::
DeviceManager
::
GetMinChunkSize
(
place
));
auto
p1
=
device
->
MemoryAllocate
(
phi
::
DeviceManager
::
GetMinChunkSize
(
place
));
EXPECT_NE
(
p1
,
nullptr
);
p
addle
::
platform
::
DeviceManager
::
SetDevice
(
place
);
auto
dev_id
=
p
addle
::
platform
::
DeviceManager
::
GetDevice
(
dev_type
);
p
hi
::
DeviceManager
::
SetDevice
(
place
);
auto
dev_id
=
p
hi
::
DeviceManager
::
GetDevice
(
dev_type
);
EXPECT_EQ
(
dev_id
,
place
.
GetDeviceId
());
}
}
...
...
@@ -168,11 +167,10 @@ void TestTensorUtils(const paddle::platform::Place& place) {
TEST
(
CustomDevice
,
Tensor
)
{
InitDevice
();
auto
dev_types
=
p
addle
::
platform
::
DeviceManager
::
GetAllDeviceTypes
();
auto
dev_types
=
p
hi
::
DeviceManager
::
GetAllDeviceTypes
();
for
(
const
auto
&
dev_type
:
dev_types
)
{
std
::
cout
<<
"Test on "
<<
dev_type
<<
std
::
endl
;
EXPECT_GT
(
static_cast
<
int
>
(
paddle
::
platform
::
DeviceManager
::
GetDeviceCount
(
dev_type
)),
EXPECT_GT
(
static_cast
<
int
>
(
phi
::
DeviceManager
::
GetDeviceCount
(
dev_type
)),
0
);
auto
place
=
paddle
::
platform
::
PlaceHelper
::
CreatePlace
(
dev_type
);
...
...
paddle/
fluid/platform/device
/custom/fake_cpu_device.h
→
paddle/
phi/backends
/custom/fake_cpu_device.h
浏览文件 @
b4665d23
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/
fluid/platform/device
/device_ext.h"
#include "paddle/
phi/backends
/device_ext.h"
constexpr
size_t
global_total_memory
=
1024
*
1024UL
;
static
size_t
global_free_memory
=
global_total_memory
;
...
...
@@ -43,14 +43,19 @@ C_Status GetDevicesList(size_t *device) {
return
C_SUCCESS
;
}
C_Status
MemCpy
(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
MemCpy
(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
)
{
memcpy
(
dst
,
src
,
size
);
return
C_SUCCESS
;
}
C_Status
AsyncMemCpy
(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
)
{
C_Status
AsyncMemCpy
(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
)
{
memcpy
(
dst
,
src
,
size
);
return
C_SUCCESS
;
}
...
...
@@ -100,14 +105,16 @@ C_Status SyncStream(const C_Device device, C_Stream stream) {
C_Status
SyncEvent
(
const
C_Device
device
,
C_Event
event
)
{
return
C_SUCCESS
;
}
C_Status
StreamWaitEvent
(
const
C_Device
device
,
C_Stream
stream
,
C_Status
StreamWaitEvent
(
const
C_Device
device
,
C_Stream
stream
,
C_Event
event
)
{
return
C_SUCCESS
;
}
C_Status
VisibleDevices
(
size_t
*
devices
)
{
return
C_SUCCESS
;
}
C_Status
DeviceMemStats
(
const
C_Device
device
,
size_t
*
total_memory
,
C_Status
DeviceMemStats
(
const
C_Device
device
,
size_t
*
total_memory
,
size_t
*
free_memory
)
{
*
total_memory
=
global_total_memory
;
*
free_memory
=
global_free_memory
;
...
...
@@ -139,7 +146,8 @@ void InitFakeCPUDevice(CustomRuntimeParams *params) {
params
->
version
.
minor
=
PADDLE_CUSTOM_RUNTIME_MINOR_VERSION
;
params
->
version
.
patch
=
PADDLE_CUSTOM_RUNTIME_PATCH_VERSION
;
memset
(
reinterpret_cast
<
void
*>
(
params
->
interface
),
0
,
memset
(
reinterpret_cast
<
void
*>
(
params
->
interface
),
0
,
sizeof
(
C_DeviceInterface
));
params
->
interface
->
initialize
=
Init
;
...
...
paddle/
fluid/platform/device
/device_base.cc
→
paddle/
phi/backends
/device_base.cc
浏览文件 @
b4665d23
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/
fluid/platform/device
/device_base.h"
#include "paddle/
phi/backends
/device_base.h"
#include "gflags/gflags.h"
DECLARE_double
(
fraction_of_gpu_memory_to_use
);
...
...
@@ -21,26 +21,25 @@ DECLARE_uint64(reallocate_gpu_memory_in_mb);
constexpr
static
float
fraction_reserve_gpu_memory
=
0.05
f
;
namespace
paddle
{
namespace
platform
{
namespace
phi
{
#define INTERFACE_UNIMPLEMENT
\
PADDLE_THROW(p
latform
::errors::Unimplemented( \
#define INTERFACE_UNIMPLEMENT \
PADDLE_THROW(p
hi
::errors::Unimplemented( \
"%s is not implemented on %s device.", __func__, Type()));
// info
size_t
DeviceInterface
::
GetComputeCapability
()
{
VLOG
(
10
)
<<
Type
()
+
" get compute capability "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" get compute capability "
<<
0
;
return
0
;
}
size_t
DeviceInterface
::
GetRuntimeVersion
()
{
VLOG
(
10
)
<<
Type
()
+
" get runtime version "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" get runtime version "
<<
0
;
return
0
;
}
size_t
DeviceInterface
::
GetDriverVersion
()
{
VLOG
(
10
)
<<
Type
()
+
" get driver version "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" get driver version "
<<
0
;
return
0
;
}
...
...
@@ -62,7 +61,8 @@ void DeviceInterface::SetDevice(size_t dev_id) { INTERFACE_UNIMPLEMENT; }
int
DeviceInterface
::
GetDevice
()
{
INTERFACE_UNIMPLEMENT
;
}
// stream manage
void
DeviceInterface
::
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
DeviceInterface
::
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
,
const
stream
::
Stream
::
Flag
&
flag
)
{
INTERFACE_UNIMPLEMENT
;
...
...
@@ -82,7 +82,8 @@ bool DeviceInterface::QueryStream(size_t dev_id, const stream::Stream* stream) {
return
true
;
}
void
DeviceInterface
::
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
void
DeviceInterface
::
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
stream
::
Stream
::
Callback
*
callback
)
{
INTERFACE_UNIMPLEMENT
;
}
...
...
@@ -94,7 +95,8 @@ void DeviceInterface::StreamWaitEvent(size_t dev_id,
}
// event manage
void
DeviceInterface
::
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
void
DeviceInterface
::
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
event
::
Event
::
Flag
flags
)
{
INTERFACE_UNIMPLEMENT
;
}
...
...
@@ -103,7 +105,8 @@ void DeviceInterface::DestroyEvent(size_t dev_id, event::Event* event) {
INTERFACE_UNIMPLEMENT
;
}
void
DeviceInterface
::
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
void
DeviceInterface
::
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
}
...
...
@@ -119,23 +122,35 @@ bool DeviceInterface::QueryEvent(size_t dev_id, const event::Event* event) {
}
// memery manage
void
DeviceInterface
::
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
DeviceInterface
::
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
}
void
DeviceInterface
::
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
DeviceInterface
::
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
}
void
DeviceInterface
::
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
DeviceInterface
::
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
}
void
DeviceInterface
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
void
DeviceInterface
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
INTERFACE_UNIMPLEMENT
;
}
...
...
@@ -154,7 +169,8 @@ void* DeviceInterface::MemoryAllocateHost(size_t dev_id, size_t size) {
return
nullptr
;
}
void
DeviceInterface
::
MemoryDeallocateHost
(
size_t
dev_id
,
void
*
ptr
,
void
DeviceInterface
::
MemoryDeallocateHost
(
size_t
dev_id
,
void
*
ptr
,
size_t
size
)
{
INTERFACE_UNIMPLEMENT
;
}
...
...
@@ -164,12 +180,15 @@ void* DeviceInterface::MemoryAllocateUnified(size_t dev_id, size_t size) {
return
nullptr
;
}
void
DeviceInterface
::
MemoryDeallocateUnified
(
size_t
dev_id
,
void
*
ptr
,
void
DeviceInterface
::
MemoryDeallocateUnified
(
size_t
dev_id
,
void
*
ptr
,
size_t
size
)
{
INTERFACE_UNIMPLEMENT
;
}
void
DeviceInterface
::
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
void
DeviceInterface
::
MemorySet
(
size_t
dev_id
,
void
*
ptr
,
uint8_t
value
,
size_t
size
)
{
INTERFACE_UNIMPLEMENT
;
}
...
...
@@ -184,8 +203,9 @@ size_t DeviceInterface::GetMinChunkSize(size_t dev_id) {
size_t
DeviceInterface
::
AllocSize
(
size_t
dev_id
,
bool
realloc
)
{
size_t
available_to_alloc
=
AvailableAllocSize
(
dev_id
);
PADDLE_ENFORCE_GT
(
available_to_alloc
,
0
,
platform
::
errors
::
ResourceExhausted
(
PADDLE_ENFORCE_GT
(
available_to_alloc
,
0
,
phi
::
errors
::
ResourceExhausted
(
"Not enough available %s memory."
,
Type
()));
// If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
// allocated by fraction
...
...
@@ -194,8 +214,9 @@ size_t DeviceInterface::AllocSize(size_t dev_id, bool realloc) {
size_t
alloc_bytes
=
(
flag_mb
>
0ul
?
flag_mb
<<
20
:
available_to_alloc
*
FLAGS_fraction_of_gpu_memory_to_use
);
PADDLE_ENFORCE_GE
(
available_to_alloc
,
alloc_bytes
,
platform
::
errors
::
ResourceExhausted
(
PADDLE_ENFORCE_GE
(
available_to_alloc
,
alloc_bytes
,
phi
::
errors
::
ResourceExhausted
(
"Not enough available %s memory."
,
Type
()));
return
alloc_bytes
;
}
...
...
@@ -217,33 +238,32 @@ size_t DeviceInterface::AvailableAllocSize(size_t dev_id) {
size_t
DeviceInterface
::
GetInitAllocSize
(
size_t
dev_id
)
{
size_t
init_alloc_size
=
AllocSize
(
dev_id
,
false
);
VLOG
(
10
)
<<
Type
()
+
" init alloc size "
<<
(
init_alloc_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" init alloc size "
<<
(
init_alloc_size
>>
20
)
<<
"M"
;
return
init_alloc_size
;
}
size_t
DeviceInterface
::
GetReallocSize
(
size_t
dev_id
)
{
size_t
realloc_size
=
AllocSize
(
dev_id
,
true
);
VLOG
(
10
)
<<
Type
()
+
" realloc size "
<<
(
realloc_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" realloc size "
<<
(
realloc_size
>>
20
)
<<
"M"
;
return
realloc_size
;
}
size_t
DeviceInterface
::
GetMaxAllocSize
(
size_t
dev_id
)
{
size_t
max_alloc_size
=
std
::
max
(
GetInitAllocSize
(
dev_id
),
GetReallocSize
(
dev_id
));
VLOG
(
10
)
<<
Type
()
+
" max alloc size "
<<
(
max_alloc_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" max alloc size "
<<
(
max_alloc_size
>>
20
)
<<
"M"
;
return
max_alloc_size
;
}
size_t
DeviceInterface
::
GetMaxChunkSize
(
size_t
dev_id
)
{
size_t
max_chunk_size
=
GetMaxAllocSize
(
dev_id
);
VLOG
(
10
)
<<
Type
()
+
" max chunk size "
<<
(
max_chunk_size
>>
20
)
<<
"M"
;
VLOG
(
10
)
<<
Type
()
<<
" max chunk size "
<<
(
max_chunk_size
>>
20
)
<<
"M"
;
return
max_chunk_size
;
}
size_t
DeviceInterface
::
GetExtraPaddingSize
(
size_t
dev_id
)
{
VLOG
(
10
)
<<
Type
()
+
" extra padding size "
<<
0
;
VLOG
(
10
)
<<
Type
()
<<
" extra padding size "
<<
0
;
return
0
;
}
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/device_base.h
→
paddle/
phi/backends
/device_base.h
浏览文件 @
b4665d23
...
...
@@ -14,11 +14,10 @@
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/event.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
phi/backends
/event.h"
#include "paddle/
phi/backends
/stream.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
class
DeviceInterface
{
// Driver / Runtime
public:
...
...
@@ -66,7 +65,8 @@ class DeviceInterface { // Driver / Runtime
// Stream
// ! Create an asynchronous stream
virtual
void
CreateStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
size_t
dev_id
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
);
...
...
@@ -81,19 +81,22 @@ class DeviceInterface { // Driver / Runtime
virtual
bool
QueryStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
);
// ! Add a callback to a compute stream.
virtual
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
virtual
void
AddCallback
(
size_t
dev_id
,
stream
::
Stream
*
stream
,
stream
::
Stream
::
Callback
*
callback
);
// Event
// ! Create an event.
virtual
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
virtual
void
CreateEvent
(
size_t
dev_id
,
event
::
Event
*
event
,
event
::
Event
::
Flag
flags
);
// ! Destroy an event.
virtual
void
DestroyEvent
(
size_t
dev_id
,
event
::
Event
*
event
);
// ! Records an event.
virtual
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
virtual
void
RecordEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
,
const
stream
::
Stream
*
stream
);
// ! Waits for event to complete.
...
...
@@ -102,24 +105,34 @@ class DeviceInterface { // Driver / Runtime
virtual
bool
QueryEvent
(
size_t
dev_id
,
const
event
::
Event
*
event
);
// ! Make a compute stream wait on an event
virtual
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
virtual
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
);
// Memory
virtual
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
virtual
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
virtual
void
MemoryCopyD2H
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
virtual
void
MemoryCopyD2D
(
size_t
dev_id
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
virtual
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
size_t
src_id
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
virtual
void
*
MemoryAllocate
(
size_t
dev_id
,
size_t
size
);
...
...
@@ -160,7 +173,6 @@ class DeviceInterface { // Driver / Runtime
size_t
AvailableAllocSize
(
size_t
dev_id
);
};
}
// namespace platform
}
// namespace paddle
}
// namespace phi
#endif
paddle/
fluid/platform/device
/device_ext.h
→
paddle/
phi/backends
/device_ext.h
浏览文件 @
b4665d23
...
...
@@ -40,7 +40,9 @@ typedef struct C_Stream_st* C_Stream;
typedef
struct
C_Event_st
*
C_Event
;
typedef
void
(
*
C_Callback
)(
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
typedef
void
(
*
C_Callback
)(
C_Device
device
,
C_Stream
stream
,
void
*
user_data
,
C_Status
*
status
);
struct
C_DeviceInterface
{
...
...
@@ -124,8 +126,10 @@ struct C_DeviceInterface {
* @param[C_Callback] callback
* @param[void*] user_data
*/
C_Status
(
*
stream_add_callback
)(
const
C_Device
device
,
C_Stream
stream
,
C_Callback
callback
,
void
*
user_data
);
C_Status
(
*
stream_add_callback
)(
const
C_Device
device
,
C_Stream
stream
,
C_Callback
callback
,
void
*
user_data
);
/**
* @brief Create an event
...
...
@@ -142,7 +146,8 @@ struct C_DeviceInterface {
* @param[C_Stream] stream
* @param[C_Event] event
*/
C_Status
(
*
record_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
record_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Event
event
);
/**
...
...
@@ -191,7 +196,8 @@ struct C_DeviceInterface {
* @param[C_Stream] stream
* @param[C_Event] event
*/
C_Status
(
*
stream_wait_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Status
(
*
stream_wait_event
)(
const
C_Device
device
,
C_Stream
stream
,
C_Event
event
);
void
*
reserved_dev_api
[
8
];
...
...
@@ -207,7 +213,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
*/
C_Status
(
*
device_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
C_Status
(
*
device_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
size_t
size
);
/**
...
...
@@ -217,7 +224,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[size_t] size
*/
C_Status
(
*
device_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
device_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
size_t
size
);
/**
...
...
@@ -228,8 +236,10 @@ struct C_DeviceInterface {
* @param[unsigned char] value
* @param[size_t] size
*/
C_Status
(
*
device_memory_set
)(
const
C_Device
device
,
void
*
ptr
,
unsigned
char
value
,
size_t
size
);
C_Status
(
*
device_memory_set
)(
const
C_Device
device
,
void
*
ptr
,
unsigned
char
value
,
size_t
size
);
/**
* @brief Host memory allocate
...
...
@@ -238,7 +248,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
*/
C_Status
(
*
host_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
C_Status
(
*
host_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
size_t
size
);
/**
...
...
@@ -248,7 +259,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[size_t] size
*/
C_Status
(
*
host_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
host_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
size_t
size
);
/**
...
...
@@ -258,7 +270,8 @@ struct C_DeviceInterface {
* @param[void**] ptr Plugin allocate an address and fill it
* @param[size_t] size
*/
C_Status
(
*
unified_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
C_Status
(
*
unified_memory_allocate
)(
const
C_Device
device
,
void
**
ptr
,
size_t
size
);
/**
...
...
@@ -268,7 +281,8 @@ struct C_DeviceInterface {
* @param[void*] ptr
* @param[size_t] size
*/
C_Status
(
*
unified_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
C_Status
(
*
unified_memory_deallocate
)(
const
C_Device
device
,
void
*
ptr
,
size_t
size
);
/**
...
...
@@ -279,7 +293,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status
(
*
memory_copy_h2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
(
*
memory_copy_h2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
...
...
@@ -290,7 +306,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status
(
*
memory_copy_d2h
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
(
*
memory_copy_d2h
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
...
...
@@ -301,7 +319,9 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status
(
*
memory_copy_d2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
C_Status
(
*
memory_copy_d2d
)(
const
C_Device
device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
...
...
@@ -314,8 +334,10 @@ struct C_DeviceInterface {
* @param[size_t] size
*/
C_Status
(
*
memory_copy_p2p
)(
const
C_Device
dst_device
,
const
C_Device
src_device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
const
C_Device
src_device
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
* @brief Asynchonrize memory copy from host to device
...
...
@@ -326,8 +348,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status
(
*
async_memory_copy_h2d
)(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Status
(
*
async_memory_copy_h2d
)(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
* @brief Asynchonrize memory copy from device to host
...
...
@@ -338,8 +363,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status
(
*
async_memory_copy_d2h
)(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Status
(
*
async_memory_copy_d2h
)(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
* @brief Asynchonrize memory copy from device to device
...
...
@@ -350,8 +378,11 @@ struct C_DeviceInterface {
* @param[void*] src
* @param[size_t] size
*/
C_Status
(
*
async_memory_copy_d2d
)(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
C_Status
(
*
async_memory_copy_d2d
)(
const
C_Device
device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
/**
* @brief Peer asynchonrize memory copy from host to device
...
...
@@ -363,8 +394,11 @@ struct C_DeviceInterface {
* @param[size_t] size
*/
C_Status
(
*
async_memory_copy_p2p
)(
const
C_Device
dst_device
,
const
C_Device
src_device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
const
C_Device
src_device
,
C_Stream
stream
,
void
*
dst
,
const
void
*
src
,
size_t
size
);
void
*
reserved_mem_api
[
8
];
...
...
@@ -394,7 +428,8 @@ struct C_DeviceInterface {
* @param[size_t*] free_memory
* @param[size_t*] used_memory
*/
C_Status
(
*
device_memory_stats
)(
const
C_Device
device
,
size_t
*
total_memory
,
C_Status
(
*
device_memory_stats
)(
const
C_Device
device
,
size_t
*
total_memory
,
size_t
*
free_memory
);
/**
...
...
paddle/
fluid/platform/device
/device_guard.cc
→
paddle/
phi/backends
/device_guard.cc
浏览文件 @
b4665d23
...
...
@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/
fluid/platform/device
/device_guard.h"
#include "paddle/
phi/backends
/device_guard.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
// Even this source file does not contains any code, it is better to keep this
// source file for cmake dependency.
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/device_guard.h
→
paddle/
phi/backends
/device_guard.h
浏览文件 @
b4665d23
...
...
@@ -13,17 +13,16 @@
// limitations under the License.
#pragma once
#include "paddle/
fluid/platform/device
/device_manager.h"
#include "paddle/
phi/backends
/device_manager.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
class
DeviceGuard
{
public:
explicit
inline
DeviceGuard
(
const
Place
&
place
)
:
dev_type_
(
PlaceHelper
::
GetDeviceType
(
place
))
{
:
dev_type_
(
place
.
GetDeviceType
(
))
{
prev_id
=
DeviceManager
::
GetDevice
(
dev_type_
);
cur_id
=
PlaceHelper
::
GetDeviceId
(
place
);
cur_id
=
place
.
GetDeviceId
(
);
if
(
cur_id
!=
prev_id
)
{
DeviceManager
::
SetDevice
(
dev_type_
,
cur_id
);
...
...
@@ -44,5 +43,4 @@ class DeviceGuard {
std
::
string
dev_type_
;
};
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/device_manager.cc
→
paddle/
phi/backends
/device_manager.cc
浏览文件 @
b4665d23
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/device_manager.h"
#include "paddle/
phi/backends
/device_manager.h"
#if !defined(_WIN32)
#include <dirent.h>
...
...
@@ -24,8 +24,7 @@
#include <functional>
#include <regex>
namespace
paddle
{
namespace
platform
{
namespace
phi
{
void
Device
::
CreateStream
(
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
,
...
...
@@ -76,23 +75,32 @@ void Device::StreamWaitEvent(const stream::Stream* stream,
impl_
->
StreamWaitEvent
(
dev_id_
,
stream
,
event
);
}
void
Device
::
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
Device
::
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyH2D
(
dev_id_
,
dst
,
src
,
size
,
stream
);
}
void
Device
::
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
Device
::
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyD2H
(
dev_id_
,
dst
,
src
,
size
,
stream
);
}
void
Device
::
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
Device
::
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyD2D
(
dev_id_
,
dst
,
src
,
size
,
stream
);
}
void
Device
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
void
Device
::
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
)
{
impl_
->
MemoryCopyP2P
(
dst_place
,
dst
,
dev_id_
,
src
,
size
,
stream
);
}
...
...
@@ -173,7 +181,7 @@ DeviceInterface* DeviceManager::GetDeviceInterfaceWithType(
}
else
{
LOG
(
ERROR
)
<<
"GetDeviceInterfaceWithType - "
<<
device_type
<<
" Failed
\n
"
;
PADDLE_THROW
(
p
latform
::
errors
::
Fatal
(
"Unregistered device type %s."
,
device_type
));
p
hi
::
errors
::
Fatal
(
"Unregistered device type %s."
,
device_type
));
return
nullptr
;
}
}
...
...
@@ -182,17 +190,21 @@ Device* DeviceManager::GetDeviceWithPlace(const Place& place) {
phi
::
AutoRDLock
lock
(
&
_global_device_manager_rw_lock
);
auto
&
dev_map
=
Instance
().
device_map_
;
auto
dev_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
dev_id
=
PlaceHelper
::
GetDeviceId
(
place
);
PADDLE_ENFORCE_NE
(
dev_map
.
find
(
dev_type
),
dev_map
.
end
(),
platform
::
errors
::
NotFound
(
"Unable to find Device with type %s."
,
dev_type
));
auto
dev_type
=
place
.
GetDeviceType
();
auto
dev_id
=
place
.
GetDeviceId
();
PADDLE_ENFORCE_NE
(
dev_map
.
find
(
dev_type
),
dev_map
.
end
(),
phi
::
errors
::
NotFound
(
"Unable to find Device with type %s."
,
dev_type
));
auto
&
dev_vec
=
dev_map
[
dev_type
];
PADDLE_ENFORCE_LT
(
dev_id
,
dev_vec
.
size
(),
platform
::
errors
::
OutOfRange
(
dev_id
,
dev_vec
.
size
(),
phi
::
errors
::
OutOfRange
(
"The visible devices count of type %s is %d, but dev_id is %d."
,
dev_type
,
dev_vec
.
size
(),
dev_id
));
dev_type
,
dev_vec
.
size
(),
dev_id
));
return
dev_vec
[
dev_id
].
get
();
}
...
...
@@ -277,22 +289,22 @@ void DeviceManager::Finalize(const std::string& device_type) {
}
void
DeviceManager
::
SynchronizeDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
SynchronizeDevice
(
device_id
);
}
void
DeviceManager
::
InitDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
InitDevice
(
device_id
);
}
void
DeviceManager
::
DeInitDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
DeInitDevice
(
device_id
);
}
...
...
@@ -304,8 +316,8 @@ void DeviceManager::SetDevice(const std::string& device_type,
}
void
DeviceManager
::
SetDevice
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
DeviceManager
::
SetDevice
(
device_type
,
device_id
);
}
...
...
@@ -315,51 +327,52 @@ int DeviceManager::GetDevice(const std::string& device_type) {
}
size_t
DeviceManager
::
GetMinChunkSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetMinChunkSize
(
device_id
);
}
size_t
DeviceManager
::
GetMaxChunkSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetMaxChunkSize
(
device_id
);
}
size_t
DeviceManager
::
GetMaxAllocSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetMaxAllocSize
(
device_id
);
}
size_t
DeviceManager
::
GetInitAllocSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetInitAllocSize
(
device_id
);
}
size_t
DeviceManager
::
GetReallocSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetReallocSize
(
device_id
);
}
size_t
DeviceManager
::
GetExtraPaddingSize
(
const
Place
&
place
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
return
dev_impl
->
GetExtraPaddingSize
(
device_id
);
}
void
DeviceManager
::
MemoryStats
(
const
Place
&
place
,
size_t
*
total
,
void
DeviceManager
::
MemoryStats
(
const
Place
&
place
,
size_t
*
total
,
size_t
*
free
)
{
auto
device_type
=
PlaceHelper
::
GetDeviceType
(
place
);
auto
device_id
=
PlaceHelper
::
GetDeviceId
(
place
);
auto
device_type
=
place
.
GetDeviceType
(
);
auto
device_id
=
place
.
GetDeviceId
(
);
auto
dev_impl
=
GetDeviceInterfaceWithType
(
device_type
);
dev_impl
->
MemoryStats
(
device_id
,
total
,
free
);
}
...
...
@@ -393,8 +406,8 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
}
else
{
while
((
ptr
=
readdir
(
dir
))
!=
nullptr
)
{
std
::
string
filename
(
ptr
->
d_name
);
if
(
std
::
regex_match
(
filename
.
begin
(),
filename
.
end
(),
results
,
express
))
{
if
(
std
::
regex_match
(
filename
.
begin
(),
filename
.
end
(),
results
,
express
))
{
libraries
.
push_back
(
library_dir
+
'/'
+
filename
);
VLOG
(
4
)
<<
"Found lib: "
<<
libraries
.
back
();
}
...
...
@@ -405,6 +418,5 @@ std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
return
libraries
;
}
}
// namespace platform
}
// namespace paddle
}
// namespace phi
#endif
paddle/
fluid/platform/device
/device_manager.h
→
paddle/
phi/backends
/device_manager.h
浏览文件 @
b4665d23
...
...
@@ -15,17 +15,16 @@
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/device_base.h"
#include "paddle/
fluid/platform/device
/device_ext.h"
#include "paddle/
fluid/platform/device
/event.h"
#include "paddle/
fluid/platform/device
/stream.h"
#include "paddle/
fluid/platform
/place.h"
#include "paddle/
phi/backends
/device_base.h"
#include "paddle/
phi/backends
/device_ext.h"
#include "paddle/
phi/backends
/event.h"
#include "paddle/
phi/backends
/stream.h"
#include "paddle/
phi/common
/place.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
class
Device
final
{
public:
Device
(
size_t
dev_id
,
DeviceInterface
*
impl
)
:
dev_id_
(
dev_id
),
impl_
(
impl
)
{}
...
...
@@ -33,8 +32,9 @@ class Device final {
// Stream
// ! Create an asynchronous stream
void
CreateStream
(
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
stream
::
Stream
*
stream
,
const
stream
::
Stream
::
Priority
&
priority
=
stream
::
Stream
::
Priority
::
kNormal
,
const
stream
::
Stream
::
Flag
&
flag
=
stream
::
Stream
::
Flag
::
kDefaultFlag
);
// ! Destroys an asynchronous stream.
...
...
@@ -69,17 +69,26 @@ class Device final {
void
StreamWaitEvent
(
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
);
// Memory
void
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyH2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2H
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
void
MemoryCopyD2D
(
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
MemoryCopyP2P
(
const
Place
&
dst_place
,
void
*
dst
,
const
void
*
src
,
size_t
size
,
const
stream
::
Stream
*
stream
=
nullptr
);
void
*
MemoryAllocate
(
size_t
size
);
...
...
@@ -168,7 +177,8 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle);
void
LoadCustomRuntimeLib
(
const
CustomRuntimeParams
&
runtime_params
,
std
::
unique_ptr
<
C_DeviceInterface
>
device_interface
,
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
class
Registrar
{
public:
...
...
@@ -180,7 +190,6 @@ class Registrar {
void
Touch
()
{}
};
}
// namespace platform
}
// namespace paddle
}
// namespace phi
#endif
paddle/
fluid/platform/device
/event.cc
→
paddle/
phi/backends
/event.cc
浏览文件 @
b4665d23
...
...
@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/event.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/phi/backends/event.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/stream.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
namespace
event
{
event_t
Event
::
raw_event
()
const
{
return
event_
;
}
...
...
@@ -27,7 +26,7 @@ void Event::set_event(event_t event) { event_ = event; }
Event
::
Event
(
const
Place
&
place
,
event_t
event
)
:
place_
(
place
),
device_
(
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
device_
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
event_
(
event
),
own_data_
(
false
)
{}
...
...
@@ -60,5 +59,4 @@ void Event::Synchonrize() const { device_->SynchronizeEvent(this); }
const
Place
&
Event
::
GetPlace
()
const
{
return
place_
;
}
}
// namespace event
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/event.h
→
paddle/
phi/backends
/event.h
浏览文件 @
b4665d23
...
...
@@ -15,8 +15,7 @@
#pragma once
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
class
Device
;
...
...
@@ -57,5 +56,4 @@ class Event {
};
}
// namespace event
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/stream.cc
→
paddle/
phi/backends
/stream.cc
浏览文件 @
b4665d23
...
...
@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/stream.h"
#include "paddle/fluid/platform/device/device_guard.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/event.h"
#include "paddle/phi/backends/device_guard.h"
#include "paddle/phi/backends/event.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
namespace
stream
{
Stream
::~
Stream
()
{
Destroy
();
}
...
...
@@ -30,15 +29,16 @@ void Stream::set_stream(stream_t stream) { stream_ = stream; }
// For compatiable
Stream
::
Stream
(
const
Place
&
place
,
stream_t
stream
)
:
place_
(
place
),
device_
(
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
device_
(
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
)),
stream_
(
stream
),
callback_manager_
(
new
CallbackManager
(
this
)),
own_data_
(
false
)
{}
bool
Stream
::
Init
(
const
Place
&
place
,
const
Priority
&
priority
,
bool
Stream
::
Init
(
const
Place
&
place
,
const
Priority
&
priority
,
const
Flag
&
flag
)
{
place_
=
place
;
device_
=
p
latform
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
device_
=
p
hi
::
DeviceManager
::
GetDeviceWithPlace
(
place
);
DeviceGuard
guard
(
place_
);
device_
->
CreateStream
(
this
,
priority
,
flag
);
...
...
@@ -92,5 +92,4 @@ void Stream::Synchronize() const { device_->SynchronizeStream(this); }
const
Place
&
Stream
::
GetPlace
()
const
{
return
place_
;
}
}
// namespace stream
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/
fluid/platform/device
/stream.h
→
paddle/
phi/backends
/stream.h
浏览文件 @
b4665d23
...
...
@@ -14,11 +14,10 @@
#pragma once
#include "paddle/fluid/platform/device/callback_manager.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/backends/callback_manager.h"
namespace
paddle
{
namespace
platform
{
namespace
phi
{
class
Device
;
...
...
@@ -49,7 +48,8 @@ class Stream {
~
Stream
();
const
stream_t
&
raw_stream
()
const
;
void
set_stream
(
stream_t
stream
);
bool
Init
(
const
Place
&
place
,
const
Priority
&
priority
=
Priority
::
kNormal
,
bool
Init
(
const
Place
&
place
,
const
Priority
&
priority
=
Priority
::
kNormal
,
const
Flag
&
flag
=
Flag
::
kDefaultFlag
);
template
<
typename
Callback
>
void
AddCallback
(
Callback
&&
callback
)
const
{
...
...
@@ -75,5 +75,4 @@ class Stream {
};
}
// namespace stream
}
// namespace platform
}
// namespace paddle
}
// namespace phi
paddle/phi/core/CMakeLists.txt
浏览文件 @
b4665d23
...
...
@@ -25,7 +25,7 @@ cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library
(
selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor phi_enforce ddim memcpy
)
cc_library
(
phi_device_context SRCS device_context.cc DEPS dense_tensor selected_rows
)
cc_library
(
phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils
)
cc_library
(
phi_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils
op_registry phi_tensor_raw
)
# Will remove once we implemented MKLDNN_Tensor
if
(
WITH_MKLDNN
)
...
...
paddle/phi/core/compat/convert_utils.cc
浏览文件 @
b4665d23
...
...
@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/
fluid/platform/device
/device_manager.h"
#include "paddle/
phi/backends
/device_manager.h"
#endif
namespace
phi
{
...
...
@@ -83,9 +83,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
if
(
!
device_type
.
empty
())
{
return
phi
::
CustomPlace
(
device_type
,
set_device_id
?
paddle
::
platform
::
DeviceManager
::
GetDevice
(
device_type
)
:
0
);
set_device_id
?
phi
::
DeviceManager
::
GetDevice
(
device_type
)
:
0
);
}
#endif
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
...
...
paddle/phi/core/custom_kernel.cc
浏览文件 @
b4665d23
...
...
@@ -12,6 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/phi/core/custom_kernel.h"
namespace
phi
{
...
...
@@ -50,6 +55,25 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
}
}
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
)
{
#ifdef _LINUX
typedef
phi
::
CustomKernelMap
&
get_custom_kernel_map_t
();
auto
*
func
=
reinterpret_cast
<
get_custom_kernel_map_t
*>
(
dlsym
(
dso_handle
,
"PD_GetCustomKernelMap"
));
if
(
func
==
nullptr
)
{
LOG
(
WARNING
)
<<
"Skipped lib ["
<<
dso_lib_path
<<
"]: fail to find "
<<
"PD_GetCustomKernelMap symbol in this lib."
;
return
;
}
auto
&
custom_kernel_map
=
func
();
phi
::
RegisterCustomKernels
(
custom_kernel_map
);
LOG
(
INFO
)
<<
"Successed in loading custom kernels in lib: "
<<
dso_lib_path
;
#else
VLOG
(
3
)
<<
"Unsupported: Custom kernel is only implemented on Linux."
;
#endif
return
;
}
}
// namespace phi
#ifdef __cplusplus
...
...
paddle/phi/core/custom_kernel.h
浏览文件 @
b4665d23
...
...
@@ -46,4 +46,6 @@ class CustomKernelMap {
*/
void
RegisterCustomKernels
(
const
CustomKernelMap
&
custom_kernel_map
);
// Load custom kernel lib and register
void
LoadCustomKernelLib
(
const
std
::
string
&
dso_lib_path
,
void
*
dso_handle
);
}
// namespace phi
python/setup.py.in
浏览文件 @
b4665d23
...
...
@@ -579,8 +579,7 @@ headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # phi core headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/backends', recursive=True)) + # phi backends headers
# utila api headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True)) + # paddle utils headers
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/device/device_ext.h'])
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True))) # paddle utils headers
if '${WITH_MKLDNN}' == 'ON':
headers += list(find_files('*', '${MKLDNN_INSTALL_DIR}/include')) # mkldnn
...
...
@@ -625,8 +624,6 @@ class InstallHeaders(Command):
elif 'third_party' not in header:
# paddle headers
install_dir = re.sub('@PADDLE_SOURCE_DIR@/', '', header)
if 'device_ext.h' in header:
install_dir = "paddle/"
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录