Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
33429630
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
33429630
编写于
2月 22, 2021
作者:
Q
Qi Li
提交者:
GitHub
2月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid platform for rocm39 (part4), test=develop (#30936)
上级
a5c56d83
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
854 addition
and
34 deletion
+854
-34
paddle/fluid/memory/allocation/CMakeLists.txt
paddle/fluid/memory/allocation/CMakeLists.txt
+20
-3
paddle/fluid/memory/malloc_test.cu
paddle/fluid/memory/malloc_test.cu
+39
-3
paddle/fluid/memory/memcpy.cc
paddle/fluid/memory/memcpy.cc
+55
-1
paddle/fluid/memory/pinned_memory_test.cu
paddle/fluid/memory/pinned_memory_test.cu
+55
-5
paddle/fluid/platform/device_memory_aligment.cc
paddle/fluid/platform/device_memory_aligment.cc
+1
-1
paddle/fluid/platform/device_memory_aligment.h
paddle/fluid/platform/device_memory_aligment.h
+1
-1
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+3
-2
paddle/fluid/platform/dynload/hiprtc.h
paddle/fluid/platform/dynload/hiprtc.h
+1
-0
paddle/fluid/platform/dynload/miopen.h
paddle/fluid/platform/dynload/miopen.h
+5
-0
paddle/fluid/platform/dynload/rocm_driver.h
paddle/fluid/platform/dynload/rocm_driver.h
+1
-0
paddle/fluid/platform/event.h
paddle/fluid/platform/event.h
+7
-4
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+5
-5
paddle/fluid/platform/for_range.h
paddle/fluid/platform/for_range.h
+1
-1
paddle/fluid/platform/init.cc
paddle/fluid/platform/init.cc
+12
-6
paddle/fluid/platform/init_test.cc
paddle/fluid/platform/init_test.cc
+3
-2
paddle/fluid/platform/miopen_helper.h
paddle/fluid/platform/miopen_helper.h
+552
-0
paddle/fluid/platform/miopen_helper_test.cc
paddle/fluid/platform/miopen_helper_test.cc
+93
-0
未找到文件。
paddle/fluid/memory/allocation/CMakeLists.txt
浏览文件 @
33429630
...
@@ -16,13 +16,20 @@ endif()
...
@@ -16,13 +16,20 @@ endif()
if
(
WITH_GPU
)
if
(
WITH_GPU
)
nv_library
(
cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard
)
nv_library
(
cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard
)
nv_library
(
thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator
)
nv_library
(
thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator
)
nv_library
(
pinned_allocator SRCS pinned_allocator.cc DEPS allocator
)
cc_test
(
thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator
)
endif
()
if
(
WITH_ROCM
)
hip_library
(
cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard
)
hip_library
(
thread_local_allocator SRCS thread_local_allocator.cc DEPS allocator
)
hip_library
(
pinned_allocator SRCS pinned_allocator.cc DEPS allocator
)
cc_test
(
thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator
)
cc_test
(
thread_local_allocator_test SRCS thread_local_allocator_test.cc DEPS thread_local_allocator
)
endif
()
endif
()
cc_library
(
retry_allocator SRCS retry_allocator.cc DEPS allocator
)
cc_library
(
retry_allocator SRCS retry_allocator.cc DEPS allocator
)
nv_library
(
pinned_allocator SRCS pinned_allocator.cc DEPS allocator
)
if
(
WITH_GPU OR WITH_ROCM
)
if
(
WITH_GPU
)
set
(
AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator
)
set
(
AllocatorFacadeDeps gpu_info cuda_allocator pinned_allocator cuda_device_guard thread_local_allocator
)
elseif
(
WITH_XPU
)
elseif
(
WITH_XPU
)
set
(
AllocatorFacadeDeps xpu_info
)
set
(
AllocatorFacadeDeps xpu_info
)
...
@@ -40,6 +47,16 @@ if (WITH_GPU)
...
@@ -40,6 +47,16 @@ if (WITH_GPU)
cuda_allocator
cuda_allocator
device_context
device_context
memcpy
)
memcpy
)
elseif
(
WITH_ROCM
)
hip_test
(
best_fit_allocator_test
SRCS best_fit_allocator_test.cc
best_fit_allocator_test.cu
DEPS best_fit_allocator
locked_allocator
cpu_allocator
cuda_allocator
device_context
memcpy
)
else
()
else
()
cc_test
(
best_fit_allocator_test
cc_test
(
best_fit_allocator_test
SRCS best_fit_allocator_test.cc
SRCS best_fit_allocator_test.cc
...
@@ -57,7 +74,7 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
...
@@ -57,7 +74,7 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS allocator_strategy)
cc_test
(
retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator
)
cc_test
(
retry_allocator_test SRCS retry_allocator_test.cc DEPS retry_allocator locked_allocator cpu_allocator
)
if
(
WITH_TESTING
)
if
(
WITH_TESTING
)
if
(
WITH_GPU
AND TARGET retry_allocator_test
)
if
(
(
WITH_GPU OR WITH_ROCM
)
AND TARGET retry_allocator_test
)
target_link_libraries
(
retry_allocator_test cuda_allocator
)
target_link_libraries
(
retry_allocator_test cuda_allocator
)
endif
()
endif
()
...
...
paddle/fluid/memory/malloc_test.cu
浏览文件 @
33429630
...
@@ -12,8 +12,15 @@
...
@@ -12,8 +12,15 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
...
@@ -40,8 +47,13 @@ __global__ void kernel(float *x, int n) {
...
@@ -40,8 +47,13 @@ __global__ void kernel(float *x, int n) {
void
CheckKernelOutput
(
float
*
x
,
int
n
)
{
void
CheckKernelOutput
(
float
*
x
,
int
n
)
{
auto
host_x
=
std
::
unique_ptr
<
float
[]
>
(
new
float
[
n
]);
auto
host_x
=
std
::
unique_ptr
<
float
[]
>
(
new
float
[
n
]);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE
(
hipSuccess
==
hipMemcpy
(
host_x
.
get
(),
x
,
n
*
sizeof
(
float
),
hipMemcpyDeviceToHost
));
#else
EXPECT_TRUE
(
cudaSuccess
==
cudaMemcpy
(
host_x
.
get
(),
x
,
n
*
sizeof
(
float
),
EXPECT_TRUE
(
cudaSuccess
==
cudaMemcpy
(
host_x
.
get
(),
x
,
n
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
));
cudaMemcpyDeviceToHost
));
#endif
EXPECT_GE
(
host_x
[
i
]
+
DELTA
,
3.14159
f
*
i
);
EXPECT_GE
(
host_x
[
i
]
+
DELTA
,
3.14159
f
*
i
);
EXPECT_LE
(
host_x
[
i
]
-
DELTA
,
3.14159
f
*
i
);
EXPECT_LE
(
host_x
[
i
]
-
DELTA
,
3.14159
f
*
i
);
}
}
...
@@ -53,13 +65,22 @@ void MultiStreamCompute(float **data, float **second_data,
...
@@ -53,13 +65,22 @@ void MultiStreamCompute(float **data, float **second_data,
AllocationPtr
allocation_ptr
=
Alloc
(
ctx
,
N
*
sizeof
(
float
));
AllocationPtr
allocation_ptr
=
Alloc
(
ctx
,
N
*
sizeof
(
float
));
EXPECT_GE
(
allocation_ptr
->
size
(),
N
*
sizeof
(
float
));
EXPECT_GE
(
allocation_ptr
->
size
(),
N
*
sizeof
(
float
));
*
data
=
reinterpret_cast
<
float
*>
(
allocation_ptr
->
ptr
());
*
data
=
reinterpret_cast
<
float
*>
(
allocation_ptr
->
ptr
());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL
((
kernel
),
dim3
(
1
),
dim3
(
64
),
0
,
ctx
.
stream
(),
*
data
,
N
);
#else
kernel
<<<
1
,
64
,
0
,
ctx
.
stream
()
>>>
(
*
data
,
N
);
kernel
<<<
1
,
64
,
0
,
ctx
.
stream
()
>>>
(
*
data
,
N
);
#endif
// allocate and compute on same stream again
// allocate and compute on same stream again
allocation_ptr
=
Alloc
(
ctx
,
N
*
sizeof
(
float
));
allocation_ptr
=
Alloc
(
ctx
,
N
*
sizeof
(
float
));
EXPECT_GE
(
allocation_ptr
->
size
(),
N
*
sizeof
(
float
));
EXPECT_GE
(
allocation_ptr
->
size
(),
N
*
sizeof
(
float
));
*
second_data
=
reinterpret_cast
<
float
*>
(
allocation_ptr
->
ptr
());
*
second_data
=
reinterpret_cast
<
float
*>
(
allocation_ptr
->
ptr
());
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL
((
kernel
),
dim3
(
1
),
dim3
(
64
),
0
,
ctx
.
stream
(),
*
second_data
,
N
);
#else
kernel
<<<
1
,
64
,
0
,
ctx
.
stream
()
>>>
(
*
second_data
,
N
);
kernel
<<<
1
,
64
,
0
,
ctx
.
stream
()
>>>
(
*
second_data
,
N
);
#endif
}
}
TEST
(
Malloc
,
CUDADeviceContextMultiStream
)
{
TEST
(
Malloc
,
CUDADeviceContextMultiStream
)
{
...
@@ -75,8 +96,12 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
...
@@ -75,8 +96,12 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
float
*
second_data
[
NUM_STREAMS
];
float
*
second_data
[
NUM_STREAMS
];
CudaDevCtxVec
dev_ctx
;
CudaDevCtxVec
dev_ctx
;
// default stream
// default stream
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL
((
kernel
),
dim3
(
1
),
dim3
(
64
),
0
,
0
,
main_stream_data
,
N
);
#else
kernel
<<<
1
,
64
>>>
(
main_stream_data
,
N
);
kernel
<<<
1
,
64
>>>
(
main_stream_data
,
N
);
#endif
main_stream_alloc_ptr
.
reset
();
main_stream_alloc_ptr
.
reset
();
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
...
@@ -85,7 +110,11 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
...
@@ -85,7 +110,11 @@ TEST(Malloc, CUDADeviceContextMultiStream) {
MultiStreamCompute
(
&
data
[
i
],
&
second_data
[
i
],
*
dev_ctx
[
i
]);
MultiStreamCompute
(
&
data
[
i
],
&
second_data
[
i
],
*
dev_ctx
[
i
]);
}
}
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE
(
hipSuccess
==
hipDeviceSynchronize
());
#else
EXPECT_TRUE
(
cudaSuccess
==
cudaDeviceSynchronize
());
EXPECT_TRUE
(
cudaSuccess
==
cudaDeviceSynchronize
());
#endif
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
CheckKernelOutput
(
data
[
i
],
N
);
CheckKernelOutput
(
data
[
i
],
N
);
CheckKernelOutput
(
second_data
[
i
],
N
);
CheckKernelOutput
(
second_data
[
i
],
N
);
...
@@ -106,8 +135,12 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
...
@@ -106,8 +135,12 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
CudaDevCtxVec
dev_ctx
;
CudaDevCtxVec
dev_ctx
;
std
::
vector
<
std
::
thread
>
threads
;
std
::
vector
<
std
::
thread
>
threads
;
// default stream
// default stream
#ifdef PADDLE_WITH_HIP
hipLaunchKernelGGL
((
kernel
),
dim3
(
1
),
dim3
(
64
),
0
,
0
,
main_stream_data
,
N
);
#else
kernel
<<<
1
,
64
>>>
(
main_stream_data
,
N
);
kernel
<<<
1
,
64
>>>
(
main_stream_data
,
N
);
#endif
main_stream_alloc_ptr
.
reset
();
main_stream_alloc_ptr
.
reset
();
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
...
@@ -120,8 +153,11 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
...
@@ -120,8 +153,11 @@ TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
threads
[
i
].
join
();
threads
[
i
].
join
();
}
}
#ifdef PADDLE_WITH_HIP
EXPECT_TRUE
(
hipSuccess
==
hipDeviceSynchronize
());
#else
EXPECT_TRUE
(
cudaSuccess
==
cudaDeviceSynchronize
());
EXPECT_TRUE
(
cudaSuccess
==
cudaDeviceSynchronize
());
#endif
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
for
(
int
i
=
0
;
i
<
NUM_STREAMS
;
++
i
)
{
CheckKernelOutput
(
data
[
i
],
N
);
CheckKernelOutput
(
data
[
i
],
N
);
CheckKernelOutput
(
second_data
[
i
],
N
);
CheckKernelOutput
(
second_data
[
i
],
N
);
...
...
paddle/fluid/memory/memcpy.cc
浏览文件 @
33429630
...
@@ -196,9 +196,22 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
...
@@ -196,9 +196,22 @@ void Copy<platform::XPUPlace, platform::XPUPlace>(platform::XPUPlace dst_place,
}
}
#endif
#endif
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
static
constexpr
size_t
kMaxGpuAsyncCopyBytes
=
64
*
1024
;
// 64K
static
constexpr
size_t
kMaxGpuAsyncCopyBytes
=
64
*
1024
;
// 64K
#ifdef PADDLE_WITH_HIP
inline
void
SyncCUDAStream
()
{
#if !defined(_WIN32)
hipStreamSynchronize
(
0
);
#else
hipError_t
e_sync
=
hipSuccess
;
while
(
e_sync
=
hipStreamQuery
(
0
))
{
if
(
e_sync
==
hipErrorNotReady
)
continue
;
break
;
}
#endif
}
#else
inline
void
SyncCUDAStream
()
{
inline
void
SyncCUDAStream
()
{
#if !defined(_WIN32)
#if !defined(_WIN32)
cudaStreamSynchronize
(
0
);
cudaStreamSynchronize
(
0
);
...
@@ -210,6 +223,7 @@ inline void SyncCUDAStream() {
...
@@ -210,6 +223,7 @@ inline void SyncCUDAStream() {
}
}
#endif
#endif
}
}
#endif
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// NOTE(zcd): Do not use GpuMemcpySync as much as possible.
// because GpuMemcpySync issues the copying command to the default stream,
// because GpuMemcpySync issues the copying command to the default stream,
...
@@ -228,10 +242,18 @@ void Copy<platform::CPUPlace, platform::CUDAPlace>(
...
@@ -228,10 +242,18 @@ void Copy<platform::CPUPlace, platform::CUDAPlace>(
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
if
(
stream
)
{
if
(
stream
)
{
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:GPU->CPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:GPU->CPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
hipMemcpyDeviceToHost
,
stream
);
#else
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
#endif
}
else
{
}
else
{
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:GPU->CPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:GPU->CPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
hipMemcpyDeviceToHost
);
#else
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
#endif
// FIXME(zjl): do we really need it?
// FIXME(zjl): do we really need it?
if
(
num
<=
kMaxGpuAsyncCopyBytes
)
{
if
(
num
<=
kMaxGpuAsyncCopyBytes
)
{
SyncCUDAStream
();
SyncCUDAStream
();
...
@@ -250,10 +272,18 @@ void Copy<platform::CUDAPlace, platform::CPUPlace>(
...
@@ -250,10 +272,18 @@ void Copy<platform::CUDAPlace, platform::CPUPlace>(
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
if
(
stream
)
{
if
(
stream
)
{
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:CPU->GPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:CPU->GPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
hipMemcpyHostToDevice
,
stream
);
#else
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
#endif
}
else
{
}
else
{
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:CPU->GPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:CPU->GPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
hipMemcpyHostToDevice
);
#else
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
#endif
// FIXME(zjl): do we really need it?
// FIXME(zjl): do we really need it?
if
(
num
<=
kMaxGpuAsyncCopyBytes
)
{
if
(
num
<=
kMaxGpuAsyncCopyBytes
)
{
SyncCUDAStream
();
SyncCUDAStream
();
...
@@ -273,10 +303,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
...
@@ -273,10 +303,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
if
(
stream
)
{
if
(
stream
)
{
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync(same_gpu):GPU->GPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync(same_gpu):GPU->GPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
hipMemcpyDeviceToDevice
,
stream
);
#else
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
#endif
}
else
{
}
else
{
platform
::
RecordEvent
record_event
(
"GpuMemcpySync(same_gpu):GPU->GPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpySync(same_gpu):GPU->GPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
hipMemcpyDeviceToDevice
);
#else
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
);
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
);
#endif
}
}
}
else
{
}
else
{
if
(
stream
)
{
if
(
stream
)
{
...
@@ -332,10 +370,18 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
...
@@ -332,10 +370,18 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
if
(
stream
)
{
if
(
stream
)
{
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:GPU->CUDAPinned"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:GPU->CUDAPinned"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
hipMemcpyDeviceToHost
,
stream
);
#else
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
#endif
}
else
{
}
else
{
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:GPU->CUDAPinned"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:GPU->CUDAPinned"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
hipMemcpyDeviceToHost
);
#else
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
#endif
}
}
}
}
...
@@ -351,10 +397,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
...
@@ -351,10 +397,18 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
<<
dst_place
<<
" by thream("
<<
stream
<<
")"
;
if
(
stream
)
{
if
(
stream
)
{
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:CUDAPinned->GPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpyAsync:CUDAPinned->GPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
hipMemcpyHostToDevice
,
stream
);
#else
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
#endif
}
else
{
}
else
{
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:CUDAPinned->GPU"
);
platform
::
RecordEvent
record_event
(
"GpuMemcpySync:CUDAPinned->GPU"
);
#ifdef PADDLE_WITH_HIP
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
hipMemcpyHostToDevice
);
#else
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
#endif
}
}
}
}
...
...
paddle/fluid/memory/pinned_memory_test.cu
浏览文件 @
33429630
...
@@ -41,27 +41,44 @@ float test_pinned_memory() {
...
@@ -41,27 +41,44 @@ float test_pinned_memory() {
const
int
iteration
=
10
;
const
int
iteration
=
10
;
// create event start and end
// create event start and end
cuda
Event_t
start_e
,
stop_e
,
copying_e
;
gpu
Event_t
start_e
,
stop_e
,
copying_e
;
float
elapsedTime
=
0
;
float
elapsedTime
=
0
;
#ifdef PADDLE_WITH_HIP
hipEventCreate
(
&
start_e
);
hipEventCreate
(
&
stop_e
);
hipEventCreate
(
&
copying_e
);
#else
cudaEventCreate
(
&
start_e
);
cudaEventCreate
(
&
start_e
);
cudaEventCreate
(
&
stop_e
);
cudaEventCreate
(
&
stop_e
);
cudaEventCreate
(
&
copying_e
);
cudaEventCreate
(
&
copying_e
);
#endif
// create computation stream, data copying stream
// create computation stream, data copying stream
cudaStream_t
computation_stream
,
copying_stream
;
gpuStream_t
computation_stream
,
copying_stream
;
#ifdef PADDLE_WITH_HIP
hipStreamCreate
(
&
computation_stream
);
hipStreamCreate
(
&
copying_stream
);
#else
cudaStreamCreate
(
&
computation_stream
);
cudaStreamCreate
(
&
computation_stream
);
cudaStreamCreate
(
&
copying_stream
);
cudaStreamCreate
(
&
copying_stream
);
#endif
// create record event, pinned memory, gpu memory
// create record event, pinned memory, gpu memory
std
::
vector
<
cuda
Event_t
>
record_event
(
iteration
);
std
::
vector
<
gpu
Event_t
>
record_event
(
iteration
);
std
::
vector
<
float
*>
input_pinned_mem
(
iteration
);
std
::
vector
<
float
*>
input_pinned_mem
(
iteration
);
std
::
vector
<
float
*>
gpu_mem
(
iteration
);
std
::
vector
<
float
*>
gpu_mem
(
iteration
);
std
::
vector
<
float
*>
output_pinned_mem
(
iteration
);
std
::
vector
<
float
*>
output_pinned_mem
(
iteration
);
// initial data
// initial data
for
(
int
j
=
0
;
j
<
iteration
;
++
j
)
{
for
(
int
j
=
0
;
j
<
iteration
;
++
j
)
{
#ifdef PADDLE_WITH_HIP
hipEventCreateWithFlags
(
&
record_event
[
j
],
hipEventDisableTiming
);
hipEventCreate
(
&
(
record_event
[
j
]));
#else
cudaEventCreateWithFlags
(
&
record_event
[
j
],
cudaEventDisableTiming
);
cudaEventCreateWithFlags
(
&
record_event
[
j
],
cudaEventDisableTiming
);
cudaEventCreate
(
&
(
record_event
[
j
]));
cudaEventCreate
(
&
(
record_event
[
j
]));
#endif
input_pinned_mem
[
j
]
=
static_cast
<
float
*>
(
input_pinned_mem
[
j
]
=
static_cast
<
float
*>
(
paddle
::
memory
::
Alloc
(
cpu_place
,
data_size
*
sizeof
(
float
)));
paddle
::
memory
::
Alloc
(
cpu_place
,
data_size
*
sizeof
(
float
)));
output_pinned_mem
[
j
]
=
static_cast
<
float
*>
(
output_pinned_mem
[
j
]
=
static_cast
<
float
*>
(
...
@@ -74,7 +91,11 @@ float test_pinned_memory() {
...
@@ -74,7 +91,11 @@ float test_pinned_memory() {
}
}
}
}
#ifdef PADDLE_WITH_HIP
hipEventRecord
(
start_e
,
computation_stream
);
#else
cudaEventRecord
(
start_e
,
computation_stream
);
cudaEventRecord
(
start_e
,
computation_stream
);
#endif
// computation
// computation
for
(
int
m
=
0
;
m
<
30
;
++
m
)
{
for
(
int
m
=
0
;
m
<
30
;
++
m
)
{
...
@@ -88,13 +109,21 @@ float test_pinned_memory() {
...
@@ -88,13 +109,21 @@ float test_pinned_memory() {
// call kernel on computation stream.
// call kernel on computation stream.
Kernel
<<<
4
,
1024
,
0
,
computation_stream
>>>
(
gpu_mem
[
i
],
data_size
);
Kernel
<<<
4
,
1024
,
0
,
computation_stream
>>>
(
gpu_mem
[
i
],
data_size
);
#ifdef PADDLE_WITH_HIP
// record event_computation on computation stream
hipEventRecord
(
record_event
[
i
],
computation_stream
);
// wait event_computation on copy stream.
// note: this operation is async.
hipStreamWaitEvent
(
copying_stream
,
record_event
[
i
],
0
);
#else
// record event_computation on computation stream
// record event_computation on computation stream
cudaEventRecord
(
record_event
[
i
],
computation_stream
);
cudaEventRecord
(
record_event
[
i
],
computation_stream
);
// wait event_computation on copy stream.
// wait event_computation on copy stream.
// note: this operation is async.
// note: this operation is async.
cudaStreamWaitEvent
(
copying_stream
,
record_event
[
i
],
0
);
cudaStreamWaitEvent
(
copying_stream
,
record_event
[
i
],
0
);
#endif
// copy data GPU->CPU, on copy stream.
// copy data GPU->CPU, on copy stream.
// note: this operation is async for pinned memory.
// note: this operation is async for pinned memory.
paddle
::
memory
::
Copy
(
cpu_place
,
output_pinned_mem
[
i
],
cuda_place
,
paddle
::
memory
::
Copy
(
cpu_place
,
output_pinned_mem
[
i
],
cuda_place
,
...
@@ -103,6 +132,16 @@ float test_pinned_memory() {
...
@@ -103,6 +132,16 @@ float test_pinned_memory() {
}
}
}
}
#ifdef PADDLE_WITH_HIP
hipEventRecord
(
copying_e
,
copying_stream
);
hipStreamWaitEvent
(
computation_stream
,
copying_e
,
0
);
hipEventRecord
(
stop_e
,
computation_stream
);
hipEventSynchronize
(
start_e
);
hipEventSynchronize
(
stop_e
);
hipEventElapsedTime
(
&
elapsedTime
,
start_e
,
stop_e
);
#else
cudaEventRecord
(
copying_e
,
copying_stream
);
cudaEventRecord
(
copying_e
,
copying_stream
);
cudaStreamWaitEvent
(
computation_stream
,
copying_e
,
0
);
cudaStreamWaitEvent
(
computation_stream
,
copying_e
,
0
);
...
@@ -111,6 +150,7 @@ float test_pinned_memory() {
...
@@ -111,6 +150,7 @@ float test_pinned_memory() {
cudaEventSynchronize
(
start_e
);
cudaEventSynchronize
(
start_e
);
cudaEventSynchronize
(
stop_e
);
cudaEventSynchronize
(
stop_e
);
cudaEventElapsedTime
(
&
elapsedTime
,
start_e
,
stop_e
);
cudaEventElapsedTime
(
&
elapsedTime
,
start_e
,
stop_e
);
#endif
// std::cout << cpu_place << " "
// std::cout << cpu_place << " "
// << "time consume:" << elapsedTime / 30 << std::endl;
// << "time consume:" << elapsedTime / 30 << std::endl;
...
@@ -123,12 +163,22 @@ float test_pinned_memory() {
...
@@ -123,12 +163,22 @@ float test_pinned_memory() {
}
}
}
}
// destroy resource
// destroy resource
#ifdef PADDLE_WITH_HIP
hipEventDestroy
(
copying_e
);
hipEventDestroy
(
start_e
);
hipEventDestroy
(
stop_e
);
#else
cudaEventDestroy
(
copying_e
);
cudaEventDestroy
(
copying_e
);
cudaEventDestroy
(
start_e
);
cudaEventDestroy
(
start_e
);
cudaEventDestroy
(
stop_e
);
cudaEventDestroy
(
stop_e
);
#endif
for
(
int
j
=
0
;
j
<
10
;
++
j
)
{
for
(
int
j
=
0
;
j
<
10
;
++
j
)
{
#ifdef PADDLE_WITH_HIP
hipEventDestroy
((
record_event
[
j
]));
#else
cudaEventDestroy
((
record_event
[
j
]));
cudaEventDestroy
((
record_event
[
j
]));
#endif
paddle
::
memory
::
Free
(
cpu_place
,
input_pinned_mem
[
j
]);
paddle
::
memory
::
Free
(
cpu_place
,
input_pinned_mem
[
j
]);
paddle
::
memory
::
Free
(
cpu_place
,
output_pinned_mem
[
j
]);
paddle
::
memory
::
Free
(
cpu_place
,
output_pinned_mem
[
j
]);
paddle
::
memory
::
Free
(
cuda_place
,
gpu_mem
[
j
]);
paddle
::
memory
::
Free
(
cuda_place
,
gpu_mem
[
j
]);
...
...
paddle/fluid/platform/device_memory_aligment.cc
浏览文件 @
33429630
...
@@ -21,7 +21,7 @@ size_t Alignment(size_t size, const platform::Place &place) {
...
@@ -21,7 +21,7 @@ size_t Alignment(size_t size, const platform::Place &place) {
if
(
platform
::
is_cpu_place
(
place
))
{
if
(
platform
::
is_cpu_place
(
place
))
{
alignment
=
CpuMinChunkSize
();
alignment
=
CpuMinChunkSize
();
}
else
{
}
else
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
alignment
=
GpuMinChunkSize
();
alignment
=
GpuMinChunkSize
();
#else
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
...
...
paddle/fluid/platform/device_memory_aligment.h
浏览文件 @
33429630
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#endif
...
...
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
33429630
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include <string>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "glog/logging.h"
...
@@ -337,7 +338,7 @@ void* GetNVRTCDsoHandle() {
...
@@ -337,7 +338,7 @@ void* GetNVRTCDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvrtc.dylib"
,
false
);
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvrtc.dylib"
,
false
);
#elif defined(PADDLE_WITH_HIP)
#elif defined(PADDLE_WITH_HIP)
return
GetDsoHandleFromSearchPath
(
FLAGS_rocm_dir
,
"lib
hiprtc
.so"
,
false
);
return
GetDsoHandleFromSearchPath
(
FLAGS_rocm_dir
,
"lib
amdhip64
.so"
,
false
);
#else
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvrtc.so"
,
false
);
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvrtc.so"
,
false
);
#endif
#endif
...
@@ -347,7 +348,7 @@ void* GetCUDADsoHandle() {
...
@@ -347,7 +348,7 @@ void* GetCUDADsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcuda.dylib"
,
false
);
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcuda.dylib"
,
false
);
#elif defined(PADDLE_WITH_HIP)
#elif defined(PADDLE_WITH_HIP)
return
GetDsoHandleFromSearchPath
(
FLAGS_rocm_dir
,
"lib
hip_hcc
.so"
,
false
);
return
GetDsoHandleFromSearchPath
(
FLAGS_rocm_dir
,
"lib
amdhip64
.so"
,
false
);
#else
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcuda.so"
,
false
);
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcuda.so"
,
false
);
#endif
#endif
...
...
paddle/fluid/platform/dynload/hiprtc.h
浏览文件 @
33429630
...
@@ -45,6 +45,7 @@ extern bool HasNVRTC();
...
@@ -45,6 +45,7 @@ extern bool HasNVRTC();
* include all needed hiprtc functions
* include all needed hiprtc functions
**/
**/
#define HIPRTC_ROUTINE_EACH(__macro) \
#define HIPRTC_ROUTINE_EACH(__macro) \
__macro(hiprtcVersion); \
__macro(hiprtcGetErrorString); \
__macro(hiprtcGetErrorString); \
__macro(hiprtcCompileProgram); \
__macro(hiprtcCompileProgram); \
__macro(hiprtcCreateProgram); \
__macro(hiprtcCreateProgram); \
...
...
paddle/fluid/platform/dynload/miopen.h
浏览文件 @
33429630
...
@@ -16,10 +16,15 @@ limitations under the License. */
...
@@ -16,10 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <glog/logging.h>
#include <miopen/miopen.h>
#include <miopen/miopen.h>
#include <miopen/version.h>
#include <mutex> // NOLINT
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
#define MIOPEN_VERSION \
(MIOPEN_VERSION_MAJOR * 1000 + MIOPEN_VERSION_MINOR * 100 + \
MIOPEN_VERSION_PATCH) // NOLINT
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
dynload
{
namespace
dynload
{
...
...
paddle/fluid/platform/dynload/rocm_driver.h
浏览文件 @
33429630
...
@@ -46,6 +46,7 @@ extern bool HasCUDADriver();
...
@@ -46,6 +46,7 @@ extern bool HasCUDADriver();
* include all needed cuda driver functions
* include all needed cuda driver functions
**/
**/
#define ROCM_ROUTINE_EACH(__macro) \
#define ROCM_ROUTINE_EACH(__macro) \
__macro(hipDriverGetVersion); \
__macro(hipGetErrorString); \
__macro(hipGetErrorString); \
__macro(hipModuleLoadData); \
__macro(hipModuleLoadData); \
__macro(hipModuleGetFunction); \
__macro(hipModuleGetFunction); \
...
...
paddle/fluid/platform/event.h
浏览文件 @
33429630
...
@@ -18,6 +18,9 @@ limitations under the License. */
...
@@ -18,6 +18,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#endif
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -48,9 +51,9 @@ class Event {
...
@@ -48,9 +51,9 @@ class Event {
void
set_name
(
std
::
string
name
)
{
name_
=
name
;
}
void
set_name
(
std
::
string
name
)
{
name_
=
name
;
}
void
set_role
(
EventRole
role
)
{
role_
=
role
;
}
void
set_role
(
EventRole
role
)
{
role_
=
role
;
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef PADDLE_WITH_CUPTI
#ifndef PADDLE_WITH_CUPTI
cuda
Event_t
event
()
const
{
return
event_
;
}
gpu
Event_t
event
()
const
{
return
event_
;
}
int
device
()
const
{
return
device_
;
}
int
device
()
const
{
return
device_
;
}
#endif
#endif
#endif
#endif
...
@@ -66,7 +69,7 @@ class Event {
...
@@ -66,7 +69,7 @@ class Event {
EventRole
role_
{};
EventRole
role_
{};
int64_t
cpu_ns_
;
int64_t
cpu_ns_
;
bool
visited_status_
{
false
};
bool
visited_status_
{
false
};
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUPTI
#ifdef PADDLE_WITH_CUPTI
int64_t
gpu_ns_
=
0
;
int64_t
gpu_ns_
=
0
;
...
@@ -77,7 +80,7 @@ class Event {
...
@@ -77,7 +80,7 @@ class Event {
private:
private:
#else
#else
cuda
Event_t
event_
=
nullptr
;
gpu
Event_t
event_
=
nullptr
;
int
device_
=
-
1
;
int
device_
=
-
1
;
#endif
#endif
#endif
#endif
...
...
paddle/fluid/platform/flags.cc
浏览文件 @
33429630
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "gflags/gflags.h"
#include "gflags/gflags.h"
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#endif
#endif
...
@@ -45,7 +45,7 @@ DEFINE_bool(check_nan_inf, false,
...
@@ -45,7 +45,7 @@ DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely."
);
"extremely slow so please use this flag wisely."
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
/**
* CUDA related related FLAG
* CUDA related related FLAG
...
@@ -84,7 +84,7 @@ DEFINE_string(selected_gpus, "",
...
@@ -84,7 +84,7 @@ DEFINE_string(selected_gpus, "",
"share-memory only."
);
"share-memory only."
);
#endif
#endif
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
/**
* CUDNN related FLAG
* CUDNN related FLAG
...
@@ -167,7 +167,7 @@ DEFINE_bool(cudnn_batchnorm_spatial_persistent, false,
...
@@ -167,7 +167,7 @@ DEFINE_bool(cudnn_batchnorm_spatial_persistent, false,
"batch_norm, default is False."
);
"batch_norm, default is False."
);
#endif
#endif
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
/**
* NCCL related FLAG
* NCCL related FLAG
...
@@ -377,7 +377,7 @@ DEFINE_double(
...
@@ -377,7 +377,7 @@ DEFINE_double(
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
"Default use 50% of CPU memory as the pinned_memory for PaddlePaddle,"
"reserve the rest for page tables, etc"
);
"reserve the rest for page tables, etc"
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
/**
/**
* Memory related FLAG
* Memory related FLAG
...
...
paddle/fluid/platform/for_range.h
浏览文件 @
33429630
...
@@ -40,7 +40,7 @@ struct ForRange<CPUDeviceContext> {
...
@@ -40,7 +40,7 @@ struct ForRange<CPUDeviceContext> {
size_t
limit_
;
size_t
limit_
;
};
};
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
template
<
typename
Function
>
template
<
typename
Function
>
__global__
static
void
ForRangeElemwiseOpGridIsOne
(
Function
func
)
{
__global__
static
void
ForRangeElemwiseOpGridIsOne
(
Function
func
)
{
size_t
idx
=
static_cast
<
size_t
>
(
threadIdx
.
x
);
size_t
idx
=
static_cast
<
size_t
>
(
threadIdx
.
x
);
...
...
paddle/fluid/platform/init.cc
浏览文件 @
33429630
...
@@ -16,8 +16,10 @@ limitations under the License. */
...
@@ -16,8 +16,10 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cupti.h"
#include "paddle/fluid/platform/dynload/cupti.h"
#endif
#endif
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
...
@@ -92,6 +94,7 @@ bool InitGflags(std::vector<std::string> args) {
...
@@ -92,6 +94,7 @@ bool InitGflags(std::vector<std::string> args) {
return
successed
;
return
successed
;
}
}
#ifdef PADDLE_WITH_CUDA
void
InitCupti
()
{
void
InitCupti
()
{
#ifdef PADDLE_WITH_CUPTI
#ifdef PADDLE_WITH_CUPTI
if
(
FLAGS_multiple_of_cupti_buffer_size
==
1
)
return
;
if
(
FLAGS_multiple_of_cupti_buffer_size
==
1
)
return
;
...
@@ -117,14 +120,17 @@ void InitCupti() {
...
@@ -117,14 +120,17 @@ void InitCupti() {
#undef MULTIPLY_ATTR_VALUE
#undef MULTIPLY_ATTR_VALUE
#endif
#endif
}
}
#endif
void
InitDevices
()
{
void
InitDevices
()
{
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// CUPTI attribute should be set before any CUDA context is created (see CUPTI
// documentation about CUpti_ActivityAttribute).
// documentation about CUpti_ActivityAttribute).
#ifdef PADDLE_WITH_CUDA
InitCupti
();
InitCupti
();
#endif
/*Init all available devices by default */
/*Init all available devices by default */
std
::
vector
<
int
>
devices
;
std
::
vector
<
int
>
devices
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
try
{
try
{
// use user specified GPUs in single-node multi-process mode.
// use user specified GPUs in single-node multi-process mode.
devices
=
platform
::
GetSelectedDevices
();
devices
=
platform
::
GetSelectedDevices
();
...
@@ -154,7 +160,7 @@ void InitDevices(const std::vector<int> devices) {
...
@@ -154,7 +160,7 @@ void InitDevices(const std::vector<int> devices) {
continue
;
continue
;
}
}
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places
.
emplace_back
(
platform
::
CUDAPlace
(
devices
[
i
]));
places
.
emplace_back
(
platform
::
CUDAPlace
(
devices
[
i
]));
#endif
#endif
#ifdef PADDLE_WITH_XPU
#ifdef PADDLE_WITH_XPU
...
@@ -162,7 +168,7 @@ void InitDevices(const std::vector<int> devices) {
...
@@ -162,7 +168,7 @@ void InitDevices(const std::vector<int> devices) {
#endif
#endif
}
}
places
.
emplace_back
(
platform
::
CPUPlace
());
places
.
emplace_back
(
platform
::
CPUPlace
());
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
places
.
emplace_back
(
platform
::
CUDAPinnedPlace
());
places
.
emplace_back
(
platform
::
CUDAPinnedPlace
());
#endif
#endif
platform
::
DeviceContextPool
::
Init
(
places
);
platform
::
DeviceContextPool
::
Init
(
places
);
...
...
paddle/fluid/platform/init_test.cc
浏览文件 @
33429630
...
@@ -19,7 +19,8 @@ TEST(InitDevices, CPU) {
...
@@ -19,7 +19,8 @@ TEST(InitDevices, CPU) {
using
paddle
::
framework
::
InitDevices
;
using
paddle
::
framework
::
InitDevices
;
using
paddle
::
platform
::
DeviceContextPool
;
using
paddle
::
platform
::
DeviceContextPool
;
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU)
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_XPU) && \
!defined(PADDLE_WITH_HIP)
InitDevices
();
InitDevices
();
DeviceContextPool
&
pool
=
DeviceContextPool
::
Instance
();
DeviceContextPool
&
pool
=
DeviceContextPool
::
Instance
();
ASSERT_EQ
(
pool
.
size
(),
1U
);
ASSERT_EQ
(
pool
.
size
(),
1U
);
...
@@ -30,7 +31,7 @@ TEST(InitDevices, CUDA) {
...
@@ -30,7 +31,7 @@ TEST(InitDevices, CUDA) {
using
paddle
::
framework
::
InitDevices
;
using
paddle
::
framework
::
InitDevices
;
using
paddle
::
platform
::
DeviceContextPool
;
using
paddle
::
platform
::
DeviceContextPool
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
int
count
=
paddle
::
platform
::
GetCUDADeviceCount
();
int
count
=
paddle
::
platform
::
GetCUDADeviceCount
();
InitDevices
();
InitDevices
();
DeviceContextPool
&
pool
=
DeviceContextPool
::
Instance
();
DeviceContextPool
&
pool
=
DeviceContextPool
::
Instance
();
...
...
paddle/fluid/platform/miopen_helper.h
0 → 100644
浏览文件 @
33429630
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"
// MIOPEN do not have epslion definition
#define CUDNN_BN_MIN_EPSILON 1e-05
namespace
paddle
{
namespace
platform
{
struct
float16
;
}
// namespace platform
}
// namespace paddle
DECLARE_bool
(
cudnn_deterministic
);
namespace
paddle
{
namespace
platform
{
// MIOPEN only support NCHW, just for compatibility with CUDNN API
typedef
enum
{
MIOPEN_TENSOR_NCHW
=
0
,
MIOPEN_TENSOR_NHWC
=
1
,
}
miopenTensorFormat_t
;
// MIOPEN do not support indirect function call defined in cudnnWorkspaceHandle
struct
miopenWorkspace
{
explicit
miopenWorkspace
(
size_t
size
)
:
size
(
size
),
data
(
NULL
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMalloc
(
&
data
,
size
));
}
miopenWorkspace
(
const
miopenWorkspace
&
)
=
delete
;
miopenWorkspace
(
miopenWorkspace
&&
)
=
default
;
miopenWorkspace
&
operator
=
(
miopenWorkspace
&&
)
=
default
;
~
miopenWorkspace
()
{
if
(
data
)
{
hipFree
(
data
);
}
}
size_t
size
;
void
*
data
;
};
inline
const
char
*
miopenGetErrorString
(
miopenStatus_t
status
)
{
switch
(
status
)
{
case
miopenStatusSuccess
:
return
"miopenStatusSuccess"
;
case
miopenStatusNotInitialized
:
return
"miopenStatusNotInitialized"
;
case
miopenStatusAllocFailed
:
return
"miopenStatusAllocFailed"
;
case
miopenStatusBadParm
:
return
"miopenStatusBadParm"
;
case
miopenStatusInternalError
:
return
"miopenStatusInternalError"
;
case
miopenStatusInvalidValue
:
return
"miopenStatusInvalidValue"
;
case
miopenStatusUnknownError
:
return
"miopenStatusUnknownError"
;
case
miopenStatusNotImplemented
:
return
"miopenStatusNotImplemented"
;
default:
return
"Unknown miopen error number"
;
}
}
// no use, but will have compiling error if not defined
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))
enum
class
DataLayout
{
// Not use
kNHWC
,
kNCHW
,
kNCDHW
,
kNDHWC
,
// add, liyamei
kNCHW_VECT_C
,
};
enum
class
PoolingMode
{
kMaximum
,
kMaximumDeterministic
,
kAverageExclusive
,
kAverageInclusive
,
};
enum
class
ActivationMode
{
kNone
,
// activation identity
kSigmoid
,
kRelu
,
kRelu6
,
kReluX
,
kTanh
,
kBandPass
,
};
inline
miopenPoolingMode_t
GetPoolingMode
(
const
PoolingMode
&
mode
)
{
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
return
miopenPoolingMax
;
case
PoolingMode
::
kAverageExclusive
:
return
miopenPoolingAverage
;
case
PoolingMode
::
kAverageInclusive
:
return
miopenPoolingAverageInclusive
;
case
PoolingMode
::
kMaximum
:
return
miopenPoolingMax
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unexpected MIOPEN pooling mode."
));
}
}
inline
ActivationMode
StringToActivationMode
(
const
std
::
string
&
str
)
{
if
(
str
==
"identity"
)
{
return
ActivationMode
::
kNone
;
}
else
if
(
str
==
"sigmoid"
)
{
return
ActivationMode
::
kSigmoid
;
}
else
if
(
str
==
"relu"
)
{
return
ActivationMode
::
kRelu
;
}
else
if
(
str
==
"relu6"
)
{
return
ActivationMode
::
kRelu6
;
}
else
if
(
str
==
"relux"
)
{
return
ActivationMode
::
kReluX
;
}
else
if
(
str
==
"tanh"
)
{
return
ActivationMode
::
kTanh
;
}
else
if
(
str
==
"bandpass"
)
{
return
ActivationMode
::
kBandPass
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unknown MIOPEN activation string: %s."
,
str
));
}
}
template
<
typename
T
>
class
CudnnDataType
;
template
<
>
class
CudnnDataType
<
float16
>
{
public:
static
const
miopenDataType_t
type
=
miopenHalf
;
// The scaling param type is float for HALF and FLOAT tensors
using
ScalingParamType
=
const
float
;
using
BatchNormParamType
=
float
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
1.0
;
return
&
v
;
}
static
ScalingParamType
*
kZero
()
{
static
ScalingParamType
v
=
0.0
;
return
&
v
;
}
};
template
<
>
class
CudnnDataType
<
float
>
{
public:
static
const
miopenDataType_t
type
=
miopenFloat
;
using
ScalingParamType
=
const
float
;
using
BatchNormParamType
=
float
;
static
ScalingParamType
*
kOne
()
{
static
ScalingParamType
v
=
1.0
;
return
&
v
;
}
static
ScalingParamType
*
kZero
()
{
static
ScalingParamType
v
=
0.0
;
return
&
v
;
}
};
inline
miopenTensorFormat_t
GetCudnnTensorFormat
(
const
DataLayout
&
order
)
{
switch
(
order
)
{
case
DataLayout
::
kNHWC
:
return
MIOPEN_TENSOR_NHWC
;
case
DataLayout
::
kNCHW
:
return
MIOPEN_TENSOR_NCHW
;
case
DataLayout
::
kNCDHW
:
return
MIOPEN_TENSOR_NCHW
;
case
DataLayout
::
kNDHWC
:
return
MIOPEN_TENSOR_NHWC
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"MIOPEN has no equivalent dataLayout for input order."
));
}
return
MIOPEN_TENSOR_NCHW
;
}
class
ScopedTensorDescriptor
{
public:
ScopedTensorDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateTensorDescriptor
(
&
desc_
));
}
~
ScopedTensorDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyTensorDescriptor
(
desc_
));
}
inline
miopenTensorDescriptor_t
descriptor
(
const
miopenTensorFormat_t
format
,
const
miopenDataType_t
type
,
const
std
::
vector
<
int
>&
dims
,
const
int
groups
=
1
)
{
// the format is not used now, will add later
std
::
vector
<
int
>
strides
(
dims
.
size
());
strides
[
dims
.
size
()
-
1
]
=
1
;
for
(
int
i
=
dims
.
size
()
-
2
;
i
>=
0
;
i
--
)
{
strides
[
i
]
=
dims
[
i
+
1
]
*
strides
[
i
+
1
];
}
// Update tensor descriptor dims setting if groups > 1
// NOTE: Here, Assume using NCHW or NCDHW order
std
::
vector
<
int
>
dims_with_group
(
dims
.
begin
(),
dims
.
end
());
if
(
groups
>
1
)
{
dims_with_group
[
1
]
=
dims_with_group
[
1
]
/
groups
;
}
// MIOPEN ONLY support data layout of NCHW
PADDLE_ENFORCE_EQ
(
format
,
MIOPEN_TENSOR_NCHW
,
platform
::
errors
::
InvalidArgument
(
"format should ONLY be NCHW in MIOPEN."
));
if
(
dims
.
size
()
==
4
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetTensorDescriptor
(
desc_
,
type
,
dims_with_group
.
size
(),
const_cast
<
int
*>
(
dims_with_group
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
}
else
if
(
dims
.
size
()
==
5
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetTensorDescriptor
(
desc_
,
type
,
dims_with_group
.
size
(),
const_cast
<
int
*>
(
dims_with_group
.
data
()),
const_cast
<
int
*>
(
strides
.
data
())));
}
return
desc_
;
}
template
<
typename
T
>
inline
miopenTensorDescriptor_t
descriptor
(
const
DataLayout
&
order
,
const
std
::
vector
<
int
>&
dims
,
const
int
groups
=
1
)
{
return
descriptor
(
GetCudnnTensorFormat
(
order
),
CudnnDataType
<
T
>::
type
,
dims
,
groups
);
}
inline
miopenTensorDescriptor_t
descriptor
(
const
miopenDataType_t
miopen_type
,
const
std
::
vector
<
int
>&
dim
,
const
std
::
vector
<
int
>&
stride
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetTensorDescriptor
(
desc_
,
miopen_type
,
dim
.
size
(),
const_cast
<
int
*>
(
dim
.
data
()),
const_cast
<
int
*>
(
stride
.
data
())));
return
desc_
;
}
template
<
typename
T
>
inline
miopenTensorDescriptor_t
descriptor
(
const
std
::
vector
<
int
>&
dim
,
const
std
::
vector
<
int
>&
stride
)
{
return
descriptor
(
CudnnDataType
<
T
>::
type
,
dim
,
stride
);
}
inline
miopenTensorDescriptor_t
desc
()
{
return
desc_
;
}
private:
miopenTensorDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedTensorDescriptor
);
};
class
ScopedDropoutDescriptor
{
public:
ScopedDropoutDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateDropoutDescriptor
(
&
desc_
));
}
~
ScopedDropoutDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyDropoutDescriptor
(
desc_
));
}
inline
miopenDropoutDescriptor_t
descriptor
(
const
miopenHandle_t
&
handle
,
const
platform
::
Place
&
place
,
bool
initialized
,
float
dropout_prob_
,
framework
::
Tensor
*
dropout_state_
,
int
seed
,
size_t
state_size
)
{
if
(
dropout_state_
==
nullptr
)
{
// for no dropout or test
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetDropoutDescriptor
(
desc_
,
handle
,
0
/* dropout */
,
nullptr
,
0
/* state_size */
,
0
/* seed */
,
false
,
false
,
MIOPEN_RNG_PSEUDO_XORWOW
));
return
desc_
;
}
auto
*
dropout_state_data
=
dropout_state_
->
data
<
uint8_t
>
();
if
(
!
initialized
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetDropoutDescriptor
(
desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
seed
,
false
,
false
,
MIOPEN_RNG_PSEUDO_XORWOW
));
}
else
{
auto
dropout_state_dims
=
dropout_state_
->
dims
();
state_size
=
dropout_state_dims
[
0
];
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenRestoreDropoutDescriptor
(
desc_
,
handle
,
dropout_prob_
,
dropout_state_data
,
state_size
,
0
,
false
,
false
,
MIOPEN_RNG_PSEUDO_XORWOW
));
}
return
desc_
;
}
inline
miopenDropoutDescriptor_t
desc
()
{
return
desc_
;
}
private:
miopenDropoutDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedDropoutDescriptor
);
};
class
ScopedRNNDescriptor
{
public:
ScopedRNNDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateRNNDescriptor
(
&
desc_
));
}
~
ScopedRNNDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyRNNDescriptor
(
desc_
));
}
inline
miopenRNNDescriptor_t
desc
()
{
return
desc_
;
}
private:
miopenRNNDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedRNNDescriptor
);
};
class
ScopedFilterDescriptor
{
public:
ScopedFilterDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateTensorDescriptor
(
&
desc_
));
}
~
ScopedFilterDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyTensorDescriptor
(
desc_
));
}
inline
miopenTensorDescriptor_t
descriptor
(
const
miopenTensorFormat_t
format
,
const
miopenDataType_t
type
,
const
std
::
vector
<
int
>&
kernel
,
const
int
groups
=
1
)
{
// filter layout: MCHW(MCDHW), where M is the number of
// output image channels, C is the number of input image channels,
// D is the depth of the filter, H is the height of the filter, and W is the
// width of the filter.
std
::
vector
<
int
>
kernel_with_group
(
kernel
.
begin
(),
kernel
.
end
());
if
(
groups
>
1
)
{
kernel_with_group
[
0
]
/=
groups
;
// NOTE: input filter(C) of the filter is already asserted to be C/groups.
}
std
::
vector
<
int
>
stride_dim
(
kernel_with_group
.
size
());
stride_dim
.
push_back
(
1
);
for
(
int
k
=
kernel_with_group
.
size
()
-
2
;
k
>=
0
;
k
--
)
{
stride_dim
[
k
]
=
stride_dim
[
k
+
1
]
*
kernel_with_group
[
k
+
1
];
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetTensorDescriptor
(
desc_
,
type
,
kernel_with_group
.
size
(),
const_cast
<
int
*>
(
kernel_with_group
.
data
()),
const_cast
<
int
*>
(
stride_dim
.
data
())));
return
desc_
;
}
template
<
typename
T
>
inline
miopenTensorDescriptor_t
descriptor
(
const
DataLayout
&
order
,
const
std
::
vector
<
int
>&
kernel
,
const
int
groups
=
1
)
{
return
descriptor
(
GetCudnnTensorFormat
(
order
),
CudnnDataType
<
T
>::
type
,
kernel
,
groups
);
}
inline
miopenTensorDescriptor_t
desc
()
{
return
desc_
;
}
private:
miopenTensorDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedFilterDescriptor
);
};
class
ScopedConvolutionDescriptor
{
public:
ScopedConvolutionDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateConvolutionDescriptor
(
&
desc_
));
}
~
ScopedConvolutionDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyConvolutionDescriptor
(
desc_
));
}
inline
miopenConvolutionDescriptor_t
descriptor
(
miopenDataType_t
type
,
const
std
::
vector
<
int
>&
pads
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
)
{
PADDLE_ENFORCE_EQ
(
pads
.
size
(),
strides
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of pads and strides should be equal. But "
"received size of pads is %d, size of strides is %d."
,
pads
.
size
(),
strides
.
size
()));
PADDLE_ENFORCE_EQ
(
pads
.
size
(),
dilations
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of pads and dilations should be equal. But received size "
"of pads is %d, size of dilations is %d."
,
pads
.
size
(),
dilations
.
size
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenInitConvolutionNdDescriptor
(
desc_
,
pads
.
size
(),
const_cast
<
int
*>
(
pads
.
data
()),
const_cast
<
int
*>
(
strides
.
data
()),
const_cast
<
int
*>
(
dilations
.
data
()),
miopenConvolution
));
return
desc_
;
}
template
<
typename
T
>
inline
miopenConvolutionDescriptor_t
descriptor
(
const
std
::
vector
<
int
>&
pads
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
)
{
return
descriptor
(
CudnnDataType
<
T
>::
type
,
pads
,
strides
,
dilations
);
}
private:
miopenConvolutionDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedConvolutionDescriptor
);
};
class
ScopedPoolingDescriptor
{
public:
ScopedPoolingDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreatePoolingDescriptor
(
&
desc_
));
}
~
ScopedPoolingDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyPoolingDescriptor
(
desc_
));
}
inline
miopenPoolingDescriptor_t
descriptor
(
const
PoolingMode
&
mode
,
const
std
::
vector
<
int
>&
kernel
,
const
std
::
vector
<
int
>&
pads
,
const
std
::
vector
<
int
>&
strides
)
{
PADDLE_ENFORCE_EQ
(
kernel
.
size
(),
pads
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of kernel and pads should be equal. But "
"received size of kernel is %d, size of pads is %d."
,
kernel
.
size
(),
pads
.
size
()));
PADDLE_ENFORCE_EQ
(
kernel
.
size
(),
strides
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of kernel and strides should be equal. But "
"received size of kernel is %d, size of strides is %d."
,
kernel
.
size
(),
strides
.
size
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSet2dPoolingDescriptor
(
desc_
,
GetPoolingMode
(
mode
),
kernel
[
0
],
kernel
[
1
],
pads
[
0
],
pads
[
1
],
strides
[
0
],
strides
[
1
]));
return
desc_
;
}
private:
miopenPoolingDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedPoolingDescriptor
);
};
class
ScopedActivationDescriptor
{
public:
ScopedActivationDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateActivationDescriptor
(
&
desc_
));
}
~
ScopedActivationDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyActivationDescriptor
(
desc_
));
}
template
<
typename
T
>
inline
miopenActivationDescriptor_t
descriptor
(
const
std
::
string
&
act
,
double
value_max
=
static_cast
<
double
>
(
0.
))
{
double
relu_ceiling
=
0.0
;
ActivationMode
activation_mode
=
StringToActivationMode
(
act
);
miopenActivationMode_t
mode
;
switch
(
activation_mode
)
{
case
ActivationMode
::
kNone
:
mode
=
miopenActivationPASTHRU
;
break
;
case
ActivationMode
::
kRelu6
:
relu_ceiling
=
6.0
;
mode
=
miopenActivationCLIPPEDRELU
;
break
;
case
ActivationMode
::
kReluX
:
relu_ceiling
=
value_max
;
mode
=
miopenActivationCLIPPEDRELU
;
break
;
case
ActivationMode
::
kRelu
:
mode
=
miopenActivationRELU
;
break
;
case
ActivationMode
::
kSigmoid
:
mode
=
miopenActivationLOGISTIC
;
break
;
case
ActivationMode
::
kTanh
:
mode
=
miopenActivationTANH
;
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unrecognized MIOPEN activation mode: %d."
,
static_cast
<
int
>
(
activation_mode
)));
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetActivationDescriptor
(
desc_
,
mode
,
relu_ceiling
,
0.0
,
0.0
));
return
desc_
;
}
private:
miopenActivationDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedActivationDescriptor
);
};
inline
bool
CanCUDNNBeUsed
(
const
framework
::
ExecutionContext
&
ctx
)
{
bool
use_cudnn
=
ctx
.
Attr
<
bool
>
(
"use_cudnn"
);
use_cudnn
&=
paddle
::
platform
::
is_gpu_place
(
ctx
.
GetPlace
());
#ifdef PADDLE_WITH_HIP
if
(
use_cudnn
)
{
auto
&
dev_ctx
=
ctx
.
device_context
<
platform
::
CUDADeviceContext
>
();
use_cudnn
&=
dev_ctx
.
cudnn_handle
()
!=
nullptr
;
}
#endif
return
use_cudnn
;
}
class
ScopedCTCLossDescriptor
{
public:
ScopedCTCLossDescriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenCreateCTCLossDescriptor
(
&
desc_
));
}
~
ScopedCTCLossDescriptor
()
PADDLE_MAY_THROW
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenDestroyCTCLossDescriptor
(
desc_
));
}
template
<
typename
T
>
inline
miopenCTCLossDescriptor_t
descriptor
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenSetCTCLossDescriptor
(
desc_
,
CudnnDataType
<
T
>::
type
,
0
,
false
));
return
desc_
;
}
private:
miopenCTCLossDescriptor_t
desc_
;
DISABLE_COPY_AND_ASSIGN
(
ScopedCTCLossDescriptor
);
};
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/miopen_helper_test.cc
0 → 100644
浏览文件 @
33429630
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL
#include "paddle/fluid/platform/miopen_helper.h"
#include <gtest/gtest.h>
TEST
(
MIOpenHelper
,
ScopedTensorDescriptor
)
{
using
paddle
::
platform
::
ScopedTensorDescriptor
;
using
paddle
::
platform
::
DataLayout
;
ScopedTensorDescriptor
tensor_desc
;
std
::
vector
<
int
>
shape
=
{
2
,
4
,
6
,
6
};
auto
desc
=
tensor_desc
.
descriptor
<
float
>
(
DataLayout
::
kNCHW
,
shape
);
miopenDataType_t
type
;
int
nd
;
std
::
vector
<
int
>
dims
(
4
);
std
::
vector
<
int
>
strides
(
4
);
paddle
::
platform
::
dynload
::
miopenGetTensorDescriptor
(
desc
,
&
type
,
dims
.
data
(),
strides
.
data
());
paddle
::
platform
::
dynload
::
miopenGetTensorDescriptorSize
(
desc
,
&
nd
);
EXPECT_EQ
(
nd
,
4
);
for
(
size_t
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
EXPECT_EQ
(
dims
[
i
],
shape
[
i
]);
}
EXPECT_EQ
(
strides
[
3
],
1
);
EXPECT_EQ
(
strides
[
2
],
6
);
EXPECT_EQ
(
strides
[
1
],
36
);
EXPECT_EQ
(
strides
[
0
],
144
);
// test tensor5d: ScopedTensorDescriptor
ScopedTensorDescriptor
tensor5d_desc
;
std
::
vector
<
int
>
shape_5d
=
{
2
,
4
,
6
,
6
,
6
};
auto
desc_5d
=
tensor5d_desc
.
descriptor
<
float
>
(
DataLayout
::
kNCDHW
,
shape_5d
);
std
::
vector
<
int
>
dims_5d
(
5
);
std
::
vector
<
int
>
strides_5d
(
5
);
paddle
::
platform
::
dynload
::
miopenGetTensorDescriptor
(
desc_5d
,
&
type
,
dims_5d
.
data
(),
strides_5d
.
data
());
paddle
::
platform
::
dynload
::
miopenGetTensorDescriptorSize
(
desc_5d
,
&
nd
);
EXPECT_EQ
(
nd
,
5
);
for
(
size_t
i
=
0
;
i
<
dims_5d
.
size
();
++
i
)
{
EXPECT_EQ
(
dims_5d
[
i
],
shape_5d
[
i
]);
}
EXPECT_EQ
(
strides_5d
[
4
],
1
);
EXPECT_EQ
(
strides_5d
[
3
],
6
);
EXPECT_EQ
(
strides_5d
[
2
],
36
);
EXPECT_EQ
(
strides_5d
[
1
],
216
);
EXPECT_EQ
(
strides_5d
[
0
],
864
);
}
TEST
(
MIOpenHelper
,
ScopedConvolutionDescriptor
)
{
using
paddle
::
platform
::
ScopedConvolutionDescriptor
;
ScopedConvolutionDescriptor
conv_desc
;
std
::
vector
<
int
>
src_pads
=
{
2
,
2
,
2
};
std
::
vector
<
int
>
src_strides
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
src_dilations
=
{
1
,
1
,
1
};
auto
desc
=
conv_desc
.
descriptor
<
float
>
(
src_pads
,
src_strides
,
src_dilations
);
miopenConvolutionMode_t
mode
;
int
nd
;
std
::
vector
<
int
>
pads
(
3
);
std
::
vector
<
int
>
strides
(
3
);
std
::
vector
<
int
>
dilations
(
3
);
paddle
::
platform
::
dynload
::
miopenGetConvolutionNdDescriptor
(
desc
,
3
,
&
nd
,
pads
.
data
(),
strides
.
data
(),
dilations
.
data
(),
&
mode
);
EXPECT_EQ
(
nd
,
3
);
for
(
size_t
i
=
0
;
i
<
src_pads
.
size
();
++
i
)
{
EXPECT_EQ
(
pads
[
i
],
src_pads
[
i
]);
EXPECT_EQ
(
strides
[
i
],
src_strides
[
i
]);
EXPECT_EQ
(
dilations
[
i
],
src_dilations
[
i
]);
}
EXPECT_EQ
(
mode
,
miopenConvolution
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录