Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
637f27c6
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
637f27c6
编写于
7月 15, 2021
作者:
C
ceci3
提交者:
GitHub
7月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bn grad compute when x.stop_gradient=True (#34102)
* fix bn * fix * add unittest * fix cpu
上级
ff97dea4
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
102 addition
and
53 deletion
+102
-53
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+52
-40
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+43
-13
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+7
-0
未找到文件。
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
637f27c6
...
@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
...
@@ -464,11 +464,9 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
"BatchNormGrad"
);
"BatchNormGrad"
);
// check output
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"BatchNormGrad"
);
const
bool
has_scale_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale"
));
const
bool
has_scale_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale"
));
const
bool
has_bias_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
));
const
bool
has_bias_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
));
const
bool
has_x_grad
=
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
));
PADDLE_ENFORCE_EQ
((
has_scale_grad
==
has_bias_grad
),
true
,
PADDLE_ENFORCE_EQ
((
has_scale_grad
==
has_bias_grad
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
...
@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
...
@@ -496,12 +494,14 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
?
x_dims
[
1
]
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
:
x_dims
[
x_dims
.
size
()
-
1
]);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
if
(
has_scale_grad
)
{
if
(
has_scale_grad
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale"
),
{
C
});
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale"
),
{
C
});
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
}
}
if
(
has_x_grad
)
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
}
}
framework
::
OpKernelType
BatchNormGradOp
::
GetExpectedKernelType
(
framework
::
OpKernelType
BatchNormGradOp
::
GetExpectedKernelType
(
...
@@ -596,15 +596,20 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -596,15 +596,20 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
if
(
ctx
.
HasInput
(
"Y"
))
{
if
(
ctx
.
HasInput
(
"Y"
))
{
x
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
x
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
is_inplace
=
true
;
is_inplace
=
true
;
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
// if the input of batch norm is stop_gradient, d_x is null.
platform
::
errors
::
InvalidArgument
(
if
(
d_x
)
{
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
}
}
else
{
}
else
{
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
is_inplace
=
false
;
is_inplace
=
false
;
PADDLE_ENFORCE_NE
(
d_x
,
d_y
,
if
(
d_x
)
{
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NE
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
}
}
}
// Get the size for each dimension.
// Get the size for each dimension.
...
@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -629,7 +634,9 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const
int
sample_size
=
x
->
numel
()
/
N
/
C
;
const
int
sample_size
=
x
->
numel
()
/
N
/
C
;
// init output
// init output
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
d_x
)
{
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
const
T
*
mean_data
=
saved_mean
->
data
<
T
>
();
const
T
*
mean_data
=
saved_mean
->
data
<
T
>
();
const
T
*
inv_var_data
=
saved_inv_variance
->
data
<
T
>
();
const
T
*
inv_var_data
=
saved_inv_variance
->
data
<
T
>
();
...
@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -673,7 +680,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr
.
setZero
();
d_scale_arr
.
setZero
();
}
}
if
((
N
*
sample_size
)
==
1
&&
!
use_global_stats
)
{
if
(
d_x
&&
(
N
*
sample_size
)
==
1
&&
!
use_global_stats
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
return
;
return
;
}
}
...
@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -718,8 +725,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
}
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
sample_size
,
N
*
C
);
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
sample_size
,
N
*
C
);
ConstEigenArrayMap
<
T
>
d_y_arr
(
d_y
->
data
<
T
>
(),
sample_size
,
N
*
C
);
ConstEigenArrayMap
<
T
>
d_y_arr
(
d_y
->
data
<
T
>
(),
sample_size
,
N
*
C
);
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
sample_size
,
N
*
C
);
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
int
c
=
nc
%
C
;
...
@@ -734,19 +739,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -734,19 +739,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
}
}
if
(
!
use_global_stats
)
{
if
(
d_x
)
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
int
c
=
nc
%
C
;
sample_size
,
N
*
C
);
d_x_arr
.
col
(
nc
)
=
if
(
!
use_global_stats
)
{
scale_inv_var_nhw
(
c
)
*
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
(
d_y_arr
.
col
(
nc
)
*
N
*
sample_size
-
dy_sum_arr
(
c
)
-
int
c
=
nc
%
C
;
(
x_arr
.
col
(
nc
)
-
mean_arr
[
c
])
*
d_x_arr
.
col
(
nc
)
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
c
)
*
inv_var_arr
(
c
));
scale_inv_var_nhw
(
c
)
*
}
(
d_y_arr
.
col
(
nc
)
*
N
*
sample_size
-
dy_sum_arr
(
c
)
-
}
else
{
(
x_arr
.
col
(
nc
)
-
mean_arr
[
c
])
*
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
c
)
*
int
c
=
nc
%
C
;
inv_var_arr
(
c
));
d_x_arr
.
col
(
nc
)
=
scale_inv_var_nhw
(
c
)
*
d_y_arr
.
col
(
nc
);
}
}
else
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_x_arr
.
col
(
nc
)
=
scale_inv_var_nhw
(
c
)
*
d_y_arr
.
col
(
nc
);
}
}
}
}
}
break
;
break
;
...
@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -765,8 +775,6 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
}
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
C
,
N
*
sample_size
);
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
C
,
N
*
sample_size
);
ConstEigenArrayMap
<
T
>
d_y_arr
(
d_y
->
data
<
T
>
(),
C
,
N
*
sample_size
);
ConstEigenArrayMap
<
T
>
d_y_arr
(
d_y
->
data
<
T
>
(),
C
,
N
*
sample_size
);
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
N
*
sample_size
);
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
dy_sum_arr
+=
d_y_arr
.
col
(
nhw
);
dy_sum_arr
+=
d_y_arr
.
col
(
nhw
);
...
@@ -779,17 +787,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
...
@@ -779,17 +787,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
}
}
if
(
!
use_global_stats
)
{
if
(
d_x
)
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
EigenArrayMap
<
T
>
d_x_arr
(
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
d_x_arr
.
col
(
nhw
)
=
N
*
sample_size
);
scale_inv_var_nhw
*
if
(
!
use_global_stats
)
{
(
d_y_arr
.
col
(
nhw
)
*
N
*
sample_size
-
dy_sum_arr
-
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
d_x_arr
.
col
(
nhw
)
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
*
inv_var_arr
);
scale_inv_var_nhw
*
}
(
d_y_arr
.
col
(
nhw
)
*
N
*
sample_size
-
dy_sum_arr
-
}
else
{
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
dy_mul_x_sub_mean_mul_invstd_sum_arr
*
inv_var_arr
);
d_x_arr
.
col
(
nhw
)
=
scale_inv_var_nhw
*
d_y_arr
.
col
(
nhw
);
}
}
else
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_x_arr
.
col
(
nhw
)
=
scale_inv_var_nhw
*
d_y_arr
.
col
(
nhw
);
}
}
}
}
}
break
;
break
;
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
637f27c6
...
@@ -840,15 +840,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -840,15 +840,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if
(
ctx
.
HasInput
(
"Y"
))
{
if
(
ctx
.
HasInput
(
"Y"
))
{
x
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
x
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
is_inplace
=
true
;
is_inplace
=
true
;
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
if
(
d_x
)
{
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
d_x
,
d_y
,
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD not inplace in inplace mode"
));
}
}
else
{
}
else
{
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
is_inplace
=
false
;
is_inplace
=
false
;
PADDLE_ENFORCE_NE
(
d_x
,
d_y
,
if
(
d_x
)
{
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_NE
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
d_x
,
d_y
,
platform
::
errors
::
InvalidArgument
(
"X@GRAD and Y@GRAD inplaced in non-inplace mode"
));
}
}
}
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -867,7 +871,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
// init output
// init output
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
d_x
)
{
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
if
(
d_scale
&&
d_bias
)
{
if
(
d_scale
&&
d_bias
)
{
d_scale
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
d_scale
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
...
@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -908,7 +914,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_d_y
(
d_y
->
type
());
Tensor
transformed_d_y
(
d_y
->
type
());
Tensor
transformed_d_x
(
d_x
->
type
())
;
Tensor
transformed_d_x
;
if
(
data_layout
==
DataLayout
::
kNHWC
&&
if
(
data_layout
==
DataLayout
::
kNHWC
&&
compute_format
==
DataLayout
::
kNCHW
)
{
compute_format
==
DataLayout
::
kNCHW
)
{
VLOG
(
3
)
<<
"Transform input tensor from NHWC to NCHW."
;
VLOG
(
3
)
<<
"Transform input tensor from NHWC to NCHW."
;
...
@@ -920,12 +926,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -920,12 +926,16 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
&
transformed_d_y
);
&
transformed_d_y
);
TransToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_y
,
TransToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_y
,
&
transformed_d_y
);
&
transformed_d_y
);
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_x
,
if
(
d_x
)
{
&
transformed_d_x
);
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_x
,
&
transformed_d_x
);
}
}
else
{
}
else
{
transformed_x
.
ShareDataWith
(
*
x
);
transformed_x
.
ShareDataWith
(
*
x
);
transformed_d_y
.
ShareDataWith
(
*
d_y
);
transformed_d_y
.
ShareDataWith
(
*
d_y
);
transformed_d_x
.
ShareDataWith
(
*
d_x
);
if
(
d_x
)
{
transformed_d_x
.
ShareDataWith
(
*
d_x
);
}
}
}
std
::
vector
<
int
>
dims
;
std
::
vector
<
int
>
dims
;
...
@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -954,7 +964,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
if
(
!
use_global_stats
)
{
if
(
!
use_global_stats
)
{
if
((
N
*
H
*
W
*
D
)
==
1
)
{
if
((
N
*
H
*
W
*
D
)
==
1
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
if
(
d_x
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
}
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
BatchNormParamType
<
T
>>
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
BatchNormParamType
<
T
>>
functor
;
functor
;
functor
(
dev_ctx
,
d_scale
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
functor
(
dev_ctx
,
d_scale
,
static_cast
<
BatchNormParamType
<
T
>>
(
0
));
...
@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -1042,7 +1054,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
}
// This branch calls CUDNN APIs
// This branch calls CUDNN APIs
if
(
d_scale
&&
d_bias
)
{
if
(
d_
x
&&
d_
scale
&&
d_bias
)
{
bool
called
=
false
;
bool
called
=
false
;
#if CUDNN_VERSION_MIN(7, 4, 1)
#if CUDNN_VERSION_MIN(7, 4, 1)
called
=
true
;
called
=
true
;
...
@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -1187,6 +1199,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
d_x
->
data
<
T
>
());
}
}
if
(
d_scale
&&
d_bias
)
{
KeBNBackwardScaleBias
<
T
,
block
,
framework
::
DataLayout
::
kNCHW
><<<
grid2
,
block
,
0
,
stream
>>>
(
d_y
->
data
<
T
>
(),
x
->
data
<
T
>
(),
saved_mean_data
,
saved_var_data
,
epsilon
,
N
,
C
,
H
*
W
*
D
,
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
}
else
{
}
else
{
if
(
d_x
)
{
if
(
d_x
)
{
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNHWC
><<<
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNHWC
><<<
...
@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -1195,6 +1216,15 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
d_x
->
data
<
T
>
());
}
}
if
(
d_scale
&&
d_bias
)
{
KeBNBackwardScaleBias
<
T
,
block
,
framework
::
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
stream
>>>
(
d_y
->
data
<
T
>
(),
x
->
data
<
T
>
(),
saved_mean_data
,
saved_var_data
,
epsilon
,
N
,
C
,
H
*
W
*
D
,
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
637f27c6
...
@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
...
@@ -515,6 +515,13 @@ class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining):
os
.
environ
[
'FLAGS_cudnn_batchnorm_spatial_persistent'
]
=
"1"
os
.
environ
[
'FLAGS_cudnn_batchnorm_spatial_persistent'
]
=
"1"
class
TestBatchNormOpTrainingCase3
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
False
self
.
no_grad_set
=
set
([
'x@GRAD'
])
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'scale@GRAD'
,
'bias@GRAD'
]
class
TestBatchNormOpTrainingMomentumVariable
(
TestBatchNormOpTraining
):
class
TestBatchNormOpTrainingMomentumVariable
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
use_momentum_variable
=
True
self
.
use_momentum_variable
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录