Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dd480273
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dd480273
编写于
2月 02, 2023
作者:
R
ronnywang
提交者:
GitHub
2月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomDevice] refine custom device api (#50152)
上级
d8643cb6
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
61 addition
and
47 deletion
+61
-47
paddle/phi/backends/custom/custom_device.cc
paddle/phi/backends/custom/custom_device.cc
+61
-47
未找到文件。
paddle/phi/backends/custom/custom_device.cc
浏览文件 @
dd480273
...
@@ -148,38 +148,41 @@ class CustomDevice : public DeviceInterface {
...
@@ -148,38 +148,41 @@ class CustomDevice : public DeviceInterface {
stream
::
Stream
::
Flag
::
kDefaultFlag
)
override
{
stream
::
Stream
::
Flag
::
kDefaultFlag
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
C_Stream
c_stream
;
C_Stream
c_stream
;
if
(
pimpl_
->
create_stream
)
{
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
create_stream
(
device
,
&
c_stream
));
pimpl_
->
create_stream
(
device
,
&
c_stream
));
}
else
{
c_stream
=
nullptr
;
}
stream
->
set_stream
(
c_stream
);
stream
->
set_stream
(
c_stream
);
}
}
void
DestroyStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
)
override
{
void
DestroyStream
(
size_t
dev_id
,
stream
::
Stream
*
stream
)
override
{
if
(
pimpl_
->
destroy_stream
)
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
destroy_stream
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
destroy_stream
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
())));
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
())));
}
}
}
void
SynchronizeStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
)
override
{
void
SynchronizeStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
)
override
{
if
(
pimpl_
->
synchronize_stream
)
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
synchronize_stream
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
synchronize_stream
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
())));
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
())));
}
}
}
bool
QueryStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
)
override
{
bool
QueryStream
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
)
override
{
const
auto
device
=
&
devices_pool
[
dev_id
];
if
(
!
pimpl_
->
query_stream
)
{
if
(
!
pimpl_
->
query_stream
)
{
SynchronizeStream
(
dev_id
,
stream
);
SynchronizeStream
(
dev_id
,
stream
);
return
true
;
return
true
;
}
}
else
{
if
(
pimpl_
->
query_stream
(
const
auto
device
=
&
devices_pool
[
dev_id
];
return
pimpl_
->
query_stream
(
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()))
==
device
,
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()))
==
C_SUCCESS
)
{
C_SUCCESS
;
return
true
;
}
}
return
false
;
}
}
void
AddCallback
(
size_t
dev_id
,
void
AddCallback
(
size_t
dev_id
,
...
@@ -259,6 +262,7 @@ class CustomDevice : public DeviceInterface {
...
@@ -259,6 +262,7 @@ class CustomDevice : public DeviceInterface {
void
StreamWaitEvent
(
size_t
dev_id
,
void
StreamWaitEvent
(
size_t
dev_id
,
const
stream
::
Stream
*
stream
,
const
stream
::
Stream
*
stream
,
const
event
::
Event
*
event
)
override
{
const
event
::
Event
*
event
)
override
{
if
(
pimpl_
->
stream_wait_event
)
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_wait_event
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
stream_wait_event
(
...
@@ -266,6 +270,7 @@ class CustomDevice : public DeviceInterface {
...
@@ -266,6 +270,7 @@ class CustomDevice : public DeviceInterface {
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
()),
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
reinterpret_cast
<
C_Event
>
(
event
->
raw_event
())));
}
}
}
void
MemoryCopyH2D
(
size_t
dev_id
,
void
MemoryCopyH2D
(
size_t
dev_id
,
void
*
dst
,
void
*
dst
,
...
@@ -279,7 +284,7 @@ class CustomDevice : public DeviceInterface {
...
@@ -279,7 +284,7 @@ class CustomDevice : public DeviceInterface {
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_h2d
(
device
,
c_stream
,
dst
,
src
,
size
));
pimpl_
->
async_memory_copy_h2d
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
}
else
if
(
pimpl_
->
memory_copy_h2d
)
{
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
paddle
::
platform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
...
@@ -300,7 +305,7 @@ class CustomDevice : public DeviceInterface {
...
@@ -300,7 +305,7 @@ class CustomDevice : public DeviceInterface {
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_d2h
(
device
,
c_stream
,
dst
,
src
,
size
));
pimpl_
->
async_memory_copy_d2h
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
}
else
if
(
pimpl_
->
memory_copy_d2h
)
{
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
paddle
::
platform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
...
@@ -321,7 +326,7 @@ class CustomDevice : public DeviceInterface {
...
@@ -321,7 +326,7 @@ class CustomDevice : public DeviceInterface {
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
C_Stream
c_stream
=
reinterpret_cast
<
C_Stream
>
(
stream
->
raw_stream
());
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
pimpl_
->
async_memory_copy_d2d
(
device
,
c_stream
,
dst
,
src
,
size
));
pimpl_
->
async_memory_copy_d2d
(
device
,
c_stream
,
dst
,
src
,
size
));
}
else
{
}
else
if
(
pimpl_
->
memory_copy_d2d
)
{
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
paddle
::
platform
::
DeviceContextPool
::
Instance
();
pool
.
Get
(
place
)
->
Wait
();
pool
.
Get
(
place
)
->
Wait
();
...
@@ -455,6 +460,7 @@ class CustomDevice : public DeviceInterface {
...
@@ -455,6 +460,7 @@ class CustomDevice : public DeviceInterface {
}
}
void
MemoryStats
(
size_t
dev_id
,
size_t
*
total
,
size_t
*
free
)
override
{
void
MemoryStats
(
size_t
dev_id
,
size_t
*
total
,
size_t
*
free
)
override
{
if
(
pimpl_
->
device_memory_stats
)
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS
(
...
@@ -464,15 +470,23 @@ class CustomDevice : public DeviceInterface {
...
@@ -464,15 +470,23 @@ class CustomDevice : public DeviceInterface {
VLOG
(
10
)
<<
Type
()
<<
" memory usage "
<<
(
used
>>
20
)
<<
"M/"
VLOG
(
10
)
<<
Type
()
<<
" memory usage "
<<
(
used
>>
20
)
<<
"M/"
<<
(
*
total
>>
20
)
<<
"M, "
<<
(
*
free
>>
20
)
<<
(
*
total
>>
20
)
<<
"M, "
<<
(
*
free
>>
20
)
<<
"M available to allocate"
;
<<
"M available to allocate"
;
}
else
{
*
total
=
0
;
*
free
=
0
;
}
}
}
size_t
GetMinChunkSize
(
size_t
dev_id
)
override
{
size_t
GetMinChunkSize
(
size_t
dev_id
)
override
{
if
(
pimpl_
->
device_min_chunk_size
)
{
const
auto
device
=
&
devices_pool
[
dev_id
];
const
auto
device
=
&
devices_pool
[
dev_id
];
size_t
size
=
0
;
size_t
size
=
0
;
pimpl_
->
device_min_chunk_size
(
device
,
&
size
);
pimpl_
->
device_min_chunk_size
(
device
,
&
size
);
VLOG
(
10
)
<<
Type
()
<<
" min chunk size "
<<
size
<<
"B"
;
VLOG
(
10
)
<<
Type
()
<<
" min chunk size "
<<
size
<<
"B"
;
return
size
;
return
size
;
}
else
{
return
1
;
}
}
}
size_t
GetMaxChunkSize
(
size_t
dev_id
)
override
{
size_t
GetMaxChunkSize
(
size_t
dev_id
)
override
{
...
@@ -911,8 +925,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
...
@@ -911,8 +925,8 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE
(
get_device
,
true
);
CHECK_INTERFACE
(
get_device
,
true
);
CHECK_INTERFACE
(
deinit_device
,
false
);
CHECK_INTERFACE
(
deinit_device
,
false
);
CHECK_INTERFACE
(
create_stream
,
tru
e
);
CHECK_INTERFACE
(
create_stream
,
fals
e
);
CHECK_INTERFACE
(
destroy_stream
,
tru
e
);
CHECK_INTERFACE
(
destroy_stream
,
fals
e
);
CHECK_INTERFACE
(
query_stream
,
false
);
CHECK_INTERFACE
(
query_stream
,
false
);
CHECK_INTERFACE
(
stream_add_callback
,
false
);
CHECK_INTERFACE
(
stream_add_callback
,
false
);
...
@@ -922,9 +936,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
...
@@ -922,9 +936,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE
(
query_event
,
false
);
CHECK_INTERFACE
(
query_event
,
false
);
CHECK_INTERFACE
(
synchronize_device
,
false
);
CHECK_INTERFACE
(
synchronize_device
,
false
);
CHECK_INTERFACE
(
synchronize_stream
,
tru
e
);
CHECK_INTERFACE
(
synchronize_stream
,
fals
e
);
CHECK_INTERFACE
(
synchronize_event
,
true
);
CHECK_INTERFACE
(
synchronize_event
,
true
);
CHECK_INTERFACE
(
stream_wait_event
,
tru
e
);
CHECK_INTERFACE
(
stream_wait_event
,
fals
e
);
CHECK_INTERFACE
(
device_memory_allocate
,
true
);
CHECK_INTERFACE
(
device_memory_allocate
,
true
);
CHECK_INTERFACE
(
device_memory_deallocate
,
true
);
CHECK_INTERFACE
(
device_memory_deallocate
,
true
);
...
@@ -932,9 +946,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
...
@@ -932,9 +946,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE
(
host_memory_deallocate
,
false
);
CHECK_INTERFACE
(
host_memory_deallocate
,
false
);
CHECK_INTERFACE
(
unified_memory_allocate
,
false
);
CHECK_INTERFACE
(
unified_memory_allocate
,
false
);
CHECK_INTERFACE
(
unified_memory_deallocate
,
false
);
CHECK_INTERFACE
(
unified_memory_deallocate
,
false
);
CHECK_INTERFACE
(
memory_copy_h2d
,
tru
e
);
CHECK_INTERFACE
(
memory_copy_h2d
,
fals
e
);
CHECK_INTERFACE
(
memory_copy_d2h
,
tru
e
);
CHECK_INTERFACE
(
memory_copy_d2h
,
fals
e
);
CHECK_INTERFACE
(
memory_copy_d2d
,
tru
e
);
CHECK_INTERFACE
(
memory_copy_d2d
,
fals
e
);
CHECK_INTERFACE
(
memory_copy_p2p
,
false
);
CHECK_INTERFACE
(
memory_copy_p2p
,
false
);
CHECK_INTERFACE
(
async_memory_copy_h2d
,
false
);
CHECK_INTERFACE
(
async_memory_copy_h2d
,
false
);
CHECK_INTERFACE
(
async_memory_copy_d2h
,
false
);
CHECK_INTERFACE
(
async_memory_copy_d2h
,
false
);
...
@@ -943,9 +957,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
...
@@ -943,9 +957,9 @@ bool ValidCustomCustomRuntimeParams(const CustomRuntimeParams* params) {
CHECK_INTERFACE
(
get_device_count
,
true
);
CHECK_INTERFACE
(
get_device_count
,
true
);
CHECK_INTERFACE
(
get_device_list
,
true
);
CHECK_INTERFACE
(
get_device_list
,
true
);
CHECK_INTERFACE
(
device_memory_stats
,
tru
e
);
CHECK_INTERFACE
(
device_memory_stats
,
fals
e
);
CHECK_INTERFACE
(
device_min_chunk_size
,
tru
e
);
CHECK_INTERFACE
(
device_min_chunk_size
,
fals
e
);
CHECK_INTERFACE
(
device_max_chunk_size
,
false
);
CHECK_INTERFACE
(
device_max_chunk_size
,
false
);
CHECK_INTERFACE
(
device_max_alloc_size
,
false
);
CHECK_INTERFACE
(
device_max_alloc_size
,
false
);
CHECK_INTERFACE
(
device_extra_padding_size
,
false
);
CHECK_INTERFACE
(
device_extra_padding_size
,
false
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录