Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a0aff194
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a0aff194
编写于
4月 24, 2023
作者:
Z
Zhang Zheng
提交者:
GitHub
4月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the calculation of layer_norm_bwd (#53224)
* Fix the calculation of layer_norm_bwd * fix
上级
bfa5d6b8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
15 deletion
+14
-15
paddle/phi/kernels/funcs/layer_norm_impl.cu.h
paddle/phi/kernels/funcs/layer_norm_impl.cu.h
+14
-15
未找到文件。
paddle/phi/kernels/funcs/layer_norm_impl.cu.h
浏览文件 @
a0aff194
...
@@ -1603,13 +1603,13 @@ __global__ void LayerNormBackwardGradientAll(
...
@@ -1603,13 +1603,13 @@ __global__ void LayerNormBackwardGradientAll(
for
(
int64_t
i
=
beg_idx
;
i
<
end_idx
;
i
+=
stride
)
{
for
(
int64_t
i
=
beg_idx
;
i
<
end_idx
;
i
+=
stride
)
{
int
row_idx
=
i
/
feature_size
;
int
row_idx
=
i
/
feature_size
;
auto
var_val
=
r
eal_sqrt
(
static_cast
<
U
>
(
var
[
row_idx
])
+
epsilon
);
auto
var_val
=
r
sqrt_
(
static_cast
<
U
>
(
var
[
row_idx
])
+
epsilon
);
d_scale_partial
+=
static_cast
<
U
>
(
d_y
[
i
])
*
d_scale_partial
+=
static_cast
<
U
>
(
d_y
[
i
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean
[
row_idx
])
/
var_val
;
(
static_cast
<
U
>
(
x
[
i
])
-
mean
[
row_idx
])
*
var_val
;
d_bias_partial
+=
static_cast
<
U
>
(
d_y
[
i
]);
d_bias_partial
+=
static_cast
<
U
>
(
d_y
[
i
]);
if
(
HasDx
)
{
if
(
HasDx
)
{
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
static_cast
<
U
>
(
scale
[
blockIdx
.
x
+
col_offset
])
/
static_cast
<
U
>
(
scale
[
blockIdx
.
x
+
col_offset
])
*
var_val
);
var_val
);
}
}
}
}
...
@@ -1659,10 +1659,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
...
@@ -1659,10 +1659,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
for
(
int64_t
i
=
beg_idx
;
i
<
end_idx
;
i
+=
stride
)
{
for
(
int64_t
i
=
beg_idx
;
i
<
end_idx
;
i
+=
stride
)
{
int
row_idx
=
i
/
feature_size
;
int
row_idx
=
i
/
feature_size
;
auto
var_val
=
auto
var_val
=
static_cast
<
U
>
(
r
eal_sqrt
(
static_cast
<
float
>
(
var
[
row_idx
])
+
epsilon
));
static_cast
<
U
>
(
r
sqrt_
(
static_cast
<
float
>
(
var
[
row_idx
])
+
epsilon
));
if
(
HasDScale
)
{
if
(
HasDScale
)
{
d_scale_or_d_bias_partial
+=
static_cast
<
U
>
(
d_y
[
i
])
*
d_scale_or_d_bias_partial
+=
static_cast
<
U
>
(
d_y
[
i
])
*
(
static_cast
<
U
>
(
x
[
i
])
-
mean
[
row_idx
])
/
(
static_cast
<
U
>
(
x
[
i
])
-
mean
[
row_idx
])
*
var_val
;
var_val
;
}
else
{
// d_bias != nullptr
}
else
{
// d_bias != nullptr
d_scale_or_d_bias_partial
+=
static_cast
<
U
>
(
d_y
[
i
]);
d_scale_or_d_bias_partial
+=
static_cast
<
U
>
(
d_y
[
i
]);
...
@@ -1671,10 +1671,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
...
@@ -1671,10 +1671,10 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if
(
HasDx
)
{
if
(
HasDx
)
{
if
(
scale
!=
nullptr
)
{
if
(
scale
!=
nullptr
)
{
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
static_cast
<
U
>
(
scale
[
blockIdx
.
x
+
col_offset
])
/
static_cast
<
U
>
(
scale
[
blockIdx
.
x
+
col_offset
])
*
var_val
);
var_val
);
}
else
{
}
else
{
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
/
var_val
);
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
var_val
);
}
}
}
}
}
}
...
@@ -1762,13 +1762,13 @@ __global__ void LayerNormBackwardGradientOnlyDX(
...
@@ -1762,13 +1762,13 @@ __global__ void LayerNormBackwardGradientOnlyDX(
U
d_x_mean_partial
=
static_cast
<
U
>
(
0
),
d_x_var_partial
=
static_cast
<
U
>
(
0
);
U
d_x_mean_partial
=
static_cast
<
U
>
(
0
),
d_x_var_partial
=
static_cast
<
U
>
(
0
);
for
(
int64_t
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
for
(
int64_t
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
auto
var_val
=
auto
var_val
=
static_cast
<
U
>
(
r
eal_sqrt
(
static_cast
<
float
>
(
block_var
)
+
epsilon
));
static_cast
<
U
>
(
r
sqrt_
(
static_cast
<
float
>
(
block_var
)
+
epsilon
));
if
(
scale
!=
nullptr
)
{
if
(
scale
!=
nullptr
)
{
int
col_idx
=
i
%
feature_size
;
int
col_idx
=
i
%
feature_size
;
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
static_cast
<
U
>
(
scale
[
col_idx
])
/
var_val
);
static_cast
<
U
>
(
scale
[
col_idx
])
*
var_val
);
}
else
{
}
else
{
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
/
var_val
);
d_x
[
i
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
i
])
*
var_val
);
}
}
d_x_mean_partial
+=
static_cast
<
U
>
(
d_x
[
i
]);
d_x_mean_partial
+=
static_cast
<
U
>
(
d_x
[
i
]);
d_x_var_partial
+=
d_x_var_partial
+=
...
@@ -1812,21 +1812,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
...
@@ -1812,21 +1812,20 @@ __global__ void LayerNormBackwardWhenBatchSizeIsOne(
int64_t
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int64_t
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
using
ScaleBiasT
=
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
;
using
ScaleBiasT
=
LayerNormScaleBiasT
<
T
,
U
,
ScaleBiasWithSameTypeX
>
;
if
(
idx
<
feature_size
)
{
if
(
idx
<
feature_size
)
{
auto
var_val
=
auto
var_val
=
static_cast
<
U
>
(
rsqrt_
(
static_cast
<
float
>
(
var
[
0
])
+
epsilon
));
static_cast
<
U
>
(
real_sqrt
(
static_cast
<
float
>
(
var
[
0
])
+
epsilon
));
if
(
d_x
!=
nullptr
)
{
if
(
d_x
!=
nullptr
)
{
if
(
d_scale
==
nullptr
)
{
if
(
d_scale
==
nullptr
)
{
d_x
[
idx
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
idx
])
/
var_val
);
d_x
[
idx
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
idx
])
*
var_val
);
}
else
{
}
else
{
d_x
[
idx
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
idx
])
*
d_x
[
idx
]
=
static_cast
<
T
>
(
static_cast
<
U
>
(
d_y
[
idx
])
*
static_cast
<
U
>
(
scale
[
idx
])
/
var_val
);
static_cast
<
U
>
(
scale
[
idx
])
*
var_val
);
}
}
}
}
if
(
d_scale
!=
nullptr
)
{
if
(
d_scale
!=
nullptr
)
{
d_scale
[
idx
]
=
d_scale
[
idx
]
=
static_cast
<
ScaleBiasT
>
(
static_cast
<
U
>
(
d_y
[
idx
])
*
static_cast
<
ScaleBiasT
>
(
static_cast
<
U
>
(
d_y
[
idx
])
*
(
static_cast
<
U
>
(
x
[
idx
])
-
mean
[
0
])
/
var_val
);
(
static_cast
<
U
>
(
x
[
idx
])
-
mean
[
0
])
*
var_val
);
}
}
if
(
d_bias
!=
nullptr
)
{
if
(
d_bias
!=
nullptr
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录