Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f57d706a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f57d706a
编写于
9月 05, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use double to reduce
上级
f94fdeaa
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
7 addition
and
7 deletion
+7
-7
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+7
-7
未找到文件。
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
f57d706a
...
@@ -67,27 +67,27 @@ template <typename T, int BlockDim>
...
@@ -67,27 +67,27 @@ template <typename T, int BlockDim>
__global__
void
LayerNormForward
(
const
T
*
x
,
const
T
*
scale
,
const
T
*
bias
,
__global__
void
LayerNormForward
(
const
T
*
x
,
const
T
*
scale
,
const
T
*
bias
,
T
*
y
,
T
*
mean
,
T
*
var
,
float
epsilon
,
T
*
y
,
T
*
mean
,
T
*
var
,
float
epsilon
,
int
feature_size
)
{
int
feature_size
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
PairForLayerNorm
<
T
>
,
BlockDim
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
PairForLayerNorm
<
double
>
,
BlockDim
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
beg_idx
=
blockIdx
.
x
*
feature_size
+
threadIdx
.
x
;
int
beg_idx
=
blockIdx
.
x
*
feature_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
feature_size
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
feature_size
;
// Step 1: Reduce to calculate mean and var
// Step 1: Reduce to calculate mean and var
T
mean_val
=
static_cast
<
T
>
(
0
)
;
double
mean_val
=
0
;
T
var_val
=
static_cast
<
T
>
(
0
)
;
double
var_val
=
0
;
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
T
tmp
=
x
[
i
];
T
tmp
=
x
[
i
];
mean_val
+=
tmp
;
mean_val
+=
tmp
;
var_val
+=
(
tmp
*
tmp
);
var_val
+=
(
tmp
*
tmp
);
}
}
auto
pair
=
BlockReduce
(
temp_storage
)
auto
pair
=
BlockReduce
(
temp_storage
)
.
Reduce
(
PairForLayerNorm
<
T
>
(
mean_val
,
var_val
),
.
Reduce
(
PairForLayerNorm
<
double
>
(
mean_val
,
var_val
),
PairForLayerNormAddFunctor
<
T
>
());
PairForLayerNormAddFunctor
<
double
>
());
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
auto
tmp
=
pair
.
first_
/
feature_size
;
auto
tmp
=
pair
.
first_
/
feature_size
;
mean
[
blockIdx
.
x
]
=
tmp
;
mean
[
blockIdx
.
x
]
=
static_cast
<
T
>
(
tmp
)
;
var
[
blockIdx
.
x
]
=
pair
.
second_
/
feature_size
-
tmp
*
tmp
;
var
[
blockIdx
.
x
]
=
static_cast
<
T
>
(
pair
.
second_
/
feature_size
-
tmp
*
tmp
)
;
}
}
__syncthreads
();
__syncthreads
();
mean_val
=
mean
[
blockIdx
.
x
];
mean_val
=
mean
[
blockIdx
.
x
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录