Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8c20d668
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c20d668
编写于
1月 18, 2022
作者:
S
sneaxiy
提交者:
GitHub
1月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speedup FP16 Gelu op using fast math and vectorized 8 kernel (#38980)
* speedup gelu using fast math * add bwd part
上级
55e9087f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
206 addition
and
0 deletion
+206
-0
paddle/fluid/operators/gelu_op.cu
paddle/fluid/operators/gelu_op.cu
+171
-0
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+3
-0
python/paddle/fluid/tests/unittests/test_gelu_op.py
python/paddle/fluid/tests/unittests/test_gelu_op.py
+32
-0
未找到文件。
paddle/fluid/operators/gelu_op.cu
浏览文件 @
8c20d668
...
@@ -16,9 +16,156 @@ limitations under the License. */
...
@@ -16,9 +16,156 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/gelu_op.h"
#include "paddle/fluid/operators/gelu_op.h"
DECLARE_bool
(
use_fast_math
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
#ifdef __NVCC__
template
<
bool
FastMode
>
static
__device__
__forceinline__
float
FP32FastTanh
(
float
x
)
{
#if __CUDA_ARCH__ >= 750 && !defined(_WIN32)
if
(
FastMode
)
{
float
y
;
asm
(
"tanh.approx.f32 %0,%1;
\n\t
"
:
"=f"
(
y
)
:
"f"
(
x
));
return
y
;
}
#endif
return
tanhf
(
x
);
}
template
<
bool
FastMode
>
static
__device__
__forceinline__
float
FP32GeluFwd
(
float
x
)
{
auto
tanh_out
=
FP32FastTanh
<
FastMode
>
(
0.79788456
f
*
x
*
(
1.0
f
+
0.044715
f
*
x
*
x
));
return
x
*
0.5
f
*
(
1.0
f
+
tanh_out
);
}
template
<
bool
FastMode
>
static
__device__
__forceinline__
float
FP32GeluBwd
(
float
x
,
float
y_g
)
{
auto
tanh_out
=
FP32FastTanh
<
FastMode
>
(
0.79788456
f
*
x
*
(
1.0
f
+
0.044715
f
*
x
*
x
));
auto
tmp
=
0.5
f
*
x
*
((
1.0
f
-
tanh_out
*
tanh_out
)
*
(
0.79788456
f
+
0.1070322243
f
*
x
*
x
))
+
0.5
f
*
(
1.0
f
+
tanh_out
);
return
tmp
*
y_g
;
}
template
<
int
VecSize
,
bool
FastMode
>
static
__global__
void
FP16FastGeluFwdCUDAKernel
(
const
__half
*
x
,
__half
*
y
,
size_t
n
)
{
size_t
offset
=
static_cast
<
size_t
>
(
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
)
*
VecSize
;
size_t
stride
=
static_cast
<
size_t
>
(
blockDim
.
x
*
gridDim
.
x
)
*
VecSize
;
for
(;
offset
<
n
;
offset
+=
stride
)
{
using
ArrT
=
platform
::
AlignedVector
<
__half
,
VecSize
>
;
ArrT
in_arr
=
*
reinterpret_cast
<
const
ArrT
*>
(
x
+
offset
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
float
tmp
=
__half2float
(
in_arr
[
i
]);
in_arr
[
i
]
=
__float2half
(
FP32GeluFwd
<
FastMode
>
(
tmp
));
}
*
reinterpret_cast
<
ArrT
*>
(
y
+
offset
)
=
in_arr
;
}
}
template
<
int
VecSize
,
bool
FastMode
>
static
__global__
void
FP16FastGeluBwdCUDAKernel
(
const
__half
*
x
,
const
__half
*
y_g
,
__half
*
x_g
,
size_t
n
)
{
size_t
offset
=
static_cast
<
size_t
>
(
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
)
*
VecSize
;
size_t
stride
=
static_cast
<
size_t
>
(
blockDim
.
x
*
gridDim
.
x
)
*
VecSize
;
for
(;
offset
<
n
;
offset
+=
stride
)
{
using
ArrT
=
platform
::
AlignedVector
<
__half
,
VecSize
>
;
ArrT
x_in_arr
=
*
reinterpret_cast
<
const
ArrT
*>
(
x
+
offset
);
ArrT
y_g_in_arr
=
*
reinterpret_cast
<
const
ArrT
*>
(
y_g
+
offset
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
__half2
tmp_fp16_2
;
tmp_fp16_2
.
x
=
x_in_arr
[
i
];
tmp_fp16_2
.
y
=
y_g_in_arr
[
i
];
float2
tmp_fp32_2
=
__half22float2
(
tmp_fp16_2
);
x_in_arr
[
i
]
=
__float2half
(
FP32GeluBwd
<
FastMode
>
(
tmp_fp32_2
.
x
,
tmp_fp32_2
.
y
));
}
*
reinterpret_cast
<
ArrT
*>
(
x_g
+
offset
)
=
x_in_arr
;
}
}
static
bool
TryLaunchFP16FastGeluFwdVectorizeCUDAKernel
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
__half
*
x
,
__half
*
y
,
size_t
n
)
{
auto
is_aligned
=
[](
const
void
*
p
,
size_t
alignment
)
{
return
reinterpret_cast
<
uintptr_t
>
(
p
)
%
alignment
==
0
;
};
#define PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL(__vec_size, __use_fast_math) \
do { \
constexpr auto kAlignment = \
alignof(platform::AlignedVector<__half, __vec_size>); \
if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \
is_aligned(y, kAlignment)) { \
size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
size_t block = (n / __vec_size + thread - 1) / thread; \
block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize().x); \
VLOG(10) << "Use FP16 fast gelu fwd kernel, block = " << block \
<< " , thread = " << thread; \
FP16FastGeluFwdCUDAKernel< \
__vec_size, \
__use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y, n); \
return true; \
} \
} while (0)
if
(
FLAGS_use_fast_math
)
{
PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL
(
8
,
true
);
}
else
{
PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL
(
8
,
false
);
}
#undef PD_LAUNCH_FP16_FAST_GELU_FWD_KERNEL
return
false
;
}
static
bool
TryLaunchFP16FastGeluBwdVectorizeCUDAKernel
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
__half
*
x
,
const
__half
*
y_g
,
__half
*
x_g
,
size_t
n
)
{
auto
is_aligned
=
[](
const
void
*
p
,
size_t
alignment
)
{
return
reinterpret_cast
<
uintptr_t
>
(
p
)
%
alignment
==
0
;
};
#define PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL(__vec_size, __use_fast_math) \
do { \
constexpr auto kAlignment = \
alignof(platform::AlignedVector<__half, __vec_size>); \
if (n % __vec_size == 0 && is_aligned(x, kAlignment) && \
is_aligned(x, kAlignment) && is_aligned(y_g, kAlignment) && \
is_aligned(x_g, kAlignment)) { \
size_t thread = std::min<size_t>(512, dev_ctx.GetMaxThreadsPerBlock()); \
size_t block = (n / __vec_size + thread - 1) / thread; \
block = std::min<size_t>(block, dev_ctx.GetCUDAMaxGridDimSize().x); \
VLOG(10) << "Use FP16 fast gelu bwd kernel, block = " << block \
<< " , thread = " << thread; \
FP16FastGeluBwdCUDAKernel< \
__vec_size, \
__use_fast_math><<<block, thread, 0, dev_ctx.stream()>>>(x, y_g, \
x_g, n); \
return true; \
} \
} while (0)
if
(
FLAGS_use_fast_math
)
{
PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL
(
8
,
true
);
}
else
{
PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL
(
8
,
false
);
}
#undef PD_LAUNCH_FP16_FAST_GELU_BWD_KERNEL
return
false
;
}
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
GeluWithApproximateFunctor
{
struct
GeluWithApproximateFunctor
{
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
using
MPType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
...
@@ -59,7 +206,19 @@ class GeluKernel<platform::CUDADeviceContext, T>
...
@@ -59,7 +206,19 @@ class GeluKernel<platform::CUDADeviceContext, T>
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out
};
const
auto
&
dev_ctx
=
const
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
if
(
approximate
)
{
if
(
approximate
)
{
#ifdef __NVCC__
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
size_t
n
=
in
->
numel
();
const
auto
*
in_ptr
=
reinterpret_cast
<
const
__half
*>
(
in
->
data
<
T
>
());
auto
*
out_ptr
=
reinterpret_cast
<
__half
*>
(
out
->
data
<
T
>
());
if
(
TryLaunchFP16FastGeluFwdVectorizeCUDAKernel
(
dev_ctx
,
in_ptr
,
out_ptr
,
n
))
{
return
;
}
}
#endif
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
0
,
GeluWithApproximateFunctor
<
T
>
());
dev_ctx
,
ins
,
&
outs
,
0
,
GeluWithApproximateFunctor
<
T
>
());
}
else
{
}
else
{
...
@@ -120,6 +279,18 @@ class GeluGradKernel<platform::CUDADeviceContext, T>
...
@@ -120,6 +279,18 @@ class GeluGradKernel<platform::CUDADeviceContext, T>
const
auto
&
dev_ctx
=
const
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
if
(
approximate
)
{
if
(
approximate
)
{
#ifdef __NVCC__
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
size_t
n
=
x
->
numel
();
const
auto
*
x_ptr
=
reinterpret_cast
<
const
__half
*>
(
x
->
data
<
T
>
());
const
auto
*
y_g_ptr
=
reinterpret_cast
<
const
__half
*>
(
dout
->
data
<
T
>
());
auto
*
x_g_ptr
=
reinterpret_cast
<
__half
*>
(
dx
->
data
<
T
>
());
if
(
TryLaunchFP16FastGeluBwdVectorizeCUDAKernel
(
dev_ctx
,
x_ptr
,
y_g_ptr
,
x_g_ptr
,
n
))
{
return
;
}
}
#endif
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
0
,
GeluWithApproximateGradFunctor
<
T
>
());
dev_ctx
,
ins
,
&
outs
,
0
,
GeluWithApproximateGradFunctor
<
T
>
());
}
else
{
}
else
{
...
...
paddle/fluid/platform/flags.cc
浏览文件 @
8c20d668
...
@@ -652,6 +652,9 @@ PADDLE_DEFINE_EXPORTED_bool(
...
@@ -652,6 +652,9 @@ PADDLE_DEFINE_EXPORTED_bool(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool
(
conv2d_disable_cudnn
,
false
,
PADDLE_DEFINE_EXPORTED_bool
(
conv2d_disable_cudnn
,
false
,
"Disable cudnn in conv2d"
);
"Disable cudnn in conv2d"
);
PADDLE_DEFINE_EXPORTED_bool
(
use_fast_math
,
false
,
"Whether to use fast math GPU functions."
);
#endif
#endif
/**
/**
...
...
python/paddle/fluid/tests/unittests/test_gelu_op.py
浏览文件 @
8c20d668
...
@@ -19,6 +19,8 @@ import numpy as np
...
@@ -19,6 +19,8 @@ import numpy as np
from
scipy.special
import
erf
from
scipy.special
import
erf
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph
as
dg
import
paddle.fluid.dygraph
as
dg
import
paddle
import
paddle.nn.functional
as
F
def
gelu
(
x
,
approximate
):
def
gelu
(
x
,
approximate
):
...
@@ -59,6 +61,36 @@ class TestGeluOp(unittest.TestCase):
...
@@ -59,6 +61,36 @@ class TestGeluOp(unittest.TestCase):
if
fluid
.
is_compiled_with_cuda
():
if
fluid
.
is_compiled_with_cuda
():
self
.
_test_case1_gpu
(
approximate
)
self
.
_test_case1_gpu
(
approximate
)
def
test_fast_math
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
def
use_fast_math
(
enabled
):
paddle
.
set_flags
({
'FLAGS_use_fast_math'
:
enabled
})
shape
=
[
11
,
17
,
8
]
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
shape
).
astype
(
np
.
float16
)
y_g_np
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
shape
).
astype
(
np
.
float16
)
def
run_gelu_op
(
approximate
):
with
dg
.
guard
():
x
=
paddle
.
to_tensor
(
x_np
)
x
.
stop_gradient
=
False
y
=
F
.
gelu
(
x
,
approximate
=
approximate
)
x_grad
=
paddle
.
grad
([
y
],
[
x
],
[
paddle
.
to_tensor
(
y_g_np
)])[
0
]
return
y
.
numpy
(),
x_grad
.
numpy
()
use_fast_math
(
True
)
y_fast_math
,
x_g_fast_math
=
run_gelu_op
(
True
)
use_fast_math
(
False
)
y_ref
,
x_g_ref
=
run_gelu_op
(
True
)
self
.
assertTrue
(
np
.
allclose
(
y_ref
,
y_fast_math
,
rtol
=
1e-5
,
atol
=
5e-4
))
self
.
assertTrue
(
np
.
allclose
(
x_g_ref
,
x_g_fast_math
,
rtol
=
1e-5
,
atol
=
5e-4
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录