Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7d4002e0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7d4002e0
编写于
4月 20, 2020
作者:
M
mapingshuo
提交者:
GitHub
4月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
restrict block num of layer_norm_grad cuda block to 128 (#23878)
restrict block num of layer_norm_grad cuda kernel to 128, test=develop
上级
20b1b080
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
214 addition
and
78 deletion
+214
-78
paddle/fluid/operators/layer_norm_op.cc
paddle/fluid/operators/layer_norm_op.cc
+7
-2
paddle/fluid/operators/layer_norm_op.cu
paddle/fluid/operators/layer_norm_op.cu
+80
-38
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+1
-0
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
+126
-38
未找到文件。
paddle/fluid/operators/layer_norm_op.cc
浏览文件 @
7d4002e0
...
...
@@ -141,7 +141,7 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
ctx
->
GetInputDim
(
"
Scale
"
));
ctx
->
GetInputDim
(
"
Bias
"
));
}
}
...
...
@@ -182,6 +182,7 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
}
if
(
this
->
HasInput
(
"Bias"
))
{
op
->
SetInput
(
"Bias"
,
this
->
Input
(
"Bias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Bias"
),
this
->
InputGrad
(
"Bias"
));
}
...
...
@@ -191,6 +192,9 @@ class LayerNormGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER
(
LayerNormGradNoNeedBufferVarInference
,
"Bias"
);
}
// namespace operators
}
// namespace paddle
...
...
@@ -198,7 +202,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR
(
layer_norm
,
ops
::
LayerNormOp
,
ops
::
LayerNormOpMaker
,
ops
::
LayerNormGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
LayerNormGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
layer_norm_grad
,
ops
::
LayerNormGradOp
);
REGISTER_OPERATOR
(
layer_norm_grad
,
ops
::
LayerNormGradOp
,
ops
::
LayerNormGradNoNeedBufferVarInference
);
REGISTER_OP_CPU_KERNEL
(
layer_norm
,
ops
::
LayerNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
LayerNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
...
...
paddle/fluid/operators/layer_norm_op.cu
浏览文件 @
7d4002e0
...
...
@@ -45,6 +45,37 @@ inline static int GetDesiredBlockDim(int block_dim) {
FIXED_BLOCK_DIM_CASE_BASE(2, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_CASE_BASE(1, ##__VA_ARGS__)
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE( \
log2_block_dim, feature_size, kMaxBlockNum, ...) \
case (1 << (log2_block_dim)): { \
for (int i = 0; i < std::ceil(feature_size / (1.0 * kMaxBlockNum)); i++) { \
int col_offset = i * kMaxBlockNum; \
int block_num = std::min(feature_size - col_offset, kMaxBlockNum); \
constexpr auto kBlockDim = (1 << (log2_block_dim)); \
__VA_ARGS__; \
} \
} break
#define FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(feature_size, kMaxBlockNum, ...) \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(9, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(8, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(7, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(6, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(5, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(4, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(3, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(2, feature_size, kMaxBlockNum, \
##__VA_ARGS__); \
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE(1, feature_size, kMaxBlockNum, \
##__VA_ARGS__)
static
__device__
__forceinline__
float
real_sqrt
(
float
x
)
{
return
sqrtf
(
x
);
}
static
__device__
__forceinline__
double
real_sqrt
(
double
x
)
{
return
sqrt
(
x
);
}
...
...
@@ -131,12 +162,13 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
T
*
d_scale
,
T
*
d_bias
,
T
*
d_x
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
scale
,
float
epsilon
,
int
batch_size
,
int
feature_size
)
{
int
batch_size
,
int
feature_size
,
int
col_offset
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
PairForLayerNorm
<
T
>
,
BlockDim
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
beg_idx
=
threadIdx
.
x
*
feature_size
+
blockIdx
.
x
;
int
end_idx
=
batch_size
*
feature_size
+
blockIdx
.
x
;
int
beg_idx
=
threadIdx
.
x
*
feature_size
+
(
blockIdx
.
x
+
col_offset
)
;
int
end_idx
=
batch_size
*
feature_size
+
(
blockIdx
.
x
+
col_offset
)
;
int
stride
=
BlockDim
*
feature_size
;
T
d_scale_partial
=
0
,
d_bias_partial
=
0
;
...
...
@@ -147,7 +179,7 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
d_scale_partial
+=
d_y
[
i
]
*
(
x
[
i
]
-
mean
[
row_idx
])
/
var_val
;
d_bias_partial
+=
d_y
[
i
];
if
(
HasDx
)
{
d_x
[
i
]
=
d_y
[
i
]
*
scale
[
blockIdx
.
x
]
/
var_val
;
d_x
[
i
]
=
d_y
[
i
]
*
scale
[
blockIdx
.
x
+
col_offset
]
/
var_val
;
}
}
...
...
@@ -156,8 +188,8 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
PairForLayerNormAddFunctor
<
T
>
());
if
(
threadIdx
.
x
==
0
)
{
d_scale
[
blockIdx
.
x
]
=
pair
.
first_
;
d_bias
[
blockIdx
.
x
]
=
pair
.
second_
;
d_scale
[
blockIdx
.
x
+
col_offset
]
=
pair
.
first_
;
d_bias
[
blockIdx
.
x
+
col_offset
]
=
pair
.
second_
;
}
}
...
...
@@ -168,11 +200,11 @@ template <typename T, int BlockDim, bool HasDx, bool HasDScale>
__global__
void
LayerNormBackwardGradientScaleOrBias
(
const
T
*
x
,
const
T
*
d_y
,
T
*
d_scale
,
T
*
d_bias
,
T
*
d_x
,
const
T
*
mean
,
const
T
*
var
,
const
T
*
scale
,
float
epsilon
,
int
batch_size
,
int
feature_size
)
{
int
feature_size
,
int
col_offset
)
{
using
BlockReduce
=
cub
::
BlockReduce
<
T
,
BlockDim
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
int
beg_idx
=
threadIdx
.
x
*
feature_size
+
blockIdx
.
x
;
int
end_idx
=
batch_size
*
feature_size
+
blockIdx
.
x
;
int
beg_idx
=
threadIdx
.
x
*
feature_size
+
blockIdx
.
x
+
col_offset
;
int
end_idx
=
batch_size
*
feature_size
+
blockIdx
.
x
+
col_offset
;
int
stride
=
BlockDim
*
feature_size
;
T
d_scale_or_d_bias_partial
=
0
;
...
...
@@ -187,7 +219,7 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if
(
HasDx
)
{
if
(
scale
!=
nullptr
)
{
d_x
[
i
]
=
d_y
[
i
]
*
scale
[
blockIdx
.
x
]
/
var_val
;
d_x
[
i
]
=
d_y
[
i
]
*
scale
[
blockIdx
.
x
+
col_offset
]
/
var_val
;
}
else
{
d_x
[
i
]
=
d_y
[
i
]
/
var_val
;
}
...
...
@@ -199,9 +231,9 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
if
(
threadIdx
.
x
==
0
)
{
if
(
HasDScale
)
{
d_scale
[
blockIdx
.
x
]
=
d_scale_or_d_bias_partial
;
d_scale
[
blockIdx
.
x
+
col_offset
]
=
d_scale_or_d_bias_partial
;
}
else
{
d_bias
[
blockIdx
.
x
]
=
d_scale_or_d_bias_partial
;
d_bias
[
blockIdx
.
x
+
col_offset
]
=
d_scale_or_d_bias_partial
;
}
}
}
...
...
@@ -322,6 +354,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
T
*
d_bias
,
float
epsilon
,
int
batch_size
,
int
feature_size
,
cudaStream_t
stream
)
{
const
int
kMaxBlockDim
=
512
;
const
int
kMaxBlockNum
=
128
;
int
gradient_flag
=
((
d_x
!=
nullptr
?
1
:
0
)
<<
2
)
|
((
d_scale
!=
nullptr
?
1
:
0
)
<<
1
)
|
((
d_bias
!=
nullptr
?
1
:
0
));
...
...
@@ -347,29 +380,33 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
switch
(
gradient_flag
)
{
case
1
:
// d_x == nulptr, d_scale == nullptr, d_bias != nullptr
switch
(
block_dim
)
{
FIXED_BLOCK_DIM_CASE
(
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
false
,
false
><<<
feature_size
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
));
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
(
feature_size
,
kMaxBlockNum
,
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
false
,
false
><<<
block_num
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
,
col_offset
));
}
break
;
case
2
:
// d_x == nullptr, d_scale != nullptr, d_bias == nullptr
switch
(
block_dim
)
{
FIXED_BLOCK_DIM_CASE
(
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
false
,
true
><<<
feature_size
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
));
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
(
feature_size
,
kMaxBlockNum
,
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
false
,
true
><<<
block_num
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
,
col_offset
));
}
break
;
case
3
:
// d_x == nullptr, d_scale != nulptr, d_bias != nullptr
switch
(
block_dim
)
{
FIXED_BLOCK_DIM_CASE
(
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
(
feature_size
,
kMaxBlockNum
,
LayerNormBackwardGradientAll
<
T
,
kBlockDim
,
false
><<<
feature_size
,
kBlockDim
,
0
,
stream
>>>
(
T
,
kBlockDim
,
false
><<<
block_num
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
));
batch_size
,
feature_size
,
col_offset
));
}
break
;
case
4
:
// d_x != nullptr, d_scale == nullptr, d_bias == nullptr
...
...
@@ -382,11 +419,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
break
;
case
5
:
// d_x != nulptr, d_scale == nullptr, d_bias != nullptr
switch
(
block_dim
)
{
FIXED_BLOCK_DIM_CASE
(
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
true
,
false
><<<
feature_size
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
));
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
(
feature_size
,
kMaxBlockNum
,
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
true
,
false
><<<
block_num
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
,
col_offset
));
}
switch
(
GetDesiredBlockDim
(
feature_size
))
{
FIXED_BLOCK_DIM_CASE
(
...
...
@@ -397,11 +435,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
break
;
case
6
:
// d_x != nullptr, d_scale != nullptr, d_bias == nullptr
switch
(
block_dim
)
{
FIXED_BLOCK_DIM_CASE
(
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
true
,
true
><<<
feature_size
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
));
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
(
feature_size
,
kMaxBlockNum
,
LayerNormBackwardGradientScaleOrBias
<
T
,
kBlockDim
,
true
,
true
><<<
block_num
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
,
col_offset
));
}
switch
(
GetDesiredBlockDim
(
feature_size
))
{
FIXED_BLOCK_DIM_CASE
(
...
...
@@ -412,11 +451,12 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
break
;
case
7
:
// d_x != nullptr, d_scale != nullptr, d_bias != nullptr
switch
(
block_dim
)
{
FIXED_BLOCK_DIM_CASE
(
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
(
feature_size
,
kMaxBlockNum
,
LayerNormBackwardGradientAll
<
T
,
kBlockDim
,
true
><<<
feature_size
,
kBlockDim
,
0
,
stream
>>>
(
T
,
kBlockDim
,
true
><<<
block_num
,
kBlockDim
,
0
,
stream
>>>
(
x
,
d_y
,
d_scale
,
d_bias
,
d_x
,
mean
,
var
,
scale
,
epsilon
,
batch_size
,
feature_size
));
batch_size
,
feature_size
,
col_offset
));
}
switch
(
GetDesiredBlockDim
(
feature_size
))
{
FIXED_BLOCK_DIM_CASE
(
...
...
@@ -539,6 +579,8 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
}
};
template
class
LayerNormDirectCUDAFunctor
<
float
>;
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
#undef FIXED_BLOCK_DIM_CASE_BASE
#undef FIXED_BLOCK_DIM_CASE
}
// namespace operators
...
...
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
7d4002e0
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
python/paddle/fluid/tests/unittests/test_layer_norm_op.py
浏览文件 @
7d4002e0
...
...
@@ -36,42 +36,72 @@ def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
mean
=
np
.
mean
(
x
,
axis
=
1
)
var
=
np
.
var
(
x
,
axis
=
1
)
+
epsilon
output
=
scale
.
reshape
([
1
,
D
])
*
np
.
divide
(
(
x
-
mean
.
reshape
([
N
,
1
])),
(
np
.
sqrt
(
var
)).
reshape
([
N
,
1
]))
+
beta
.
reshape
([
1
,
D
])
output
=
np
.
divide
((
x
-
mean
.
reshape
([
N
,
1
])),
(
np
.
sqrt
(
var
)).
reshape
([
N
,
1
]))
if
scale
is
not
None
:
output
=
scale
.
reshape
([
1
,
D
])
*
output
if
beta
is
not
None
:
output
=
output
+
beta
.
reshape
([
1
,
D
])
x
.
shape
,
output
.
shape
=
x_shape
,
x_shape
return
output
,
mean
,
var
def
_reference_layer_norm_grad
(
x
,
grad_y
,
scale
,
mean
,
var
,
begin_norm_axis
=
1
):
def
_reference_layer_norm_grad
(
x
,
grad_y
,
scale
,
bias
,
mean
,
var
,
begin_norm_axis
=
1
):
x_shape
=
x
.
shape
scale_shape
=
scale
.
shape
N
=
reduce
(
mul
,
x_shape
[
0
:
begin_norm_axis
],
1
)
D
=
reduce
(
mul
,
x_shape
[
begin_norm_axis
:
len
(
x_shape
)],
1
)
if
scale
is
not
None
:
scale_shape
=
scale
.
shape
scale
.
shape
=
[
1
,
D
]
x
.
shape
,
grad_y
.
shape
=
[
N
,
D
],
[
N
,
D
]
var
.
shape
,
mean
.
shape
=
[
N
,
1
],
[
N
,
1
]
scale
.
shape
=
[
1
,
D
]
# d_bias
d_bias
=
np
.
sum
(
grad_y
,
axis
=
0
).
reshape
([
1
,
D
])
if
bias
is
not
None
:
d_bias
=
np
.
sum
(
grad_y
,
axis
=
0
).
reshape
([
1
,
D
])
else
:
d_bias
=
None
# d_scale
d_scale
=
np
.
sum
(((
x
-
mean
)
*
np
.
sqrt
(
1
/
var
))
*
grad_y
,
axis
=
0
).
reshape
([
1
,
D
])
if
scale
is
not
None
:
d_scale
=
np
.
sum
(((
x
-
mean
)
*
np
.
sqrt
(
1
/
var
))
*
grad_y
,
axis
=
0
).
reshape
([
1
,
D
])
else
:
d_scale
=
None
# dx
dx_end
=
scale
*
np
.
sqrt
(
1.0
/
var
)
*
grad_y
d_mean_0
=
np
.
sum
(
-
np
.
sqrt
(
1.0
/
var
)
*
grad_y
*
scale
,
axis
=
1
).
reshape
(
[
N
,
1
])
# the second part equals to zero.
d_mean
=
1.0
/
D
*
d_mean_0
d_std
=
np
.
sum
(
-
(
1.0
/
var
)
*
(
x
-
mean
)
*
grad_y
*
scale
,
axis
=
1
).
reshape
([
N
,
1
])
*
(
1.0
/
D
*
np
.
sqrt
(
1.0
/
var
).
reshape
([
N
,
1
])
*
(
x
-
mean
))
if
scale
is
not
None
:
dx_end
=
scale
*
np
.
sqrt
(
1.0
/
var
)
*
grad_y
d_mean_0
=
np
.
sum
(
-
np
.
sqrt
(
1.0
/
var
)
*
grad_y
*
scale
,
axis
=
1
).
reshape
(
[
N
,
1
])
# the second part equals to zero.
d_mean
=
1.0
/
D
*
d_mean_0
d_std
=
np
.
sum
(
-
(
1.0
/
var
)
*
(
x
-
mean
)
*
grad_y
*
scale
,
axis
=
1
).
reshape
([
N
,
1
])
*
(
1.0
/
D
*
np
.
sqrt
(
1.0
/
var
).
reshape
([
N
,
1
])
*
(
x
-
mean
))
else
:
dx_end
=
1.0
*
np
.
sqrt
(
1.0
/
var
)
*
grad_y
d_mean_0
=
np
.
sum
(
-
np
.
sqrt
(
1.0
/
var
)
*
grad_y
*
1.0
,
axis
=
1
).
reshape
(
[
N
,
1
])
# the second part equals to zero.
d_mean
=
1.0
/
D
*
d_mean_0
d_std
=
np
.
sum
(
-
(
1.0
/
var
)
*
(
x
-
mean
)
*
grad_y
*
1.0
,
axis
=
1
).
reshape
([
N
,
1
])
*
(
1.0
/
D
*
np
.
sqrt
(
1.0
/
var
).
reshape
([
N
,
1
])
*
(
x
-
mean
))
grad_x
=
dx_end
+
d_mean
+
d_std
grad_x
.
shape
,
x
.
shape
,
grad_y
.
shape
=
x_shape
,
x_shape
,
x_shape
scale
.
shape
=
scale_shape
var
.
shape
,
mean
.
shape
=
[
N
,
],
[
N
,
]
if
scale
is
not
None
:
scale
.
shape
=
scale_shape
return
grad_x
,
d_scale
,
d_bias
...
...
@@ -82,7 +112,12 @@ class TestLayerNormOp(unittest.TestCase):
def
__assert_close
(
self
,
tensor
,
np_array
,
msg
,
atol
=
1e-4
):
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
tensor
),
np_array
,
atol
=
atol
),
msg
)
def
check_forward_backward
(
self
,
shape
,
begin_norm_axis
):
def
check_forward_backward
(
self
,
shape
,
begin_norm_axis
,
has_scale
=
True
,
has_bias
=
True
,
y_grad_scale
=
1.0
):
def
test_with_place
(
place
,
shape
,
begin_norm_axis
):
# attr
epsilon
=
0.00001
...
...
@@ -92,21 +127,26 @@ class TestLayerNormOp(unittest.TestCase):
np
.
random
.
seed
(
123
)
x
=
np
.
random
.
random_sample
(
x_shape
).
astype
(
np
.
float32
)
scale
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
bias
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
y_grad
=
np
.
random
.
random_sample
(
x_shape
).
astype
(
np
.
float32
)
scale
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
if
has_scale
else
None
bias
=
np
.
random
.
random_sample
(
scale_shape
).
astype
(
np
.
float32
)
if
has_bias
else
None
y_grad
=
(
np
.
random
.
random_sample
(
x_shape
)
*
y_grad_scale
).
astype
(
np
.
float32
)
# reference forward & backward
y
,
mean
,
variance
=
_reference_layer_norm_naive
(
x
,
scale
,
bias
,
epsilon
,
begin_norm_axis
)
x_grad
,
scale_grad
,
bias_grad
=
_reference_layer_norm_grad
(
x
,
y_grad
,
scale
,
mean
,
variance
,
begin_norm_axis
)
x
,
y_grad
,
scale
,
bias
,
mean
,
variance
,
begin_norm_axis
)
var_dict
=
locals
()
var_dict
[
'y@GRAD'
]
=
y_grad
var_names
=
[
'x'
,
'scale'
,
'bias'
,
'mean'
,
'variance'
,
'y'
,
'y@GRAD'
]
var_names
=
[
'x'
,
'mean'
,
'variance'
,
'y'
,
'y@GRAD'
]
if
has_scale
:
var_names
+=
[
'scale'
]
if
has_bias
:
var_names
+=
[
'bias'
]
ground_truth
=
{
name
:
var_dict
[
name
]
for
name
in
var_names
}
program
=
fluid
.
Program
()
...
...
@@ -117,13 +157,22 @@ class TestLayerNormOp(unittest.TestCase):
name
=
name
,
dtype
=
'float32'
,
shape
=
ground_truth
[
name
].
shape
)
inputs
=
{
"X"
:
block
.
var
(
'x'
)}
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'x@GRAD'
,
]
if
has_scale
:
inputs
[
"Scale"
]
=
block
.
var
(
'scale'
)
fetch_list
+=
[
'scale@GRAD'
]
if
has_bias
:
inputs
[
"Bias"
]
=
block
.
var
(
'bias'
)
fetch_list
+=
[
'bias@GRAD'
]
layer_norm_op
=
block
.
append_op
(
type
=
"layer_norm"
,
inputs
=
{
"X"
:
block
.
var
(
'x'
),
"Scale"
:
block
.
var
(
'scale'
),
"Bias"
:
block
.
var
(
'bias'
),
},
inputs
=
inputs
,
outputs
=
{
"Y"
:
block
.
var
(
'y'
),
"Mean"
:
block
.
var
(
'mean'
),
# share the same memory
...
...
@@ -134,7 +183,6 @@ class TestLayerNormOp(unittest.TestCase):
"epsilon"
:
epsilon
,
"begin_norm_axis"
:
begin_norm_axis
})
# generate backward op_desc
grad_op_desc_list
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
layer_norm_op
.
desc
,
set
(),
[])
...
...
@@ -150,23 +198,25 @@ class TestLayerNormOp(unittest.TestCase):
grad_var
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
program
.
_sync_with_cpp
()
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
program
,
feed
=
{
name
:
var_dict
[
name
]
for
name
in
[
'x'
,
'scale'
,
'bias'
,
'y@GRAD'
]
},
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'x@GRAD'
,
'scale@GRAD'
,
'bias@GRAD'
])
fetch_list
=
fetch_list
)
self
.
__assert_close
(
y
,
out
[
0
],
"y"
)
self
.
__assert_close
(
mean
,
out
[
1
],
"mean"
)
self
.
__assert_close
(
variance
,
out
[
2
],
"variance"
,
1e-3
)
self
.
__assert_close
(
x_grad
,
out
[
3
],
"x_grad"
)
self
.
__assert_close
(
scale_grad
,
out
[
4
],
"scale_grad"
,
1e-3
)
self
.
__assert_close
(
bias_grad
,
out
[
5
],
"bias_grad"
)
if
has_scale
:
self
.
__assert_close
(
scale_grad
,
out
[
fetch_list
.
index
(
'scale@GRAD'
)],
"scale_grad"
,
1e-3
)
if
has_bias
:
self
.
__assert_close
(
bias_grad
,
out
[
fetch_list
.
index
(
'bias@GRAD'
)],
"bias_grad"
)
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
()
and
core
.
op_support_gpu
(
...
...
@@ -178,7 +228,45 @@ class TestLayerNormOp(unittest.TestCase):
def
test_check_forward_backward_with_scale_and_bias
(
self
):
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
,
has_scale
=
False
,
has_bias
=
True
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
,
has_scale
=
True
,
has_bias
=
False
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
1
,
has_scale
=
False
,
has_bias
=
False
)
self
.
check_forward_backward
(
shape
=
[
2
,
3
,
4
,
5
],
begin_norm_axis
=
3
)
self
.
check_forward_backward
(
shape
=
[
92
,
513
,
129
],
begin_norm_axis
=
2
,
y_grad_scale
=
0.1
)
self
.
check_forward_backward
(
shape
=
[
3
,
34
,
1134
],
begin_norm_axis
=
2
)
self
.
check_forward_backward
(
shape
=
[
92
,
513
,
1134
],
begin_norm_axis
=
2
,
y_grad_scale
=
0.1
)
self
.
check_forward_backward
(
shape
=
[
92
,
513
,
1134
],
begin_norm_axis
=
2
,
has_scale
=
False
,
has_bias
=
True
,
y_grad_scale
=
0.1
)
self
.
check_forward_backward
(
shape
=
[
92
,
513
,
1134
],
begin_norm_axis
=
2
,
has_scale
=
True
,
has_bias
=
False
,
y_grad_scale
=
0.1
)
self
.
check_forward_backward
(
shape
=
[
92
,
513
,
1134
],
begin_norm_axis
=
2
,
has_scale
=
False
,
has_bias
=
False
,
y_grad_scale
=
0.1
)
class
TestLayerNormAPI
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录