Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
34f1628c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
34f1628c
编写于
2月 07, 2021
作者:
Q
Qi Li
提交者:
GitHub
2月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] update fluid platform for rocm39 (part2), test=develop (#30774)
上级
5ded39f2
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
530 addition
and
110 deletion
+530
-110
paddle/fluid/platform/float16.h
paddle/fluid/platform/float16.h
+81
-33
paddle/fluid/platform/float16_test.cu
paddle/fluid/platform/float16_test.cu
+125
-28
paddle/fluid/platform/gen_comm_id_helper.cc
paddle/fluid/platform/gen_comm_id_helper.cc
+3
-2
paddle/fluid/platform/gen_comm_id_helper.h
paddle/fluid/platform/gen_comm_id_helper.h
+2
-1
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+167
-16
paddle/fluid/platform/gpu_info.h
paddle/fluid/platform/gpu_info.h
+21
-5
paddle/fluid/platform/gpu_launch_config.h
paddle/fluid/platform/gpu_launch_config.h
+5
-1
paddle/fluid/platform/nccl_helper.h
paddle/fluid/platform/nccl_helper.h
+7
-2
paddle/fluid/platform/place.h
paddle/fluid/platform/place.h
+2
-2
paddle/fluid/platform/profiler.cc
paddle/fluid/platform/profiler.cc
+1
-1
paddle/fluid/platform/profiler.cu
paddle/fluid/platform/profiler.cu
+23
-0
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+2
-2
paddle/fluid/platform/profiler_helper.h
paddle/fluid/platform/profiler_helper.h
+11
-1
paddle/fluid/platform/profiler_test.cc
paddle/fluid/platform/profiler_test.cc
+11
-1
paddle/fluid/platform/stream_callback_manager.cc
paddle/fluid/platform/stream_callback_manager.cc
+13
-3
paddle/fluid/platform/stream_callback_manager.h
paddle/fluid/platform/stream_callback_manager.h
+10
-2
paddle/fluid/platform/test_limit_gpu_memory.cu
paddle/fluid/platform/test_limit_gpu_memory.cu
+28
-7
paddle/fluid/platform/transform.h
paddle/fluid/platform/transform.h
+17
-2
paddle/fluid/platform/variant.h
paddle/fluid/platform/variant.h
+1
-1
未找到文件。
paddle/fluid/platform/float16.h
浏览文件 @
34f1628c
...
...
@@ -90,7 +90,7 @@ struct PADDLE_ALIGN(2) float16 {
// Constructors
#ifdef PADDLE_CUDA_FP16
HOSTDEVICE
inline
explicit
float16
(
const
half
&
h
)
{
#if
(defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
)
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP
)
#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 9000
x
=
reinterpret_cast
<
__half_raw
*>
(
const_cast
<
half
*>
(
&
h
))
->
x
;
#else
...
...
@@ -366,10 +366,11 @@ struct PADDLE_ALIGN(2) float16 {
// CUDA 7.5 and 8.0 do not. The arithmetic operators defined here are
// for users to write similar CUDA code in CUDA 7.5 and 8.0 as in
// CUDA 9.0 regarding the half data type.
// xuan[TODO] change for rocm
#if defined(PADDLE_CUDA_FP16) && CUDA_VERSION < 9000
// ROCM has built-in arithmetic operators as not defined
// __HIP_NO_HALF_OPERATORS__
#if defined(PADDLE_CUDA_FP16) && !defined(__HIPCC__) && CUDA_VERSION < 9000
DEVICE
inline
half
operator
+
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hadd
(
a
,
b
);
#else
float
res
=
static_cast
<
float
>
(
float16
(
a
))
+
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -378,7 +379,7 @@ DEVICE inline half operator+(const half& a, const half& b) {
}
DEVICE
inline
half
operator
-
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hsub
(
a
,
b
);
#else
float
res
=
static_cast
<
float
>
(
float16
(
a
))
-
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -387,7 +388,7 @@ DEVICE inline half operator-(const half& a, const half& b) {
}
DEVICE
inline
half
operator
*
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hmul
(
a
,
b
);
#else
float
res
=
static_cast
<
float
>
(
float16
(
a
))
*
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -396,7 +397,7 @@ DEVICE inline half operator*(const half& a, const half& b) {
}
DEVICE
inline
half
operator
/
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
float
num
=
__half2float
(
a
);
float
denom
=
__half2float
(
b
);
return
__float2half
(
num
/
denom
);
...
...
@@ -407,7 +408,7 @@ DEVICE inline half operator/(const half& a, const half& b) {
}
DEVICE
inline
half
operator
-
(
const
half
&
a
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hneg
(
a
);
#else
float
res
=
-
static_cast
<
float
>
(
float16
(
a
));
...
...
@@ -438,7 +439,7 @@ DEVICE inline half& operator/=(half& a, const half& b) { // NOLINT
#endif
DEVICE
inline
bool
operator
==
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__heq
(
a
,
b
);
#else
return
static_cast
<
float
>
(
float16
(
a
))
==
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -446,7 +447,7 @@ DEVICE inline bool operator==(const half& a, const half& b) {
}
DEVICE
inline
bool
operator
!=
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hne
(
a
,
b
);
#else
return
static_cast
<
float
>
(
float16
(
a
))
!=
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -454,7 +455,7 @@ DEVICE inline bool operator!=(const half& a, const half& b) {
}
DEVICE
inline
bool
operator
<
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hlt
(
a
,
b
);
#else
return
static_cast
<
float
>
(
float16
(
a
))
<
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -462,7 +463,7 @@ DEVICE inline bool operator<(const half& a, const half& b) {
}
DEVICE
inline
bool
operator
<=
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hle
(
a
,
b
);
#else
return
static_cast
<
float
>
(
float16
(
a
))
<=
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -470,7 +471,7 @@ DEVICE inline bool operator<=(const half& a, const half& b) {
}
DEVICE
inline
bool
operator
>
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hgt
(
a
,
b
);
#else
return
static_cast
<
float
>
(
float16
(
a
))
>
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -478,7 +479,7 @@ DEVICE inline bool operator>(const half& a, const half& b) {
}
DEVICE
inline
bool
operator
>=
(
const
half
&
a
,
const
half
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hge
(
a
,
b
);
#else
return
static_cast
<
float
>
(
float16
(
a
))
>=
static_cast
<
float
>
(
float16
(
b
));
...
...
@@ -489,9 +490,8 @@ DEVICE inline bool operator>=(const half& a, const half& b) {
// Arithmetic operators for float16 on GPU
#if defined(PADDLE_CUDA_FP16)
// HIPCC has compile error if call __device__ function __hadd in __host__
// __device__ function
// HIPCC has compile error if call __device__ function __hadd, __hsub, etc.
// in __host__ __device__ function
#if defined(__HIPCC__)
DEVICE
inline
float16
operator
+
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hadd
(
half
(
a
),
half
(
b
)));
...
...
@@ -509,8 +509,6 @@ HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
}
#endif
// HIPCC has compile error if call __device__ function __hsub in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE
inline
float16
operator
-
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hsub
(
half
(
a
),
half
(
b
)));
...
...
@@ -528,8 +526,6 @@ HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
}
#endif
// HIPCC has compile error if call __device__ function __hmul in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE
inline
float16
operator
*
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hmul
(
half
(
a
),
half
(
b
)));
...
...
@@ -547,8 +543,16 @@ HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
}
#endif
#if defined(__HIPCC__)
DEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
__hdiv
(
half
(
a
),
half
(
b
)));
}
HOST
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
return
float16
(
static_cast
<
float
>
(
a
)
/
static_cast
<
float
>
(
b
));
}
#else
HOSTDEVICE
inline
float16
operator
/
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float
num
=
__half2float
(
half
(
a
));
float
denom
=
__half2float
(
half
(
b
));
...
...
@@ -557,9 +561,8 @@ HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
return
float16
(
static_cast
<
float
>
(
a
)
/
static_cast
<
float
>
(
b
));
#endif
}
#endif
// HIPCC has compile error if call __device__ function __hneg in __host__
// __device__ function
#if defined(__HIPCC__)
DEVICE
inline
float16
operator
-
(
const
float16
&
a
)
{
return
float16
(
__hneg
(
half
(
a
)));
...
...
@@ -601,8 +604,8 @@ HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { // NOLINT
return
a
;
}
// HIPCC has compile error if call __device__ function __heq
in __host__
// __device__ function
// HIPCC has compile error if call __device__ function __heq
, __hne, etc.
//
in __host__
__device__ function
#if defined(__HIPCC__)
DEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__heq
(
half
(
a
),
half
(
b
));
...
...
@@ -610,7 +613,7 @@ DEVICE inline bool operator==(const float16& a, const float16& b) {
HOST
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
float
>
(
a
)
==
static_cast
<
float
>
(
b
);
}
#else //
CUDA
#else //
__HIPCC__
HOSTDEVICE
inline
bool
operator
==
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__heq
(
half
(
a
),
half
(
b
));
...
...
@@ -618,47 +621,92 @@ HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
return
static_cast
<
float
>
(
a
)
==
static_cast
<
float
>
(
b
);
#endif
}
#endif
#endif
// __HIPCC__
#if defined(__HIPCC__)
DEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hne
(
half
(
a
),
half
(
b
));
}
HOST
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
float
>
(
a
)
!=
static_cast
<
float
>
(
b
);
}
#else // __HIPCC__
HOSTDEVICE
inline
bool
operator
!=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hne
(
half
(
a
),
half
(
b
));
#else
return
static_cast
<
float
>
(
a
)
!=
static_cast
<
float
>
(
b
);
#endif
}
#endif // __HIPCC__
#if defined(__HIPCC__)
DEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hlt
(
half
(
a
),
half
(
b
));
}
HOST
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
float
>
(
a
)
<
static_cast
<
float
>
(
b
);
}
#else // __HIPCC__
HOSTDEVICE
inline
bool
operator
<
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hlt
(
half
(
a
),
half
(
b
));
#else
return
static_cast
<
float
>
(
a
)
<
static_cast
<
float
>
(
b
);
#endif
}
#endif // __HIPCC__
#if defined(__HIPCC__)
DEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hle
(
half
(
a
),
half
(
b
));
}
HOST
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
float
>
(
a
)
<=
static_cast
<
float
>
(
b
);
}
#else // __HIPCC__
HOSTDEVICE
inline
bool
operator
<=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hle
(
half
(
a
),
half
(
b
));
#else
return
static_cast
<
float
>
(
a
)
<=
static_cast
<
float
>
(
b
);
#endif
}
#endif // __HIPCC__
#if defined(__HIPCC__)
DEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hgt
(
half
(
a
),
half
(
b
));
}
HOST
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
}
#else // __HIPCC__
HOSTDEVICE
inline
bool
operator
>
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hgt
(
half
(
a
),
half
(
b
));
#else
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
#endif
}
#endif // __HIPCC__
#if defined(__HIPCC__)
DEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
__hge
(
half
(
a
),
half
(
b
));
}
HOST
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
return
static_cast
<
float
>
(
a
)
>=
static_cast
<
float
>
(
b
);
}
#else // __HIPCC__
HOSTDEVICE
inline
bool
operator
>=
(
const
float16
&
a
,
const
float16
&
b
)
{
#if defined(__
HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
#if defined(__
CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return
__hge
(
half
(
a
),
half
(
b
));
#else
return
static_cast
<
float
>
(
a
)
>=
static_cast
<
float
>
(
b
);
#endif
}
#endif // __HIPCC__
// Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16)
...
...
paddle/fluid/platform/float16_test.cu
浏览文件 @
34f1628c
...
...
@@ -22,30 +22,109 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#define ARITHMETIC_KERNEL(op_type, sign) \
__global__ void op_type(const half
* in1, const half* in2, half*
out) { \
__global__ void op_type(const half
*in1, const half *in2, half *
out) { \
out[0] = in1[0] sign in2[0]; \
}
#define COMPOUND_KERNEL(op_type, sign) \
__global__ void op_type(half
* in1, const half*
in2) { in1[0] sign in2[0]; }
__global__ void op_type(half
*in1, const half *
in2) { in1[0] sign in2[0]; }
#define COMPARISON_KERNEL(op_type, sign) \
__global__ void op_type(const half
* in1, const half* in2, bool*
out) { \
__global__ void op_type(const half
*in1, const half *in2, bool *
out) { \
out[0] = in1[0] sign in2[0]; \
}
#ifdef PADDLE_WITH_HIP
#define ARITHMETIC_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2, *out; \
half *d_in1, *d_in2, *d_out; \
int size = sizeof(half); \
hipMalloc(reinterpret_cast<void **>(&d_in1), size); \
hipMalloc(reinterpret_cast<void **>(&d_in2), size); \
hipMalloc(reinterpret_cast<void **>(&d_out), size); \
in1 = reinterpret_cast<half *>(malloc(size)); \
in2 = reinterpret_cast<half *>(malloc(size)); \
out = reinterpret_cast<half *>(malloc(size)); \
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice); \
hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice); \
hipLaunchKernelGGL(op_type, dim3(1), dim3(1), 0, 0, d_in1, d_in2, d_out); \
hipMemcpy(out, d_out, size, hipMemcpyDeviceToHost); \
EXPECT_EQ(static_cast<float>(float16(out[0])), v_out); \
free(in1); \
free(in2); \
free(out); \
hipFree(d_in1); \
hipFree(d_in2); \
hipFree(d_out); \
}
#define COMPOUND_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2; \
half *d_in1, *d_in2; \
int size = sizeof(half); \
hipMalloc(reinterpret_cast<void **>(&d_in1), size); \
hipMalloc(reinterpret_cast<void **>(&d_in2), size); \
in1 = reinterpret_cast<half *>(malloc(size)); \
in2 = reinterpret_cast<half *>(malloc(size)); \
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice); \
hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice); \
hipLaunchKernelGGL(op_type, dim3(1), dim3(1), 0, 0, d_in1, d_in2); \
hipMemcpy(in1, d_in1, size, hipMemcpyDeviceToHost); \
EXPECT_EQ(static_cast<float>(float16(in1[0])), v_out); \
free(in1); \
free(in2); \
hipFree(d_in1); \
hipFree(d_in2); \
}
#define COMPARISON_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, bool v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2; \
half *d_in1, *d_in2; \
bool *out, *d_out; \
int size = sizeof(half); \
hipMalloc(reinterpret_cast<void **>(&d_in1), size); \
hipMalloc(reinterpret_cast<void **>(&d_in2), size); \
hipMalloc(reinterpret_cast<void **>(&d_out), 1); \
in1 = reinterpret_cast<half *>(malloc(size)); \
in2 = reinterpret_cast<half *>(malloc(size)); \
out = reinterpret_cast<bool *>(malloc(1)); \
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice); \
hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice); \
hipLaunchKernelGGL(op_type, dim3(1), dim3(1), 0, 0, d_in1, d_in2, d_out); \
hipMemcpy(out, d_out, 1, hipMemcpyDeviceToHost); \
EXPECT_EQ(out[0], v_out); \
free(in1); \
free(in2); \
free(out); \
hipFree(d_in1); \
hipFree(d_in2); \
hipFree(d_out); \
}
#else
#define ARITHMETIC_KERNEL_LAUNCH(op_type) \
void Test##op_type(float v_in1, float v_in2, float v_out) { \
LOG(INFO) << "Test " << #op_type << " on GPU!"; \
half *in1, *in2, *out; \
half *d_in1, *d_in2, *d_out; \
int size = sizeof(half); \
cudaMalloc(reinterpret_cast<void
**>(&d_in1), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_in2), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_out), size);
\
in1 = reinterpret_cast<half
*>(malloc(size));
\
in2 = reinterpret_cast<half
*>(malloc(size));
\
out = reinterpret_cast<half
*>(malloc(size));
\
cudaMalloc(reinterpret_cast<void
**>(&d_in1), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_in2), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_out), size);
\
in1 = reinterpret_cast<half
*>(malloc(size));
\
in2 = reinterpret_cast<half
*>(malloc(size));
\
out = reinterpret_cast<half
*>(malloc(size));
\
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
...
...
@@ -67,10 +146,10 @@ limitations under the License. */
half *in1, *in2; \
half *d_in1, *d_in2; \
int size = sizeof(half); \
cudaMalloc(reinterpret_cast<void
**>(&d_in1), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_in2), size);
\
in1 = reinterpret_cast<half
*>(malloc(size));
\
in2 = reinterpret_cast<half
*>(malloc(size));
\
cudaMalloc(reinterpret_cast<void
**>(&d_in1), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_in2), size);
\
in1 = reinterpret_cast<half
*>(malloc(size));
\
in2 = reinterpret_cast<half
*>(malloc(size));
\
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
...
...
@@ -91,12 +170,12 @@ limitations under the License. */
half *d_in1, *d_in2; \
bool *out, *d_out; \
int size = sizeof(half); \
cudaMalloc(reinterpret_cast<void
**>(&d_in1), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_in2), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_out), 1);
\
in1 = reinterpret_cast<half
*>(malloc(size));
\
in2 = reinterpret_cast<half
*>(malloc(size));
\
out = reinterpret_cast<bool
*>(malloc(1));
\
cudaMalloc(reinterpret_cast<void
**>(&d_in1), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_in2), size);
\
cudaMalloc(reinterpret_cast<void
**>(&d_out), 1);
\
in1 = reinterpret_cast<half
*>(malloc(size));
\
in2 = reinterpret_cast<half
*>(malloc(size));
\
out = reinterpret_cast<bool
*>(malloc(1));
\
in1[0] = half(float16(v_in1)); \
in2[0] = half(float16(v_in2)); \
cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); \
...
...
@@ -111,12 +190,14 @@ limitations under the License. */
cudaFree(d_in2); \
cudaFree(d_out); \
}
#endif
#ifdef PADDLE_CUDA_FP16
namespace
paddle
{
namespace
platform
{
#if CUDA_VERSION < 9000
#if defined(PADDLE_WITH_HIP) || \
(defined(PADDLE_WITH_CUDA) && CUDA_VERSION < 9000)
ARITHMETIC_KERNEL
(
Add
,
+
)
ARITHMETIC_KERNEL
(
Sub
,
-
)
ARITHMETIC_KERNEL
(
Mul
,
*
)
...
...
@@ -128,21 +209,37 @@ ARITHMETIC_KERNEL_LAUNCH(Mul)
ARITHMETIC_KERNEL_LAUNCH
(
Div
)
// Negative sign kernel
__global__
void
Neg
(
half
*
in
)
{
in
[
0
]
=
-
in
[
0
];
}
__global__
void
Neg
(
half
*
in
)
{
in
[
0
]
=
-
in
[
0
];
}
void
TestNeg
(
float
v_in
,
float
v_out
)
{
LOG
(
INFO
)
<<
"Test Neg on GPU!"
;
half
*
in
,
*
d_in
;
int
size
=
sizeof
(
half
);
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in
),
size
);
in
=
reinterpret_cast
<
half
*>
(
malloc
(
size
));
#ifdef PADDLE_WITH_HIP
hipMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in
),
size
);
#else
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_in
),
size
);
#endif
in
=
reinterpret_cast
<
half
*>
(
malloc
(
size
));
in
[
0
]
=
half
(
float16
(
v_in
));
#ifdef PADDLE_WITH_HIP
hipMemcpy
(
d_in
,
in
,
size
,
hipMemcpyHostToDevice
);
#else
cudaMemcpy
(
d_in
,
in
,
size
,
cudaMemcpyHostToDevice
);
#endif
Neg
<<<
1
,
1
>>>
(
d_in
);
#ifdef PADDLE_WITH_HIP
hipMemcpy
(
in
,
d_in
,
size
,
hipMemcpyDeviceToHost
);
#else
cudaMemcpy
(
in
,
d_in
,
size
,
cudaMemcpyDeviceToHost
);
#endif
EXPECT_EQ
(
static_cast
<
float
>
(
float16
(
in
[
0
])),
v_out
);
free
(
in
);
#ifdef PADDLE_WITH_HIP
hipFree
(
d_in
);
#else
cudaFree
(
d_in
);
#endif
}
COMPOUND_KERNEL
(
AddAssign
,
+=
)
...
...
@@ -221,7 +318,7 @@ TEST(float16, lod_tensor_on_gpu) {
framework
::
LoDTensor
gpu_tensor
;
framework
::
LoDTensor
dst_tensor
;
float16
*
src_ptr
=
src_tensor
.
mutable_data
<
float16
>
(
float16
*
src_ptr
=
src_tensor
.
mutable_data
<
float16
>
(
framework
::
make_ddim
({
2
,
2
}),
CPUPlace
());
float16
arr
[
4
]
=
{
float16
(
1.0
f
),
float16
(
0.5
f
),
float16
(
0.33333
f
),
...
...
@@ -238,7 +335,7 @@ TEST(float16, lod_tensor_on_gpu) {
// Sync before comparing LoDTensors
gpu_ctx
.
Wait
();
const
float16
*
dst_ptr
=
dst_tensor
.
data
<
float16
>
();
const
float16
*
dst_ptr
=
dst_tensor
.
data
<
float16
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
{
EXPECT_EQ
(
src_ptr
[
i
].
x
,
dst_ptr
[
i
].
x
);
...
...
@@ -247,7 +344,7 @@ TEST(float16, lod_tensor_on_gpu) {
template
<
typename
T
>
struct
Functor
{
bool
operator
()(
const
T
&
val
)
{
bool
operator
()(
const
T
&
val
)
{
return
std
::
type_index
(
typeid
(
T
))
==
std
::
type_index
(
typeid
(
platform
::
float16
));
}
...
...
@@ -304,13 +401,13 @@ TEST(float16, cast) {
auto
b
=
a
;
{
// change semantic, keep the same value
float16
c
=
reinterpret_cast
<
float16
&>
(
reinterpret_cast
<
unsigned
&>
(
b
));
float16
c
=
reinterpret_cast
<
float16
&>
(
reinterpret_cast
<
unsigned
&>
(
b
));
EXPECT_EQ
(
b
,
c
);
}
{
// use uint32 low 16 bit store float16
uint32_t
c
=
reinterpret_cast
<
uint32_t
&>
(
b
);
uint32_t
c
=
reinterpret_cast
<
uint32_t
&>
(
b
);
float16
d
;
d
.
x
=
c
;
EXPECT_EQ
(
b
,
d
);
...
...
paddle/fluid/platform/gen_comm_id_helper.cc
浏览文件 @
34f1628c
...
...
@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include <arpa/inet.h>
...
...
@@ -336,7 +337,7 @@ void RecvBroadCastCommID(int server_fd, std::string endpoint,
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE
(
ncclUniqueId
)
#endif
#ifdef PADDLE_WITH_XPU_BKCL
...
...
paddle/fluid/platform/gen_comm_id_helper.h
浏览文件 @
34f1628c
...
...
@@ -14,7 +14,8 @@ limitations under the License. */
#pragma once
#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include <functional>
#include <string>
#include <vector>
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
34f1628c
...
...
@@ -17,7 +17,11 @@ limitations under the License. */
#include "gflags/gflags.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
#include "paddle/fluid/platform/dynload/cudnn.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
...
...
@@ -40,19 +44,34 @@ namespace platform {
int
CudnnVersion
()
{
if
(
!
dynload
::
HasCUDNN
())
return
-
1
;
#ifdef PADDLE_WITH_HIP
size_t
version_major
,
version_minor
,
version_patch
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
miopenGetVersion
(
&
version_major
,
&
version_minor
,
&
version_patch
));
return
version_major
*
100
+
version_minor
*
10
+
version_patch
;
#else
return
dynload
::
cudnnGetVersion
();
#endif
}
static
int
GetCUDADeviceCountImpl
()
{
int
driverVersion
=
0
;
#ifdef PADDLE_WITH_HIP
hipError_t
status
=
hipDriverGetVersion
(
&
driverVersion
);
#else
cudaError_t
status
=
cudaDriverGetVersion
(
&
driverVersion
);
#endif
if
(
!
(
status
==
cuda
Success
&&
driverVersion
!=
0
))
{
if
(
!
(
status
==
gpu
Success
&&
driverVersion
!=
0
))
{
// No GPU driver
VLOG
(
2
)
<<
"GPU Driver Version can't be detected. No GPU driver!"
;
return
0
;
}
#ifdef PADDLE_WITH_HIP
const
auto
*
cuda_visible_devices
=
std
::
getenv
(
"HIP_VISIBLE_DEVICES"
);
#else
const
auto
*
cuda_visible_devices
=
std
::
getenv
(
"CUDA_VISIBLE_DEVICES"
);
#endif
if
(
cuda_visible_devices
!=
nullptr
)
{
std
::
string
cuda_visible_devices_str
(
cuda_visible_devices
);
if
(
!
cuda_visible_devices_str
.
empty
())
{
...
...
@@ -68,12 +87,17 @@ static int GetCUDADeviceCountImpl() {
if
(
std
::
all_of
(
cuda_visible_devices_str
.
begin
(),
cuda_visible_devices_str
.
end
(),
[](
char
ch
)
{
return
ch
==
' '
;
}))
{
VLOG
(
2
)
<<
"CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected."
;
VLOG
(
2
)
<<
"CUDA_VISIBLE_DEVICES or HIP_VISIBLE_DEVICES is set to be "
"empty. No GPU detected."
;
return
0
;
}
}
int
count
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipGetDeviceCount
(
&
count
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGetDeviceCount
(
&
count
));
#endif
return
count
;
}
...
...
@@ -94,13 +118,24 @@ int GetCUDAComputeCapability(int id) {
id
,
GetCUDADeviceCount
()));
int
major
,
minor
;
#ifdef PADDLE_WITH_HIP
auto
major_error_code
=
hipDeviceGetAttribute
(
&
major
,
hipDeviceAttributeComputeCapabilityMajor
,
id
);
auto
minor_error_code
=
hipDeviceGetAttribute
(
&
minor
,
hipDeviceAttributeComputeCapabilityMinor
,
id
);
#else
auto
major_error_code
=
cudaDeviceGetAttribute
(
&
major
,
cudaDevAttrComputeCapabilityMajor
,
id
);
auto
minor_error_code
=
cudaDeviceGetAttribute
(
&
minor
,
cudaDevAttrComputeCapabilityMinor
,
id
);
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
major_error_code
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
minor_error_code
);
#ifdef PADDLE_WITH_HIP
return
major
*
100
+
minor
;
#else
return
major
*
10
+
minor
;
#endif
}
dim3
GetGpuMaxGridDimSize
(
int
id
)
{
...
...
@@ -111,15 +146,30 @@ dim3 GetGpuMaxGridDimSize(int id) {
id
,
GetCUDADeviceCount
()));
dim3
ret
;
int
size
;
#ifdef PADDLE_WITH_HIP
auto
error_code_x
=
hipDeviceGetAttribute
(
&
size
,
hipDeviceAttributeMaxGridDimX
,
id
);
#else
auto
error_code_x
=
cudaDeviceGetAttribute
(
&
size
,
cudaDevAttrMaxGridDimX
,
id
);
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
error_code_x
);
ret
.
x
=
size
;
#ifdef PADDLE_WITH_HIP
auto
error_code_y
=
hipDeviceGetAttribute
(
&
size
,
hipDeviceAttributeMaxGridDimY
,
id
);
#else
auto
error_code_y
=
cudaDeviceGetAttribute
(
&
size
,
cudaDevAttrMaxGridDimY
,
id
);
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
error_code_y
);
ret
.
y
=
size
;
#ifdef PADDLE_WITH_HIP
auto
error_code_z
=
hipDeviceGetAttribute
(
&
size
,
hipDeviceAttributeMaxGridDimZ
,
id
);
#else
auto
error_code_z
=
cudaDeviceGetAttribute
(
&
size
,
cudaDevAttrMaxGridDimZ
,
id
);
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
error_code_z
);
ret
.
z
=
size
;
return
ret
;
...
...
@@ -132,7 +182,11 @@ int GetCUDARuntimeVersion(int id) {
"but received id is: %d. GPU count is: %d."
,
id
,
GetCUDADeviceCount
()));
int
runtime_version
=
0
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipRuntimeGetVersion
(
&
runtime_version
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaRuntimeGetVersion
(
&
runtime_version
));
#endif
return
runtime_version
;
}
...
...
@@ -143,12 +197,16 @@ int GetCUDADriverVersion(int id) {
"but received id is: %d. GPU count is: %d."
,
id
,
GetCUDADeviceCount
()));
int
driver_version
=
0
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipDriverGetVersion
(
&
driver_version
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaDriverGetVersion
(
&
driver_version
));
#endif
return
driver_version
;
}
bool
TensorCoreAvailable
()
{
#if CUDA_VERSION >= 9000
#if
!defined(PADDLE_WITH_HIP) &&
CUDA_VERSION >= 9000
int
device
=
GetCurrentDeviceId
();
int
driver_version
=
GetCUDAComputeCapability
(
device
);
return
driver_version
>=
70
;
...
...
@@ -164,8 +222,13 @@ int GetCUDAMultiProcessors(int id) {
"but received id is: %d. GPU count is: %d."
,
id
,
GetCUDADeviceCount
()));
int
count
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipDeviceGetAttribute
(
&
count
,
hipDeviceAttributeMultiprocessorCount
,
id
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
id
));
#endif
return
count
;
}
...
...
@@ -176,8 +239,13 @@ int GetCUDAMaxThreadsPerMultiProcessor(int id) {
"but received id is: %d. GPU count is: %d."
,
id
,
GetCUDADeviceCount
()));
int
count
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipDeviceGetAttribute
(
&
count
,
hipDeviceAttributeMaxThreadsPerMultiProcessor
,
id
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMaxThreadsPerMultiProcessor
,
id
));
#endif
return
count
;
}
...
...
@@ -188,14 +256,23 @@ int GetCUDAMaxThreadsPerBlock(int id) {
"but received id is: %d. GPU count is: %d."
,
id
,
GetCUDADeviceCount
()));
int
count
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipDeviceGetAttribute
(
&
count
,
hipDeviceAttributeMaxThreadsPerBlock
,
id
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMaxThreadsPerBlock
,
id
));
#endif
return
count
;
}
int
GetCurrentDeviceId
()
{
int
device_id
;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipGetDevice
(
&
device_id
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaGetDevice
(
&
device_id
));
#endif
return
device_id
;
}
...
...
@@ -224,7 +301,11 @@ void SetDeviceId(int id) {
"Device id must be less than GPU count, "
"but received id is: %d. GPU count is: %d."
,
id
,
GetCUDADeviceCount
()));
#ifdef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS
(
hipSetDevice
(
id
));
#else
PADDLE_RETRY_CUDA_SUCCESS
(
cudaSetDevice
(
id
));
#endif
}
void
GpuMemoryUsage
(
size_t
*
available
,
size_t
*
total
)
{
...
...
@@ -289,46 +370,91 @@ size_t GpuMaxChunkSize() {
return
max_chunk_size
;
}
#ifdef PADDLE_WITH_HIP
void
GpuMemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
hipMemcpyKind
kind
,
hipStream_t
stream
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemcpyAsync
(
dst
,
src
,
count
,
kind
,
stream
));
}
#else
void
GpuMemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
cudaMemcpyKind
kind
,
cudaStream_t
stream
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpyAsync
(
dst
,
src
,
count
,
kind
,
stream
));
}
#endif
#ifdef PADDLE_WITH_HIP
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
hipMemcpyKind
kind
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemcpy
(
dst
,
src
,
count
,
kind
));
}
#else
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
cudaMemcpyKind
kind
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpy
(
dst
,
src
,
count
,
kind
));
}
#endif
void
GpuMemcpyPeerAsync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
)
{
int
src_device
,
size_t
count
,
gpuStream_t
stream
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemcpyPeerAsync
(
dst
,
dst_device
,
src
,
src_device
,
count
,
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpyPeerAsync
(
dst
,
dst_device
,
src
,
src_device
,
count
,
stream
));
#endif
}
void
GpuMemcpyPeerSync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemcpyPeer
(
dst
,
dst_device
,
src
,
src_device
,
count
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemcpyPeer
(
dst
,
dst_device
,
src
,
src_device
,
count
));
#endif
}
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
)
{
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
gpuStream_t
stream
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMemsetAsync
(
dst
,
value
,
count
,
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaMemsetAsync
(
dst
,
value
,
count
,
stream
));
#endif
}
void
GpuStreamSync
(
cudaStream_t
stream
)
{
void
GpuStreamSync
(
gpuStream_t
stream
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream
));
#endif
}
static
void
RaiseNonOutOfMemoryError
(
cudaError_t
*
status
)
{
static
void
RaiseNonOutOfMemoryError
(
gpuError_t
*
status
)
{
#ifdef PADDLE_WITH_HIP
if
(
*
status
==
hipErrorOutOfMemory
)
{
*
status
=
hipSuccess
;
}
#else
if
(
*
status
==
cudaErrorMemoryAllocation
)
{
*
status
=
cudaSuccess
;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
*
status
);
#ifdef PADDLE_WITH_HIP
*
status
=
hipGetLastError
();
if
(
*
status
==
hipErrorOutOfMemory
)
{
*
status
=
hipSuccess
;
}
#else
*
status
=
cudaGetLastError
();
if
(
*
status
==
cudaErrorMemoryAllocation
)
{
*
status
=
cudaSuccess
;
}
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
*
status
);
}
...
...
@@ -370,26 +496,38 @@ class RecordedCudaMallocHelper {
* or cudaSuccess would be returned, and the cudaGetLastError() flag
* would be clear.
*/
cuda
Error_t
Malloc
(
void
**
ptr
,
size_t
size
)
{
gpu
Error_t
Malloc
(
void
**
ptr
,
size_t
size
)
{
LockGuardPtr
<
std
::
mutex
>
lock
(
mtx_
);
if
(
UNLIKELY
(
NeedRecord
()
&&
cur_size_
+
size
>
limit_size_
))
{
#ifdef PADDLE_WITH_HIP
return
hipErrorOutOfMemory
;
#else
return
cudaErrorMemoryAllocation
;
#endif
}
CUDADeviceGuard
guard
(
dev_id_
);
#ifdef PADDLE_WITH_HIP
auto
result
=
hipMalloc
(
ptr
,
size
);
#else
auto
result
=
cudaMalloc
(
ptr
,
size
);
if
(
result
==
cudaSuccess
)
{
#endif
if
(
result
==
gpuSuccess
)
{
if
(
NeedRecord
())
{
cur_size_
+=
size
;
}
STAT_INT_ADD
(
"STAT_gpu"
+
std
::
to_string
(
dev_id_
)
+
"_mem_size"
,
size
);
return
cuda
Success
;
return
gpu
Success
;
}
else
{
RaiseNonOutOfMemoryError
(
&
result
);
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
// Non out of memory error would be raised inside
// RaiseNonOutOfMemoryError. Therefore, we can
// return cudaErrorMemoryAllocation directly here.
#ifdef PADDLE_WITH_HIP
return
hipErrorOutOfMemory
;
#else
return
cudaErrorMemoryAllocation
;
#endif
}
}
...
...
@@ -404,8 +542,13 @@ class RecordedCudaMallocHelper {
// process is terminating, in which case we don't care if
// cudaFree succeeds.
CUDADeviceGuard
guard
(
dev_id_
);
#ifdef PADDLE_WITH_HIP
auto
err
=
hipFree
(
ptr
);
if
(
err
!=
hipErrorDeinitialized
)
{
#else
auto
err
=
cudaFree
(
ptr
);
if
(
err
!=
cudaErrorCudartUnloading
)
{
#endif
PADDLE_ENFORCE_CUDA_SUCCESS
(
err
);
if
(
NeedRecord
())
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
*
mtx_
);
...
...
@@ -413,7 +556,11 @@ class RecordedCudaMallocHelper {
}
STAT_INT_SUB
(
"STAT_gpu"
+
std
::
to_string
(
dev_id_
)
+
"_mem_size"
,
size
);
}
else
{
#ifdef PADDLE_WITH_HIP
hipGetLastError
();
// clear the error flag when hipErrorDeinitialized
#else
cudaGetLastError
();
// clear the error flag when cudaErrorCudartUnloading
#endif
}
}
...
...
@@ -421,8 +568,12 @@ class RecordedCudaMallocHelper {
size_t
*
actual_total
)
{
{
CUDADeviceGuard
guard
(
dev_id_
);
#ifdef PADDLE_WITH_HIP
auto
result
=
hipMemGetInfo
(
actual_avail
,
actual_total
);
#else
auto
result
=
cudaMemGetInfo
(
actual_avail
,
actual_total
);
if
(
result
!=
cudaSuccess
)
{
#endif
if
(
result
!=
gpuSuccess
)
{
*
actual_avail
=
0
;
}
RaiseNonOutOfMemoryError
(
&
result
);
...
...
@@ -458,13 +609,13 @@ class RecordedCudaMallocHelper {
static
std
::
once_flag
once_flag_
;
static
std
::
vector
<
std
::
unique_ptr
<
RecordedCudaMallocHelper
>>
instances_
;
};
};
// NOLINT
std
::
once_flag
RecordedCudaMallocHelper
::
once_flag_
;
std
::
vector
<
std
::
unique_ptr
<
RecordedCudaMallocHelper
>>
RecordedCudaMallocHelper
::
instances_
;
cuda
Error_t
RecordedCudaMalloc
(
void
**
ptr
,
size_t
size
,
int
dev_id
)
{
gpu
Error_t
RecordedCudaMalloc
(
void
**
ptr
,
size_t
size
,
int
dev_id
)
{
return
RecordedCudaMallocHelper
::
Instance
(
dev_id
)
->
Malloc
(
ptr
,
size
);
}
...
...
paddle/fluid/platform/gpu_info.h
浏览文件 @
34f1628c
...
...
@@ -15,11 +15,19 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// Note: this header for simplify HIP and CUDA type string
#include <stddef.h>
#include <string>
#include <vector>
#include "paddle/fluid/platform/type_defs.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -86,28 +94,36 @@ size_t GpuMaxChunkSize();
//! Copy memory from address src to dst asynchronously.
void
GpuMemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
#ifdef PADDLE_WITH_HIP
enum
hipMemcpyKind
kind
,
hipStream_t
stream
);
#else
enum
cudaMemcpyKind
kind
,
cudaStream_t
stream
);
#endif
//! Copy memory from address src to dst synchronously.
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
#ifdef PADDLE_WITH_HIP
enum
hipMemcpyKind
kind
);
#else
enum
cudaMemcpyKind
kind
);
#endif
//! Copy memory from one device to another device asynchronously.
void
GpuMemcpyPeerAsync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cuda
Stream_t
stream
);
int
src_device
,
size_t
count
,
gpu
Stream_t
stream
);
//! Copy memory from one device to another device synchronously.
void
GpuMemcpyPeerSync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
);
//! Set memory dst with value count size asynchronously
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cuda
Stream_t
stream
);
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
gpu
Stream_t
stream
);
//! Blocks until stream has completed all operations.
void
GpuStreamSync
(
cuda
Stream_t
stream
);
void
GpuStreamSync
(
gpu
Stream_t
stream
);
//! CudaMalloc with recorded info
cuda
Error_t
RecordedCudaMalloc
(
void
**
ptr
,
size_t
size
,
int
dev_id
);
gpu
Error_t
RecordedCudaMalloc
(
void
**
ptr
,
size_t
size
,
int
dev_id
);
//! CudaFree with recorded info
void
RecordedCudaFree
(
void
*
p
,
size_t
size
,
int
dev_id
);
...
...
paddle/fluid/platform/gpu_launch_config.h
浏览文件 @
34f1628c
...
...
@@ -16,9 +16,13 @@
#pragma once
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#else
#include <hip/hip_runtime.h>
#endif
#include <stddef.h>
#include <algorithm>
#include <string>
...
...
paddle/fluid/platform/nccl_helper.h
浏览文件 @
34f1628c
...
...
@@ -14,7 +14,7 @@
#pragma once
#if
def PADDLE_WITH_NCCL
#if
defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include <stdio.h>
#include <memory>
#include <string>
...
...
@@ -25,7 +25,12 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/collective_helper.h"
#ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
#ifdef PADDLE_WITH_RCCL
#include "paddle/fluid/platform/dynload/rccl.h"
#endif
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -81,7 +86,7 @@ struct NCCLContext {
explicit
NCCLContext
(
int
dev_id
)
:
ctx_
(
new
CUDADeviceContext
(
CUDAPlace
(
dev_id
))),
comm_
{
nullptr
}
{}
cuda
Stream_t
stream
()
const
{
return
ctx_
->
stream
();
}
gpu
Stream_t
stream
()
const
{
return
ctx_
->
stream
();
}
ncclComm_t
comm
()
const
{
return
comm_
;
}
int
device_id
()
const
{
...
...
paddle/fluid/platform/place.h
浏览文件 @
34f1628c
...
...
@@ -154,7 +154,7 @@ struct PlaceVisitorWrapper
}
typename
Visitor
::
result_type
operator
()(
const
CUDAPlace
&
cuda
)
const
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return
visitor_
(
cuda
);
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
...
...
@@ -165,7 +165,7 @@ struct PlaceVisitorWrapper
typename
Visitor
::
result_type
operator
()(
const
CUDAPinnedPlace
&
cuda_pinned
)
const
{
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
return
visitor_
(
cuda_pinned
);
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
...
...
paddle/fluid/platform/profiler.cc
浏览文件 @
34f1628c
...
...
@@ -206,7 +206,7 @@ void EnableProfiler(ProfilerState state) {
g_state
=
state
;
should_send_profile_state
=
true
;
GetDeviceTracer
()
->
Enable
();
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
g_state
==
ProfilerState
::
kCUDA
||
g_state
==
ProfilerState
::
kAll
||
g_state
==
ProfilerState
::
kCPU
)
{
// Generate some dummy events first to reduce the startup overhead.
...
...
paddle/fluid/platform/profiler.cu
浏览文件 @
34f1628c
...
...
@@ -12,7 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
...
...
@@ -31,6 +38,21 @@ static void ForEachDevice(std::function<void(int)> func) {
}
void
DummyKernelAndEvent
()
{
#ifdef PADDLE_WITH_HIP
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ForEachDevice
([](
int
d
)
{
platform
::
SetDeviceId
(
d
);
hipStream_t
stream
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamCreate
(
&
stream
));
Mark
(
"_cuda_startup_"
);
int
*
ptr
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipMalloc
(
&
ptr
,
sizeof
(
int
)));
hipLaunchKernelGGL
(
DummyKernel
,
dim3
(
1
),
dim3
(
1
),
0
,
stream
,
ptr
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipFree
(
ptr
));
});
}
#else
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
ForEachDevice
([](
int
d
)
{
platform
::
SetDeviceId
(
d
);
...
...
@@ -44,6 +66,7 @@ void DummyKernelAndEvent() {
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaFree
(
ptr
));
});
}
#endif
}
}
// namespace platform
...
...
paddle/fluid/platform/profiler.h
浏览文件 @
34f1628c
...
...
@@ -28,7 +28,7 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/event.h"
#include "paddle/fluid/platform/place.h"
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/gpu_info.h"
#endif
namespace
paddle
{
...
...
@@ -220,7 +220,7 @@ std::string OpName(const framework::VariableNameMap& name_map,
const
std
::
string
&
type_name
);
void
SetTracerOption
(
TracerOption
option
);
platform
::
TracerOption
GetTracerOption
();
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void
DummyKernelAndEvent
();
#endif
...
...
paddle/fluid/platform/profiler_helper.h
浏览文件 @
34f1628c
...
...
@@ -31,6 +31,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#endif // PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
namespace
paddle
{
namespace
platform
{
...
...
@@ -122,6 +125,13 @@ void SynchronizeAllDevice() {
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaDeviceSynchronize
());
}
#endif
#ifdef PADDLE_WITH_HIP
int
count
=
GetCUDADeviceCount
();
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
SetDeviceId
(
i
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipDeviceSynchronize
());
}
#endif
}
// Print results
...
...
@@ -300,7 +310,7 @@ void SetEvent(bool merge_thread, const Event &analyze_event,
if
(
rit
!=
pushed_events
->
rend
())
{
double
event_time
=
0
;
double
gpu_time
=
0.0
f
;
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpu_time
=
rit
->
CudaElapsedMs
(
analyze_event
);
#endif
double
cpu_time
=
rit
->
CpuElapsedMs
(
analyze_event
);
...
...
paddle/fluid/platform/profiler_test.cc
浏览文件 @
34f1628c
...
...
@@ -122,7 +122,7 @@ TEST(RecordEvent, RecordEvent) {
if
(
events
[
i
][
j
].
name
()
==
"_start_profiler_"
)
++
start_profiler_count
;
if
(
events
[
i
][
j
].
name
()
==
"push"
)
{
EXPECT_EQ
(
events
[
i
][
j
+
1
].
name
(),
"pop"
);
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
EXPECT_GT
(
events
[
i
][
j
].
CudaElapsedMs
(
events
[
i
][
j
+
1
]),
0
);
#else
EXPECT_GT
(
events
[
i
][
j
].
CpuElapsedMs
(
events
[
i
][
j
+
1
]),
0
);
...
...
@@ -146,3 +146,13 @@ TEST(TMP, stream_wait) {
cudaStreamSynchronize
(
stream
);
}
#endif
#ifdef PADDLE_WITH_HIP
TEST
(
TMP
,
stream_wait
)
{
hipStream_t
stream
;
hipStreamCreate
(
&
stream
);
hipStreamSynchronize
(
stream
);
hipStreamSynchronize
(
stream
);
hipStreamSynchronize
(
stream
);
}
#endif
paddle/fluid/platform/stream_callback_manager.cc
浏览文件 @
34f1628c
...
...
@@ -18,7 +18,10 @@
namespace
paddle
{
namespace
platform
{
#if CUDA_VERSION >= 10000
#ifdef PADDLE_WITH_HIP
static
void
StreamCallbackFunc
(
gpuStream_t
stream
,
gpuError_t
status
,
void
*
user_data
)
#elif CUDA_VERSION >= 10000
static
void
CUDART_CB
StreamCallbackFunc
(
void
*
user_data
)
#else
static
void
CUDART_CB
StreamCallbackFunc
(
cudaStream_t
stream
,
...
...
@@ -30,7 +33,7 @@ static void CUDART_CB StreamCallbackFunc(cudaStream_t stream,
(
*
func
)();
}
StreamCallbackManager
::
StreamCallbackManager
(
const
cuda
Stream_t
stream
)
StreamCallbackManager
::
StreamCallbackManager
(
const
gpu
Stream_t
stream
)
:
stream_
(
stream
),
thread_pool_
(
1
)
{}
void
StreamCallbackManager
::
AddCallback
(
std
::
function
<
void
()
>
callback
)
const
{
...
...
@@ -42,7 +45,10 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
(
*
callback_func
)();
});
});
#if CUDA_VERSION >= 10000
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamAddCallback
(
stream_
,
StreamCallbackFunc
,
func
,
0
));
#elif CUDA_VERSION >= 10000
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaLaunchHostFunc
(
stream_
,
StreamCallbackFunc
,
func
));
#else
...
...
@@ -52,7 +58,11 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
}
void
StreamCallbackManager
::
Wait
()
const
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS
(
hipStreamSynchronize
(
stream_
));
#else
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamSynchronize
(
stream_
));
#endif
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mtx_
);
if
(
last_future_
.
valid
())
{
...
...
paddle/fluid/platform/stream_callback_manager.h
浏览文件 @
34f1628c
...
...
@@ -15,8 +15,16 @@
#pragma once
#include <ThreadPool.h>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <functional>
#include <future> // NOLINT
#include <memory>
...
...
@@ -31,7 +39,7 @@ namespace platform {
// Make StreamCallbackManager thread-safe
class
StreamCallbackManager
{
public:
explicit
StreamCallbackManager
(
const
cuda
Stream_t
stream
);
explicit
StreamCallbackManager
(
const
gpu
Stream_t
stream
);
~
StreamCallbackManager
()
=
default
;
...
...
@@ -40,7 +48,7 @@ class StreamCallbackManager {
void
Wait
()
const
;
private:
const
cuda
Stream_t
stream_
;
const
gpu
Stream_t
stream_
;
mutable
::
ThreadPool
thread_pool_
;
mutable
std
::
mutex
mtx_
;
mutable
std
::
future
<
void
>
last_future_
;
...
...
paddle/fluid/platform/test_limit_gpu_memory.cu
浏览文件 @
34f1628c
...
...
@@ -40,24 +40,36 @@ TEST(test_record_malloc, test_limit_gpu_memory) {
RecordedCudaMemGetInfo
(
&
avail
,
&
total
,
&
actual_avail
,
&
actual_total
,
DEVICE_ID
);
ASSERT_EQ
(
total
,
limit
);
ASSERT_EQ
(
cudaGetLastError
(),
cudaSuccess
);
#ifdef PADDLE_WITH_HIP
ASSERT_EQ
(
hipGetLastError
(),
gpuSuccess
);
#else
ASSERT_EQ
(
cudaGetLastError
(),
gpuSuccess
);
#endif
}
{
CUDADeviceGuard
guard
(
DEVICE_ID
);
GpuMemoryUsage
(
&
avail
,
&
total
);
ASSERT_EQ
(
total
,
limit
);
ASSERT_EQ
(
cudaGetLastError
(),
cudaSuccess
);
#ifdef PADDLE_WITH_HIP
ASSERT_EQ
(
hipGetLastError
(),
gpuSuccess
);
#else
ASSERT_EQ
(
cudaGetLastError
(),
gpuSuccess
);
#endif
}
cudaError_t
err
=
cuda
Success
;
gpuError_t
err
=
gpu
Success
;
void
*
p1
=
nullptr
;
size_t
size1
=
limit
/
4
*
3
;
{
err
=
platform
::
RecordedCudaMalloc
(
&
p1
,
size1
,
DEVICE_ID
);
ASSERT_EQ
(
err
,
cudaSuccess
);
ASSERT_EQ
(
cudaGetLastError
(),
cudaSuccess
);
ASSERT_EQ
(
err
,
gpuSuccess
);
#ifdef PADDLE_WITH_HIP
ASSERT_EQ
(
hipGetLastError
(),
gpuSuccess
);
#else
ASSERT_EQ
(
cudaGetLastError
(),
gpuSuccess
);
#endif
ASSERT_NE
(
p1
,
nullptr
);
ASSERT_EQ
(
RecordedCudaMallocSize
(
DEVICE_ID
),
size1
);
...
...
@@ -67,8 +79,13 @@ TEST(test_record_malloc, test_limit_gpu_memory) {
size_t
size2
=
limit
/
2
;
{
err
=
platform
::
RecordedCudaMalloc
(
&
p2
,
size2
,
DEVICE_ID
);
#ifdef PADDLE_WITH_HIP
ASSERT_EQ
(
err
,
hipErrorOutOfMemory
);
ASSERT_EQ
(
hipGetLastError
(),
gpuSuccess
);
#else
ASSERT_EQ
(
err
,
cudaErrorMemoryAllocation
);
ASSERT_EQ
(
cudaGetLastError
(),
cudaSuccess
);
ASSERT_EQ
(
cudaGetLastError
(),
gpuSuccess
);
#endif
ASSERT_EQ
(
p2
,
nullptr
);
ASSERT_EQ
(
RecordedCudaMallocSize
(
DEVICE_ID
),
size1
);
...
...
@@ -81,8 +98,12 @@ TEST(test_record_malloc, test_limit_gpu_memory) {
{
err
=
platform
::
RecordedCudaMalloc
(
&
p2
,
size2
,
DEVICE_ID
);
ASSERT_EQ
(
err
,
cudaSuccess
);
ASSERT_EQ
(
err
,
gpuSuccess
);
#ifdef PADDLE_WITH_HIP
ASSERT_EQ
(
hipGetLastError
(),
hipSuccess
);
#else
ASSERT_EQ
(
cudaGetLastError
(),
cudaSuccess
);
#endif
ASSERT_NE
(
p2
,
nullptr
);
ASSERT_EQ
(
RecordedCudaMallocSize
(
DEVICE_ID
),
size2
);
}
...
...
paddle/fluid/platform/transform.h
浏览文件 @
34f1628c
...
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/place.h"
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/execution_policy.h>
#include <thrust/transform.h>
#include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h"
...
...
@@ -76,7 +76,7 @@ struct Transform<platform::CPUDeviceContext> {
}
};
#if
def __NVCC__
#if
defined(__NVCC__) || defined(__HIPCC__)
template
<
>
struct
Transform
<
platform
::
CUDADeviceContext
>
{
template
<
typename
InputIter
,
typename
OutputIter
,
typename
UnaryOperation
>
...
...
@@ -86,10 +86,17 @@ struct Transform<platform::CUDADeviceContext> {
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The CUDA Transform must be used in GPU place."
));
#ifdef __HIPCC__
thrust
::
transform
(
thrust
::
hip
::
par
.
on
(
context
.
stream
()),
details
::
CastToCUDATransformIterator
(
first
),
details
::
CastToCUDATransformIterator
(
last
),
details
::
CastToCUDATransformIterator
(
result
),
op
);
#else
thrust
::
transform
(
thrust
::
cuda
::
par
.
on
(
context
.
stream
()),
details
::
CastToCUDATransformIterator
(
first
),
details
::
CastToCUDATransformIterator
(
last
),
details
::
CastToCUDATransformIterator
(
result
),
op
);
#endif
}
template
<
typename
InputIter1
,
typename
InputIter2
,
typename
OutputIter
,
...
...
@@ -101,11 +108,19 @@ struct Transform<platform::CUDADeviceContext> {
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The CUDA Transform must be used in GPU place."
));
#ifdef __HIPCC__
thrust
::
transform
(
thrust
::
hip
::
par
.
on
(
context
.
stream
()),
details
::
CastToCUDATransformIterator
(
first1
),
details
::
CastToCUDATransformIterator
(
last1
),
details
::
CastToCUDATransformIterator
(
first2
),
details
::
CastToCUDATransformIterator
(
result
),
op
);
#else
thrust
::
transform
(
thrust
::
cuda
::
par
.
on
(
context
.
stream
()),
details
::
CastToCUDATransformIterator
(
first1
),
details
::
CastToCUDATransformIterator
(
last1
),
details
::
CastToCUDATransformIterator
(
first2
),
details
::
CastToCUDATransformIterator
(
result
),
op
);
#endif
}
};
#endif
...
...
paddle/fluid/platform/variant.h
浏览文件 @
34f1628c
...
...
@@ -32,7 +32,7 @@ limitations under the License. */
// BOOST_NO_CXX11_VARIADIC_TEMPLATES on gcc/clang to generate same
// function symbols. For details,
// https://github.com/PaddlePaddle/Paddle/issues/3386
#if
def PADDLE_WITH_CUDA
#if
defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#ifndef BOOST_NO_CXX11_VARIADIC_TEMPLATES
#define BOOST_NO_CXX11_VARIADIC_TEMPLATES
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录