Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c0163837
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c0163837
编写于
12月 14, 2020
作者:
L
Leo Chen
提交者:
GitHub
12月 14, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix compile problem when cuda_arch < 6000 (#29576)
* fix compile problem when cuda_arch < 6000 * refine code * refine code
上级
79a41a9e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
8 deletion
+14
-8
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+14
-8
未找到文件。
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
c0163837
...
...
@@ -109,7 +109,7 @@ struct PairForLayerNormAddFunctor {
template
<
typename
T
>
__inline__
__device__
T
rsqrt
(
const
T
val
)
{
return
::
r
sqrt
(
val
);
return
static_cast
<
T
>
(
1
)
/
sqrt
(
val
);
}
template
<
>
...
...
@@ -117,10 +117,17 @@ __inline__ __device__ float rsqrt(const float val) {
return
rsqrtf
(
val
);
}
template
<
>
__inline__
__device__
double
rsqrt
(
const
double
val
)
{
return
rsqrt
(
val
);
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template
<
>
__inline__
__device__
half
rsqrt
(
const
half
val
)
{
return
hrsqrt
(
val
);
}
#endif
template
<
typename
T
,
typename
U
,
int
BlockDim
>
__global__
void
LayerNormForward
(
const
T
*
x
,
const
U
*
scale
,
const
U
*
bias
,
...
...
@@ -841,6 +848,7 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
U
=
LayerNormParamType
<
T
>
;
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
...
...
@@ -854,12 +862,10 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
const
auto
x_dims
=
x
->
dims
();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
mean_data
=
mean
->
mutable_data
<
LayerNormParamType
<
T
>>
(
ctx
.
GetPlace
());
auto
*
var_data
=
var
->
mutable_data
<
LayerNormParamType
<
T
>>
(
ctx
.
GetPlace
());
auto
*
scale_data
=
(
scale
==
nullptr
?
nullptr
:
scale
->
data
<
LayerNormParamType
<
T
>>
());
auto
*
bias_data
=
(
bias
==
nullptr
?
nullptr
:
bias
->
data
<
LayerNormParamType
<
T
>>
());
auto
*
mean_data
=
mean
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
var_data
=
var
->
mutable_data
<
U
>
(
ctx
.
GetPlace
());
auto
*
scale_data
=
(
scale
==
nullptr
?
nullptr
:
scale
->
data
<
U
>
());
auto
*
bias_data
=
(
bias
==
nullptr
?
nullptr
:
bias
->
data
<
U
>
());
auto
matrix_dim
=
framework
::
flatten_to_2d
(
x_dims
,
begin_norm_axis
);
int
batch_size
=
static_cast
<
int
>
(
matrix_dim
[
0
]);
...
...
@@ -869,7 +875,7 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
switch
(
GetDesiredBlockDim
(
feature_size
))
{
FIXED_BLOCK_DIM_CASE
(
LayerNormForward
<
T
,
LayerNormParamType
<
T
>
,
LayerNormForward
<
T
,
U
,
kBlockDim
><<<
batch_size
,
kBlockDim
,
0
,
stream
>>>
(
x_data
,
scale_data
,
bias_data
,
y_data
,
mean_data
,
var_data
,
epsilon
,
feature_size
));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录