Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
619a241b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
619a241b
编写于
9月 25, 2019
作者:
L
lvmengsi
提交者:
GitHub
9月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix OpTest of bn (#19062)
* fix bn
上级
5920d69d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
235 addition
and
60 deletion
+235
-60
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+39
-23
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+98
-19
paddle/fluid/operators/instance_norm_op.cu
paddle/fluid/operators/instance_norm_op.cu
+75
-13
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+16
-5
python/paddle/fluid/tests/unittests/test_instance_norm_op.py
python/paddle/fluid/tests/unittests/test_instance_norm_op.py
+7
-0
未找到文件。
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
619a241b
...
...
@@ -496,6 +496,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
int
scale_coefff
=
use_global_stats
?
1
:
N
*
sample_size
;
const
auto
scale_inv_var_nhw
=
scale_arr
*
inv_var_arr
/
scale_coefff
;
Tensor
dy_sum
;
dy_sum
.
Resize
({
C
});
dy_sum
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenVectorArrayMap
<
T
>
dy_sum_arr
(
dy_sum
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
Tensor
dy_mul_x_sub_mean_mul_invstd_sum
;
dy_mul_x_sub_mean_mul_invstd_sum
.
Resize
({
C
});
dy_mul_x_sub_mean_mul_invstd_sum
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
EigenVectorArrayMap
<
T
>
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
dy_mul_x_sub_mean_mul_invstd_sum
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
dy_sum_arr
.
setZero
();
dy_mul_x_sub_mean_mul_invstd_sum_arr
.
setZero
();
switch
(
data_layout
)
{
case
DataLayout
::
kNCHW
:
{
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
sample_size
,
N
*
C
);
...
...
@@ -504,23 +519,27 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
sample_size
,
N
*
C
);
d_x_arr
.
setZero
();
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
dy_sum_arr
(
c
)
+=
d_y_arr
.
col
(
nc
).
sum
();
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
c
)
+=
((
x_arr
.
col
(
nc
)
-
mean_arr
(
c
))
*
inv_var_arr
(
c
)
*
d_y_arr
.
col
(
nc
))
.
sum
();
}
if
(
d_scale
&&
d_bias
)
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_bias_arr
(
c
)
+=
d_y_arr
.
col
(
nc
).
sum
();
d_scale_arr
(
c
)
+=
((
x_arr
.
col
(
nc
)
-
mean_arr
(
c
))
*
inv_var_arr
(
c
)
*
d_y_arr
.
col
(
nc
))
.
sum
();
}
d_bias_arr
=
dy_sum_arr
;
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
}
if
(
!
use_global_stats
)
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
int
c
=
nc
%
C
;
d_x_arr
.
col
(
nc
)
+=
scale_inv_var_nhw
(
c
)
*
(
d_y_arr
.
col
(
nc
)
*
N
*
sample_size
-
d
_bias
_arr
(
c
)
-
(
x_arr
.
col
(
nc
)
-
mean_arr
[
c
])
*
d_scale_arr
(
c
)
*
inv_var_arr
(
c
));
(
d_y_arr
.
col
(
nc
)
*
N
*
sample_size
-
d
y_sum
_arr
(
c
)
-
(
x_arr
.
col
(
nc
)
-
mean_arr
[
c
])
*
dy_mul_x_sub_mean_mul_invstd_sum_arr
(
c
)
*
inv_var_arr
(
c
));
}
}
else
{
for
(
int
nc
=
0
;
nc
<
N
*
C
;
++
nc
)
{
...
...
@@ -537,27 +556,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
N
*
sample_size
);
d_x_arr
.
setZero
();
const
auto
d_y_row_sum
=
d_y_arr
.
rowwise
().
sum
();
const
auto
x_minus_mean
=
x_arr
.
colwise
()
-
mean_arr
;
const
auto
d_y_mul_x_minus_mean_row_sum
=
(
d_y_arr
*
x_minus_mean
).
rowwise
().
sum
(
);
const
auto
inv_var_sqr
=
inv_var_arr
*
inv_var_arr
;
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
dy_sum_arr
+=
d_y_arr
.
col
(
nhw
)
;
dy_mul_x_sub_mean_mul_invstd_sum_arr
+
=
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
inv_var_arr
*
d_y_arr
.
col
(
nhw
);
}
if
(
d_scale
&&
d_bias
)
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_bias_arr
+=
d_y_arr
.
col
(
nhw
);
d_scale_arr
+=
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
inv_var_arr
*
d_y_arr
.
col
(
nhw
);
}
d_bias_arr
=
dy_sum_arr
;
d_scale_arr
=
dy_mul_x_sub_mean_mul_invstd_sum_arr
;
}
if
(
!
use_global_stats
)
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
d_x_arr
.
col
(
nhw
)
+=
scale_inv_var_nhw
*
(
d_y_arr
.
col
(
nhw
)
*
N
*
sample_size
-
d
_y_row_sum
-
x_minus_mean
.
col
(
nhw
)
*
inv_var_sqr
*
d
_y_mul_x_minus_mean_row_sum
);
(
d_y_arr
.
col
(
nhw
)
*
N
*
sample_size
-
d
y_sum_arr
-
(
x_arr
.
col
(
nhw
)
-
mean_arr
)
*
d
y_mul_x_sub_mean_mul_invstd_sum_arr
*
inv_var_arr
);
}
}
else
{
for
(
int
nhw
=
0
;
nhw
<
N
*
sample_size
;
++
nhw
)
{
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
619a241b
...
...
@@ -234,6 +234,63 @@ static __global__ void KeBNBackwardData(const T *dy,
}
}
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
static
__global__
void
BNBackwardData
(
const
T
*
dy
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
mean
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
variance
,
const
int
C
,
const
int
N
,
const
int
HxW
,
T
*
dx
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
HxW
;
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_x_sub_mean_storage
;
__shared__
BatchNormParamType
<
T
>
dy_sum_val
;
__shared__
BatchNormParamType
<
T
>
dy_x_sub_mean_sum_val
;
for
(
int
i
=
blockIdx
.
x
;
i
<
outer_size
;
i
+=
gridDim
.
x
)
{
BatchNormParamType
<
T
>
inv_var_i
=
variance
[
i
];
BatchNormParamType
<
T
>
mean_i
=
mean
[
i
];
BatchNormParamType
<
T
>
dy_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
dy_x_sub_mean_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
BatchNormParamType
<
T
>
dy_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
index
]);
dy_sum
+=
dy_i
;
dy_x_sub_mean_sum
+=
dy_i
*
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
])
-
mean_i
);
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_x_sub_mean_sum
=
BlockReduce
(
dy_x_sub_mean_storage
)
.
Reduce
(
dy_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
dy_x_sub_mean_sum_val
=
dy_x_sub_mean_sum
;
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
inner_size
;
j
+=
blockDim
.
x
)
{
const
int
index
=
layout
==
framework
::
DataLayout
::
kNCHW
?
(
j
/
HxW
*
C
+
i
)
*
HxW
+
j
%
HxW
:
j
*
outer_size
+
i
;
dx
[
index
]
=
(
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
index
])
-
dy_sum_val
/
static_cast
<
BatchNormParamType
<
T
>>
(
inner_size
)
-
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
index
])
-
mean_i
)
*
dy_x_sub_mean_sum_val
*
inv_var_i
*
inv_var_i
/
inner_size
)
*
scale
[
i
]
*
inv_var_i
;
}
}
}
template
<
typename
T
>
class
BatchNormGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -282,6 +339,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
const
int
num
=
x
->
numel
();
const
int
block
=
512
;
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
int
grid1
=
(
num
+
block
-
1
)
/
block
;
int
grid2
=
std
::
min
(
C
,
max_blocks
);
if
(
!
use_global_stats
)
{
if
((
N
*
H
*
W
*
D
)
==
1
)
{
framework
::
TensorCopy
(
*
d_y
,
ctx
.
GetPlace
(),
d_x
);
...
...
@@ -325,21 +389,43 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
void
*
saved_mean_data
=
const
auto
*
saved_mean_data
=
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
();
const
void
*
saved_var_data
=
const
auto
*
saved_var_data
=
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
d_y
->
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
if
(
d_scale
&&
d_bias
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
d_y
->
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
}
else
{
if
(
data_layout
==
framework
::
DataLayout
::
kNCHW
)
{
if
(
d_x
)
{
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNCHW
><<<
grid2
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
d_y
->
data
<
T
>
(),
scale
->
data
<
BatchNormParamType
<
T
>>
(),
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
}
}
else
{
if
(
d_x
)
{
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNCHW
><<<
grid2
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
d_y
->
data
<
T
>
(),
scale
->
data
<
BatchNormParamType
<
T
>>
(),
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
N
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
}
}
}
// clean when exit.
CUDNN_ENFORCE
(
...
...
@@ -355,13 +441,6 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const
auto
*
running_var_data
=
running_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
const
int
num
=
x
->
numel
();
const
int
block
=
512
;
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
int
grid1
=
(
num
+
block
-
1
)
/
block
;
int
grid2
=
std
::
min
(
C
,
max_blocks
);
if
(
data_layout
==
framework
::
DataLayout
::
kNCHW
)
{
if
(
d_x
)
{
KeBNBackwardData
<
T
,
framework
::
DataLayout
::
kNCHW
><<<
...
...
paddle/fluid/operators/instance_norm_op.cu
浏览文件 @
619a241b
...
...
@@ -170,6 +170,58 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
}
};
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
GradComputeDX
(
const
T
*
dy
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
mean
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
variance
,
const
int
C
,
const
int
sample_size
,
T
*
dx
)
{
int
beg_idx
=
blockIdx
.
x
*
sample_size
+
threadIdx
.
x
;
int
end_idx
=
(
blockIdx
.
x
+
1
)
*
sample_size
;
int
ncid
=
blockIdx
.
x
;
int
c
=
ncid
%
C
;
BatchNormParamType
<
T
>
mean_val
=
mean
[
ncid
];
BatchNormParamType
<
T
>
inv_var_val
=
variance
[
ncid
];
typedef
cub
::
BlockReduce
<
BatchNormParamType
<
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
dy_storage
;
__shared__
typename
BlockReduce
::
TempStorage
dy_x_sub_mean_storage
;
__shared__
BatchNormParamType
<
T
>
dy_sum_val
;
__shared__
BatchNormParamType
<
T
>
dy_x_sub_mean_sum_val
;
BatchNormParamType
<
T
>
dy_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
BatchNormParamType
<
T
>
dy_x_sub_mean_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
BatchNormParamType
<
T
>
dy_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
i
]);
dy_sum
+=
dy_i
;
dy_x_sub_mean_sum
+=
dy_i
*
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
i
])
-
mean_val
);
}
dy_sum
=
BlockReduce
(
dy_storage
).
Reduce
(
dy_sum
,
cub
::
Sum
());
dy_x_sub_mean_sum
=
BlockReduce
(
dy_x_sub_mean_storage
).
Reduce
(
dy_x_sub_mean_sum
,
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
{
dy_sum_val
=
dy_sum
;
dy_x_sub_mean_sum_val
=
dy_x_sub_mean_sum
;
}
__syncthreads
();
for
(
int
i
=
beg_idx
;
i
<
end_idx
;
i
+=
BlockDim
)
{
dx
[
i
]
=
(
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
i
])
-
dy_sum_val
/
static_cast
<
BatchNormParamType
<
T
>>
(
sample_size
)
-
(
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
i
])
-
mean_val
)
*
dy_x_sub_mean_sum_val
*
inv_var_val
*
inv_var_val
/
sample_size
)
*
scale
[
c
]
*
inv_var_val
;
}
}
template
<
typename
T
>
class
InstanceNormGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
...
...
@@ -258,21 +310,31 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
void
*
saved_mean_data
=
const
auto
*
saved_mean_data
=
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
();
const
void
*
saved_var_data
=
const
auto
*
saved_var_data
=
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x_tmp
.
template
data
<
T
>(),
data_desc_
,
d_y_tmp
.
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
in_param_desc_
,
scale_tmp
.
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
if
(
d_scale
&&
d_bias
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x_tmp
.
template
data
<
T
>(),
data_desc_
,
d_y_tmp
.
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
in_param_desc_
,
scale_tmp
.
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
d_bias_tmp
.
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
}
else
{
if
(
d_x
)
{
GradComputeDX
<
T
,
block
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
d_y
->
data
<
T
>
(),
scale
->
data
<
BatchNormParamType
<
T
>>
(),
saved_mean_data
,
x
->
data
<
T
>
(),
saved_var_data
,
C
,
H
*
W
*
D
,
d_x
->
data
<
T
>
());
}
}
if
(
d_scale
&&
d_bias
)
{
add_param
<
T
,
block
,
false
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
619a241b
...
...
@@ -338,14 +338,14 @@ class TestBatchNormOpTraining(unittest.TestCase):
return
y
,
mean_out
,
variance_out
,
saved_mean
,
saved_variance
,
x_grad
,
scale_grad
,
bias_grad
def
set_mean_variance
(
self
,
scale_shape
,
x
,
data_layout
):
mean
=
np
.
zeros
(
scale_shape
).
astype
(
np
.
float32
)
variance
=
np
.
ones
(
scale_shape
).
astype
(
np
.
float32
)
mean
,
variance
=
_cal_mean_variance
(
x
,
self
.
epsilon
,
data_layout
)
mean_pre
=
np
.
zeros
(
scale_shape
).
astype
(
np
.
float32
)
variance_pre
=
np
.
ones
(
scale_shape
).
astype
(
np
.
float32
)
# computing global mean/variance for one step
if
self
.
use_global_stats
:
mom
=
self
.
momentum
x_mean
,
x_var
=
_cal_mean_variance
(
x
,
self
.
epsilon
,
data_layout
)
mean
=
x_mean
*
(
1.
-
mom
)
+
mom
*
mean
variance
=
x_var
*
(
1.
-
mom
)
+
mom
*
variance
mean
=
mean
*
(
1.
-
mom
)
+
mom
*
mean_pre
variance
=
variance
*
(
1.
-
mom
)
+
mom
*
variance_pre
return
mean
,
variance
def
test_forward_backward
(
self
):
...
...
@@ -442,6 +442,10 @@ class TestBatchNormOpTraining(unittest.TestCase):
fetch_list
=
self
.
fetch_list
)
for
id
,
name
in
enumerate
(
self
.
fetch_list
):
if
name
==
'variance'
:
self
.
__assert_close
(
var_dict
[
name
],
out
[
id
],
name
,
atol
=
1e-3
)
continue
self
.
__assert_close
(
var_dict
[
name
],
out
[
id
],
name
)
print
(
"op test forward passed: "
,
str
(
place
),
data_layout
)
...
...
@@ -458,6 +462,13 @@ class TestBatchNormOpTraining(unittest.TestCase):
pass
class
TestBatchNormOpTrainingCase1
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
False
self
.
no_grad_set
=
set
([
'scale@GRAD'
,
'bias@GRAD'
])
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'x@GRAD'
]
class
TestBatchNormOpFreezeStatsTraining
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
True
...
...
python/paddle/fluid/tests/unittests/test_instance_norm_op.py
浏览文件 @
619a241b
...
...
@@ -184,5 +184,12 @@ class TestInstanceNormOpTraining(unittest.TestCase):
test_with_place
(
place
,
[
2
,
3
,
4
,
5
])
class
TestInstanceNormOpTrainingCase1
(
TestInstanceNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
False
self
.
no_grad_set
=
set
([
'scale@GRAD'
,
'bias@GRAD'
])
self
.
fetch_list
=
[
'y'
,
'saved_mean'
,
'saved_variance'
,
'x@GRAD'
]
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录