Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
60af2852
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60af2852
编写于
4月 01, 2021
作者:
L
liym27
提交者:
GitHub
4月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NPU] Support dataloader on npu place. (#31867)
上级
2a672f68
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
312 addition
and
9 deletion
+312
-9
paddle/fluid/operators/reader/buffered_reader.cc
paddle/fluid/operators/reader/buffered_reader.cc
+76
-3
paddle/fluid/operators/reader/buffered_reader.h
paddle/fluid/operators/reader/buffered_reader.h
+12
-1
paddle/fluid/platform/CMakeLists.txt
paddle/fluid/platform/CMakeLists.txt
+5
-0
paddle/fluid/platform/npu_resource_pool.cc
paddle/fluid/platform/npu_resource_pool.cc
+101
-0
paddle/fluid/platform/npu_resource_pool.h
paddle/fluid/platform/npu_resource_pool.h
+64
-0
python/paddle/fluid/tests/unittests/test_dataloader_npu.py
python/paddle/fluid/tests/unittests/test_dataloader_npu.py
+46
-0
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
...id/tests/unittests/test_multiprocess_dataloader_static.py
+8
-5
未找到文件。
paddle/fluid/operators/reader/buffered_reader.cc
浏览文件 @
60af2852
...
...
@@ -43,6 +43,7 @@ BufferedReader::BufferedReader(
buffer_size_
(
buffer_size
),
pin_memory_
(
pin_memory
)
{
VLOG
(
1
)
<<
"BufferedReader"
;
#ifdef PADDLE_WITH_CUDA
if
(
platform
::
is_gpu_place
(
place_
)
&&
!
pin_memory
)
{
int
dev_idx
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
).
device
;
...
...
@@ -57,9 +58,25 @@ BufferedReader::BufferedReader(
stream_
=
platform
::
CudaStreamResourcePool
::
Instance
().
New
(
dev_idx
);
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if
(
platform
::
is_npu_place
(
place_
))
{
int
dev_idx
=
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place_
).
device
;
compute_stream_
=
((
platform
::
NPUDeviceContext
*
)(
platform
::
DeviceContextPool
::
Instance
()
.
Get
(
place_
)))
->
stream
();
events_
.
resize
(
buffer_size
);
for
(
auto
&
event
:
events_
)
{
event
=
platform
::
NpuEventResourcePool
::
Instance
().
New
(
dev_idx
);
}
stream_
=
platform
::
NpuStreamResourcePool
::
Instance
().
New
(
dev_idx
);
}
#endif
is_same_place_
=
false
;
cpu_buffer_
.
resize
(
buffer_size
);
cuda_buffer_
.
resize
(
buffer_size
);
npu_buffer_
.
resize
(
buffer_size
);
ReadTillBufferFullAsync
();
}
...
...
@@ -186,6 +203,58 @@ void BufferedReader::ReadAsync(size_t i) {
}
}
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if
(
platform
::
is_npu_place
(
place_
))
{
TensorVec
&
npu
=
npu_buffer_
[
i
];
if
(
npu
.
empty
())
{
npu
.
resize
(
cpu
.
size
());
}
else
{
PADDLE_ENFORCE_EQ
(
npu
.
size
(),
cpu
.
size
(),
platform
::
errors
::
InvalidArgument
(
"Input tensor number on NPU and CPU devices are not matched. "
"The number on NPU is %d, on CPU is %d"
,
npu
.
size
(),
cpu
.
size
()));
}
std
::
vector
<
void
*>
npu_ptrs
;
npu_ptrs
.
reserve
(
cpu
.
size
());
for
(
size_t
i
=
0
;
i
<
cpu
.
size
();
++
i
)
{
npu
[
i
].
Resize
(
cpu
[
i
].
dims
());
npu
[
i
].
set_layout
(
cpu
[
i
].
layout
());
npu_ptrs
.
emplace_back
(
npu
[
i
].
mutable_data
(
place_
,
cpu
[
i
].
type
()));
}
platform
::
SetNPUDeviceId
(
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place_
).
device
);
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtRecordEvent
(
events_
[
i
].
get
(),
compute_stream_
));
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtStreamWaitEvent
(
stream_
.
get
(),
events_
[
i
].
get
()));
platform
::
RecordEvent
record_event
(
"BufferedReader:MemoryCopy"
);
for
(
size_t
i
=
0
;
i
<
cpu
.
size
();
++
i
)
{
auto
cpu_place
=
cpu
[
i
].
place
();
auto
cpu_ptr
=
cpu
[
i
].
data
<
void
>
();
auto
npu_ptr
=
npu_ptrs
[
i
];
auto
size
=
cpu
[
i
].
numel
()
*
paddle
::
framework
::
SizeOfType
(
cpu
[
i
].
type
());
if
((
platform
::
is_npu_place
(
cpu_place
)))
{
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place_
),
npu_ptr
,
BOOST_GET_CONST
(
platform
::
NPUPlace
,
cpu_place
),
cpu_ptr
,
size
,
stream_
.
get
());
}
else
{
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place_
),
npu_ptr
,
BOOST_GET_CONST
(
platform
::
CPUPlace
,
cpu_place
),
cpu_ptr
,
size
,
stream_
.
get
());
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtSynchronizeStream
(
stream_
.
get
()));
}
npu
[
i
].
set_lod
(
cpu
[
i
].
lod
());
}
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtSynchronizeStream
(
stream_
.
get
()));
}
#endif
return
i
;
}));
}
...
...
@@ -217,9 +286,13 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
return
;
}
*
out
=
std
::
move
((
platform
::
is_gpu_place
(
place_
)
&&
!
is_same_place_
)
?
cuda_buffer_
[
i
]
:
cpu_buffer_
[
i
]);
if
(
platform
::
is_gpu_place
(
place_
)
&&
!
is_same_place_
)
{
*
out
=
std
::
move
(
cuda_buffer_
[
i
]);
}
else
if
(
platform
::
is_npu_place
(
place_
)
&&
!
is_same_place_
)
{
*
out
=
std
::
move
(
npu_buffer_
[
i
]);
}
else
{
*
out
=
std
::
move
(
cpu_buffer_
[
i
]);
}
// Do not push current position into ReadAsync. Push the previous position
// Since all computation in fluid are async, change the data of
...
...
paddle/fluid/operators/reader/buffered_reader.h
浏览文件 @
60af2852
...
...
@@ -25,7 +25,10 @@
#include "paddle/fluid/platform/cuda_resource_pool.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/npu_info.h"
#include "paddle/fluid/platform/npu_resource_pool.h"
#endif
namespace
paddle
{
namespace
operators
{
namespace
reader
{
...
...
@@ -67,12 +70,20 @@ class BufferedReader : public framework::DecoratedReader {
bool
is_same_place_
;
std
::
vector
<
TensorVec
>
cpu_buffer_
;
std
::
vector
<
TensorVec
>
cuda_buffer_
;
std
::
vector
<
TensorVec
>
npu_buffer_
;
size_t
prev_pos_
{
-
1UL
};
#ifdef PADDLE_WITH_CUDA
cudaStream_t
compute_stream_
;
std
::
shared_ptr
<
platform
::
CudaStreamObject
>
stream_
;
std
::
vector
<
std
::
shared_ptr
<
platform
::
CudaEventObject
>>
events_
;
#endif
#ifdef PADDLE_WITH_ASCEND_CL
aclrtStream
compute_stream_
;
std
::
shared_ptr
<
platform
::
NpuStreamObject
>
stream_
;
std
::
vector
<
std
::
shared_ptr
<
platform
::
NpuEventObject
>>
events_
;
#endif
};
}
// namespace reader
...
...
paddle/fluid/platform/CMakeLists.txt
浏览文件 @
60af2852
...
...
@@ -135,6 +135,11 @@ if(WITH_GPU)
target_link_libraries
(
device_context cuda_resource_pool
)
endif
()
if
(
WITH_ASCEND_CL
)
cc_library
(
npu_resource_pool SRCS npu_resource_pool.cc DEPS npu_info
)
target_link_libraries
(
device_context npu_resource_pool
)
endif
()
nv_test
(
device_context_test SRCS device_context_test.cu DEPS device_context gpu_info
)
cc_test
(
init_test SRCS init_test.cc DEPS device_context
)
...
...
paddle/fluid/platform/npu_resource_pool.cc
0 → 100644
浏览文件 @
60af2852
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/npu_resource_pool.h"
#include "paddle/fluid/platform/npu_info.h"
namespace
paddle
{
namespace
platform
{
NpuStreamResourcePool
::
NpuStreamResourcePool
()
{
int
dev_cnt
=
platform
::
GetNPUDeviceCount
();
pool_
.
reserve
(
dev_cnt
);
for
(
int
dev_idx
=
0
;
dev_idx
<
dev_cnt
;
++
dev_idx
)
{
auto
creator
=
[
dev_idx
]
{
platform
::
SetNPUDeviceId
(
dev_idx
);
aclrtStream
stream
;
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtCreateStream
(
&
stream
));
return
stream
;
};
auto
deleter
=
[
dev_idx
](
aclrtStream
stream
)
{
platform
::
SetNPUDeviceId
(
dev_idx
);
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtDestroyStream
(
stream
));
};
pool_
.
emplace_back
(
ResourcePool
<
NpuStreamObject
>::
Create
(
creator
,
deleter
));
}
}
NpuStreamResourcePool
&
NpuStreamResourcePool
::
Instance
()
{
static
NpuStreamResourcePool
pool
;
return
pool
;
}
std
::
shared_ptr
<
NpuStreamObject
>
NpuStreamResourcePool
::
New
(
int
dev_idx
)
{
PADDLE_ENFORCE_GE
(
dev_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The dev_idx should be not less than 0, but got %d."
,
dev_idx
));
PADDLE_ENFORCE_LT
(
dev_idx
,
pool_
.
size
(),
platform
::
errors
::
OutOfRange
(
"The dev_idx should be less than device count %d, but got %d."
,
pool_
.
size
(),
dev_idx
));
return
pool_
[
dev_idx
]
->
New
();
}
NpuEventResourcePool
::
NpuEventResourcePool
()
{
int
dev_cnt
=
platform
::
GetNPUDeviceCount
();
pool_
.
reserve
(
dev_cnt
);
for
(
int
dev_idx
=
0
;
dev_idx
<
dev_cnt
;
++
dev_idx
)
{
auto
creator
=
[
dev_idx
]
{
platform
::
SetNPUDeviceId
(
dev_idx
);
aclrtEvent
event
;
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtCreateEvent
(
&
event
));
return
event
;
};
auto
deleter
=
[
dev_idx
](
aclrtEvent
event
)
{
platform
::
SetNPUDeviceId
(
dev_idx
);
PADDLE_ENFORCE_NPU_SUCCESS
(
aclrtDestroyEvent
(
event
));
};
pool_
.
emplace_back
(
ResourcePool
<
NpuEventObject
>::
Create
(
creator
,
deleter
));
}
}
NpuEventResourcePool
&
NpuEventResourcePool
::
Instance
()
{
static
NpuEventResourcePool
pool
;
return
pool
;
}
std
::
shared_ptr
<
NpuEventObject
>
NpuEventResourcePool
::
New
(
int
dev_idx
)
{
PADDLE_ENFORCE_GE
(
dev_idx
,
0
,
platform
::
errors
::
InvalidArgument
(
"The dev_idx should be not less than 0, but got %d."
,
dev_idx
));
PADDLE_ENFORCE_LT
(
dev_idx
,
pool_
.
size
(),
platform
::
errors
::
OutOfRange
(
"The dev_idx should be less than device count %d, but got %d."
,
pool_
.
size
(),
dev_idx
));
return
pool_
[
dev_idx
]
->
New
();
}
}
// namespace platform
}
// namespace paddle
#endif
paddle/fluid/platform/npu_resource_pool.h
0 → 100644
浏览文件 @
60af2852
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <type_traits>
#include <vector>
#include "acl/acl.h"
#include "paddle/fluid/platform/resource_pool.h"
namespace
paddle
{
namespace
platform
{
using
NpuStreamObject
=
std
::
remove_pointer
<
aclrtStream
>::
type
;
using
NpuEventObject
=
std
::
remove_pointer
<
aclrtEvent
>::
type
;
class
NpuStreamResourcePool
{
public:
std
::
shared_ptr
<
NpuStreamObject
>
New
(
int
dev_idx
);
static
NpuStreamResourcePool
&
Instance
();
private:
NpuStreamResourcePool
();
DISABLE_COPY_AND_ASSIGN
(
NpuStreamResourcePool
);
private:
std
::
vector
<
std
::
shared_ptr
<
ResourcePool
<
NpuStreamObject
>>>
pool_
;
};
class
NpuEventResourcePool
{
public:
std
::
shared_ptr
<
NpuEventObject
>
New
(
int
dev_idx
);
static
NpuEventResourcePool
&
Instance
();
private:
NpuEventResourcePool
();
DISABLE_COPY_AND_ASSIGN
(
NpuEventResourcePool
);
private:
std
::
vector
<
std
::
shared_ptr
<
ResourcePool
<
NpuEventObject
>>>
pool_
;
};
}
// namespace platform
}
// namespace paddle
#endif
python/paddle/fluid/tests/unittests/test_dataloader_npu.py
0 → 100644
浏览文件 @
60af2852
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
sys
import
unittest
import
numpy
as
np
import
paddle
from
..unittests.test_multiprocess_dataloader_static
import
TestStaticDataLoader
paddle
.
enable_static
()
class
TestStaticDataLoader
(
TestStaticDataLoader
):
def
test_main
(
self
):
results
=
[]
places
=
[
paddle
.
NPUPlace
(
0
)]
for
num_workers
in
[
0
,
2
]:
print
(
self
.
__class__
.
__name__
,
places
,
num_workers
)
sys
.
stdout
.
flush
()
ret
=
self
.
_run_main
(
num_workers
=
num_workers
,
places
=
places
,
use_pe
=
False
)
results
.
append
(
ret
)
diff
=
np
.
max
(
np
.
abs
(
results
[
0
][
'loss'
]
-
results
[
1
][
'loss'
])
/
np
.
abs
(
results
[
0
][
'loss'
]))
self
.
assertLess
(
diff
,
1e-2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py
浏览文件 @
60af2852
...
...
@@ -101,7 +101,7 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True):
class
TestStaticDataLoader
(
unittest
.
TestCase
):
def
run_main
(
self
,
num_workers
,
places
):
def
run_main
(
self
,
num_workers
,
places
,
use_pe
=
True
):
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
startup_prog
,
main_prog
,
image
,
label
,
loss
=
simple_fc_net_static
()
...
...
@@ -120,10 +120,13 @@ class TestStaticDataLoader(unittest.TestCase):
exe
=
fluid
.
Executor
(
place
=
places
[
0
])
exe
.
run
(
startup_prog
)
prog
=
fluid
.
CompiledProgram
(
main_prog
)
if
len
(
places
)
>
1
:
prog
=
prog
.
with_data_parallel
(
loss_name
=
loss
.
name
,
places
=
places
)
if
use_pe
:
prog
=
fluid
.
CompiledProgram
(
main_prog
)
if
len
(
places
)
>
1
:
prog
=
prog
.
with_data_parallel
(
loss_name
=
loss
.
name
,
places
=
places
)
else
:
prog
=
main_prog
step_list
=
[]
loss_list
=
[]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录