Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
637f27c6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
637f27c6
编写于
7月 15, 2021
作者:
C
ceci3
提交者:
GitHub
7月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bn grad compute when x.stop_gradient=True (#34102)
* fix bn * fix * add unittest * fix cpu
上级
ff97dea4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
102 addition
and
53 deletion
+102
-53
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+52
-40
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+43
-13
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+7
-0
未找到文件。
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
637f27c6
...
...
@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"BatchNormGrad"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"BatchNormGrad"
);
const
bool
has_scale_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale"
));
const
bool
has_bias_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
));
const
bool
has_x_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
));
PADDLE_ENFORCE_EQ
((
has_scale_grad
==
has_bias_grad
),
true
,
platform
::
errors
::
NotFound
(
...
...
@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if
(
has_scale_grad
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale"
),
{
C
});
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
}
if
(
has_x_grad
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
}
framework
::
OpKernelType
BatchNormGradOp
::
GetExpectedKernelType
(
...
...
@@ -596,15 +596,20 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if
(
ctx
.
HasInput
(
"Y"
))
{
x
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
is_inplace
=
true
;
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
// if the input of batch norm is stop_gradient, d_x is null.
if
(
d_x
)
{
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
}
}
else
{
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
is_inplace
=
false
;
PADDLE_ENFORCE_NE
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
if
(
d_x
)
{
PADDLE_ENFORCE_NE
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
}
}
// Get the size for each dimension.
...
...
@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const
int
sample_size
=
x
->
numel
()
/
N
/
C
;
// init output
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
d_x
)
{
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
const
T
*
mean_data
=
saved_mean
->
data
<
T
>
();
const
T
*
inv_var_data
=
saved_inv_variance
->
data
<
T
>
();
...
...
@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr
.
setZero
();
}
if
((
N
*
sample_size
)
==
1
&&
!
use_global_stats
)
{
if
(
d_x
&&
(
N
*
sample_size
)
==
1
&&
!
use_global_stats
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
return
;
}
...
...
@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
sample_size
,
N
*
C
);
ConstEigenArrayMap
<
T
>
d_y_arr
(
d_y
->
data
<
T
>
(),
sample_size
,
N
*
C
);
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
N
*
C
);
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
...
...
@@ -734,19 +739,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
}
if
(
!
use_global_stats
)
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_x_arr
.
col
(
nc
)
=
scale_inv_var_nhw
(
c
)
*
(
d_y_arr
.
col
(
nc
)
*
N
*
sample_size
-
dy_sum_arr
(
c
)
-
(
x_arr
.
col
(
nc
)
-
mean_arr
[
c
])
*
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
c
)
*
inv_var_arr
(
c
));
}
}
else
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_x_arr
.
col
(
nc
)
=
scale_inv_var_nhw
(
c
)
*
d_y_arr
.
col
(
nc
);
if
(
d_x
)
{
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
N
*
C
);
if
(
!
use_global_stats
)
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_x_arr
.
col
(
nc
)
=
scale_inv_var_nhw
(
c
)
*
(
d_y_arr
.
col
(
nc
)
*
N
*
sample_size
-
dy_sum_arr
(
c
)
-
(
x_arr
.
col
(
nc
)
-
mean_arr
[
c
])
*
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
c
)
*
inv_var_arr
(
c
));
}
}
else
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_x_arr
.
col
(
nc
)
=
scale_inv_var_nhw
(
c
)
*
d_y_arr
.
col
(
nc
);
}
}
}
break
;
...
...
@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
C
,
N
*
sample_size
);
ConstEigenArrayMap
<
T
>
d_y_arr
(
d_y
->
data
<
T
>
(),
C
,
N
*
sample_size
);
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
N
*
sample_size
);
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
dy_sum_arr
+=
d_y_arr
.
col
(
nhw
);
...
...
@@ -779,17 +787,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
}
if
(
!
use_global_stats
)
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_x_arr
.
col
(
nhw
)
=
scale_inv_var_nhw
*
(
d_y_arr
.
col
(
nhw
)
*
N
*
sample_size
-
dy_sum_arr
-
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
dy_mul_x_sub_mean_mul_invstd_sum_arr
*
inv_var_arr
);
}
}
else
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_x_arr
.
col
(
nhw
)
=
scale_inv_var_nhw
*
d_y_arr
.
col
(
nhw
);
if
(
d_x
)
{
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
N
*
sample_size
);
if
(
!
use_global_stats
)
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_x_arr
.
col
(
nhw
)
=
scale_inv_var_nhw
*
(
d_y_arr
.
col
(
nhw
)
*
N
*
sample_size
-
dy_sum_arr
-
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
dy_mul_x_sub_mean_mul_invstd_sum_arr
*
inv_var_arr
);
}
}
else
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_x_arr
.
col
(
nhw
)
=
scale_inv_var_nhw
*
d_y_arr
.
col
(
nhw
);
}
}
}
break
;
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
637f27c6
...
...
@@ -840,15 +840,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if
(
ctx
.
HasInput
(
"Y"
))
{
x
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
is_inplace
=
true
;
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
if
(
d_x
)
{
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
}
}
else
{
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
is_inplace
=
false
;
PADDLE_ENFORCE_NE
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
if
(
d_x
)
{
PADDLE_ENFORCE_NE
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
}
}
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
// init output
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
d_x
)
{
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
if
(
d_scale
&&
d_bias
)
{
d_scale
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
...
...
@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_d_y
(
d_y
->
type
());
Tensor
transformed_d_x
(
d_x
->
type
())
;
Tensor
transformed_d_x
;
if
(
data_layout
==
DataLayout
::
kNHWC
&&
compute_format
==
DataLayout
::
kNCHW
)
{
VLOG
(
3
)
<<
"Transform input tensor from NHWC to NCHW."
;
...
...
@@ -920,12 +926,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
&
transformed_d_y
);
TransToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_y
,
&
transformed_d_y
);
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_x
,
&
transformed_d_x
);
if
(
d_x
)
{
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_x
,
&
transformed_d_x
);
}
}
else
{
transformed_x
.
ShareDataWith
(
*
x
);
transformed_d_y
.
ShareDataWith
(
*
d_y
);
transformed_d_x
.
ShareDataWith
(
*
d_x
);
if
(
d_x
)
{
transformed_d_x
.
ShareDataWith
(
*
d_x
);
}
}
std
::
vector
<
int
>
dims
;
...
...
@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if
(
!
use_global_stats
)
{
if
((
N
*
H
*
W
*
D
)
==
1
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
if
(
d_x
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
}
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
BatchNormParamType
<
T
>>
functor
;
functor
(
dev_ctx
,
d_scale
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
...
...
@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
// This branch calls CUDNN APIs
if
(
d_scale
&&
d_bias
)
{
if
(
d_
x
&&
d_
scale
&&
d_bias
)
{
bool
called
=
false
;
#if CUDNN_VERSION_MIN(7, 4, 1)
called
=
true
;
...
...
@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
}
if
(
d_scale
&&
d_bias
)
{
KeBNBackwardScaleBias
<
T
,
block
,
framework
::
DataLayout
::
kNCHW
><<<
grid2
,
block
,
0
,
stream
>>>
(
d_y
->
data
<
T
>
(),
x
->
data
<
T
>
(),
saved_mean_data
,
saved_var_data
,
epsilon
,
N
,
C
,
H
*
W
*
D
,
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
}
else
{
if
(
d_x
)
{
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNHWC
><<<
...
...
@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
}
if
(
d_scale
&&
d_bias
)
{
KeBNBackwardScaleBias
<
T
,
block
,
framework
::
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
stream
>>>
(
d_y
->
data
<
T
>
(),
x
->
data
<
T
>
(),
saved_mean_data
,
saved_var_data
,
epsilon
,
N
,
C
,
H
*
W
*
D
,
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
637f27c6
...
...
@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
os
.
environ
[
'FLAGS_cudnn_batchnorm_spatial_persistent'
]
=
"1"
class
TestBatchNormOpTrainingCase3
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
False
self
.
no_grad_set
=
set
([
'x@GRAD'
])
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'scale@GRAD'
,
'bias@GRAD'
]
class
TestBatchNormOpTrainingMomentumVariable
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_momentum_variable
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录