Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8daccc9e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8daccc9e
编写于
9月 25, 2020
作者:
C
ceci3
提交者:
GitHub
9月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix batch norm double grad compute (#27549)
* fix bn double grad, test=develop * update, test=develop
上级
c143326d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
131 addition
and
41 deletion
+131
-41
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+34
-21
paddle/fluid/operators/instance_norm_op.cc
paddle/fluid/operators/instance_norm_op.cc
+3
-3
paddle/fluid/operators/norm_utils.cu.h
paddle/fluid/operators/norm_utils.cu.h
+58
-17
python/paddle/fluid/tests/unittests/test_norm_nn_grad.py
python/paddle/fluid/tests/unittests/test_norm_nn_grad.py
+36
-0
未找到文件。
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
8daccc9e
...
...
@@ -839,6 +839,7 @@ void BatchNormDoubleGradMaker<T>::Apply(GradOpPtr<T> op) const {
op
->
SetInput
(
"SavedMean"
,
this
->
Input
(
"SavedMean"
));
op
->
SetInput
(
"SavedVariance"
,
this
->
Input
(
"SavedVariance"
));
if
(
BOOST_GET_CONST
(
bool
,
this
->
GetAttr
(
"use_global_stats"
)))
{
op
->
SetInput
(
"Mean"
,
this
->
Input
(
"Mean"
));
op
->
SetInput
(
"Variance"
,
this
->
Input
(
"Variance"
));
}
op
->
SetInput
(
"DDX"
,
this
->
OutputGrad
(
framework
::
GradVarName
(
"X"
)));
...
...
@@ -868,14 +869,19 @@ void BatchNormDoubleGradOp::InferShape(
"BatchNormDoubleGrad"
);
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"DDX"
),
"Input"
,
"DDX"
,
"BatchNormDoubleGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"DY"
),
"Input"
,
"DY"
,
"BatchNormDoubleGrad"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"DX"
),
"Output"
,
"DX"
,
"BatchNormDoubleGrad"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
int
C
=
x_dims
[
1
];
const
DataLayout
data_layout
=
framework
::
StringToDataLayout
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"data_layout"
));
const
int
C
=
((
this
->
IsMKLDNNType
()
==
true
)
||
(
data_layout
==
DataLayout
::
kNCHW
)
?
x_dims
[
1
]
:
x_dims
[
x_dims
.
size
()
-
1
]);
if
(
ctx
->
HasOutput
(
"DX"
))
{
ctx
->
SetOutputDim
(
"DX"
,
x_dims
);
}
...
...
@@ -957,7 +963,9 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
Tensor
inv_var_tensor
;
if
(
use_global_stats
)
{
const
auto
*
running_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
running_variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
mean_data
=
running_mean
->
data
<
T
>
();
inv_var_tensor
.
Resize
({
C
});
T
*
running_inv_var_data
=
inv_var_tensor
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -1077,12 +1085,12 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
// (np.mean(ddx, axis=(n,h,w)) - ddx)
)
+ ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w)))
)
// axis=(n,h,w)))
if
(
ddX
)
{
dx_arr
+=
...
...
@@ -1176,7 +1184,8 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
C
,
sample_size
);
ddy_arr
.
setZero
();
if
(
use_global_stats
)
{
// math: ddy = r * ddx * inv_var
// math: ddy = r * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
if
(
ddX
)
{
ddy_arr
=
scale_tile_data
*
ddx_arr
*
inv_var_tile_data
;
}
...
...
@@ -1196,25 +1205,29 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
.
replicate
(
1
,
sample_size
)
/
sample_size
);
}
if
(
ddScale
&&
ddBias
)
{
ConstEigenVectorArrayMap
<
T
>
ddscale_arr
(
ddScale
->
data
<
T
>
(),
C
);
Tensor
ddscale_tile
;
ddscale_tile
.
Resize
({
C
,
sample_size
});
EigenArrayMap
<
T
>
ddscale_tile_data
(
ddscale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
sample_size
);
ddscale_tile_data
=
ddscale_arr
.
replicate
(
1
,
sample_size
);
}
if
(
ddScale
)
{
ConstEigenVectorArrayMap
<
T
>
ddscale_arr
(
ddScale
->
data
<
T
>
(),
C
);
Tensor
ddscale_tile
;
ddscale_tile
.
Resize
({
C
,
sample_size
});
EigenArrayMap
<
T
>
ddscale_tile_data
(
ddscale_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
sample_size
);
ddscale_tile_data
=
ddscale_arr
.
replicate
(
1
,
sample_size
);
ddy_arr
+=
x_sub_mean_mul_invstd_arr
*
ddscale_tile_data
;
}
ConstEigenVectorArrayMap
<
T
>
ddbias_arr
(
ddBias
->
data
<
T
>
(),
C
);
Tensor
ddbias_tile
;
ddbias_tile
.
Resize
({
C
,
sample_size
});
EigenArrayMap
<
T
>
ddbias_tile_data
(
ddbias_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
sample_size
);
ddbias_tile_data
=
ddbias_arr
.
replicate
(
1
,
sample_size
);
if
(
ddBias
)
{
ConstEigenVectorArrayMap
<
T
>
ddbias_arr
(
ddBias
->
data
<
T
>
(),
C
);
Tensor
ddbias_tile
;
ddbias_tile
.
Resize
({
C
,
sample_size
});
EigenArrayMap
<
T
>
ddbias_tile_data
(
ddbias_tile
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
,
sample_size
);
ddbias_tile_data
=
ddbias_arr
.
replicate
(
1
,
sample_size
);
ddy_arr
+=
x_sub_mean_mul_invstd_arr
*
ddscale_tile_data
;
ddy_arr
+=
ddbias_tile_data
;
}
ddy_arr
+=
ddbias_tile_data
;
}
if
(
data_layout
==
DataLayout
::
kNCHW
)
{
VLOG
(
3
)
<<
"Transform batchnorm output from NHWC to NCHW"
;
TransToChannelFirst
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
(
...
...
paddle/fluid/operators/instance_norm_op.cc
浏览文件 @
8daccc9e
...
...
@@ -520,11 +520,11 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
// (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW *
// np.sum(dy,
// axis=(h,w)) * (x - mean) *
// (np.mean(ddx, axis=(h,w)) - ddx)
+ ddr * (dy * inv_var - inv_var
// *
// (np.mean(ddx, axis=(h,w)) - ddx)
) + ddr * (dy * inv_var -
//
inv_var
*
// np.mean(dy, axis=(h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(h,w)))
)
// axis=(h,w)))
Tensor
x_sub_mean_mul_invstd
;
x_sub_mean_mul_invstd
.
Resize
({
sample_size
,
NxC
});
...
...
paddle/fluid/operators/norm_utils.cu.h
浏览文件 @
8daccc9e
...
...
@@ -40,12 +40,12 @@ using DataLayout = framework::DataLayout;
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
// (np.mean(ddx, axis=(n,h,w)) - ddx)
)
+ ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w)))
)
// axis=(n,h,w)))
template
<
typename
T
,
int
BlockDim
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDX
(
const
T
*
x
,
const
T
*
mean
,
...
...
@@ -138,7 +138,7 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
?
(
j
/
sample_size
*
C
+
i
)
*
sample_size
+
j
%
sample_size
:
j
*
outer_size
+
i
;
dx
[
index
]
+=
(
dy
[
index
]
*
var_val
-
dy_sum_val
/
inner_size
*
var_val
-
(
x
[
index
]
-
mean_val
)
*
var_val
*
(
x
[
index
]
-
mean_val
)
*
var_val
*
var_val
*
dy_mul_x_sub_mean_sum_val
*
var_val
/
inner_size
)
*
ddscale
[
i
];
}
...
...
@@ -326,19 +326,57 @@ __global__ void DoubleGradComputeDScaleWithGlobal(
}
// math: dx = ddscale * dy * inv_var
// math: ddy = scale * ddx * inv_var
template
<
typename
T
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDataWithGlobal
(
const
T
*
dy
,
const
T
*
scale
,
const
T
*
variance
,
const
double
epsilon
,
const
int
C
,
const
int
sample_size
,
const
int
num
,
T
*
dx
)
{
__global__
void
DoubleGradComputeDXWithGlobal
(
const
T
*
dy
,
const
T
*
ddscale
,
const
T
*
variance
,
const
double
epsilon
,
const
int
C
,
const
int
sample_size
,
const
int
num
,
T
*
dx
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
if
(
scale
!=
nullptr
)
{
if
(
dd
scale
!=
nullptr
)
{
for
(
int
i
=
gid
;
i
<
num
;
i
+=
stride
)
{
const
int
c
=
layout
==
framework
::
DataLayout
::
kNCHW
?
i
/
sample_size
%
C
:
i
%
C
;
T
inv_var
=
1.0
/
sqrt
(
variance
[
c
]
+
epsilon
);
dx
[
i
]
=
dy
[
i
]
*
scale
[
c
]
*
inv_var
;
dx
[
i
]
=
dy
[
i
]
*
ddscale
[
c
]
*
inv_var
;
}
}
}
// math: ddy = scale * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
template
<
typename
T
,
framework
::
DataLayout
layout
>
__global__
void
DoubleGradComputeDDYWithGlobal
(
const
T
*
ddx
,
const
T
*
scale
,
const
T
*
mean
,
const
T
*
variance
,
const
T
*
x
,
const
T
*
ddbias
,
const
T
*
ddscale
,
const
double
epsilon
,
const
int
C
,
const
int
sample_size
,
const
int
num
,
T
*
ddy
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
if
(
ddx
!=
nullptr
)
{
for
(
int
i
=
gid
;
i
<
num
;
i
+=
stride
)
{
const
int
c
=
layout
==
framework
::
DataLayout
::
kNCHW
?
i
/
sample_size
%
C
:
i
%
C
;
T
inv_var
=
1.0
/
sqrt
(
variance
[
c
]
+
epsilon
);
ddy
[
i
]
+=
ddx
[
i
]
*
scale
[
c
]
*
inv_var
;
}
}
__syncthreads
();
if
(
ddscale
!=
nullptr
)
{
for
(
int
i
=
gid
;
i
<
num
;
i
+=
stride
)
{
const
int
c
=
layout
==
framework
::
DataLayout
::
kNCHW
?
i
/
sample_size
%
C
:
i
%
C
;
T
inv_var
=
1.0
/
sqrt
(
variance
[
c
]
+
epsilon
);
ddy
[
i
]
+=
(
x
[
i
]
-
mean
[
c
])
*
inv_var
*
ddscale
[
c
];
}
}
__syncthreads
();
if
(
ddbias
!=
nullptr
)
{
for
(
int
i
=
gid
;
i
<
num
;
i
+=
stride
)
{
const
int
c
=
layout
==
framework
::
DataLayout
::
kNCHW
?
i
/
sample_size
%
C
:
i
%
C
;
ddy
[
i
]
+=
ddbias
[
c
];
}
}
}
...
...
@@ -383,8 +421,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
const
T
*
mean_data
,
*
variance_data
;
if
(
use_global_stats
)
{
const
auto
*
running_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
running_var
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
const
auto
*
running_mean_data
=
running_mean
->
template
data
<
T
>();
const
auto
*
running_var_data
=
running_var
->
template
data
<
T
>();
mean_data
=
running_mean_data
;
variance_data
=
running_var_data
;
}
else
{
const
T
*
smean_data
=
Saved_mean
->
data
<
T
>
();
...
...
@@ -398,12 +439,12 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant
(
dev_ctx
,
dX
,
static_cast
<
T
>
(
0
));
if
(
use_global_stats
)
{
if
(
data_layout
==
DataLayout
::
kNHWC
)
{
DoubleGradComputeD
ata
WithGlobal
<
DoubleGradComputeD
X
WithGlobal
<
T
,
DataLayout
::
kNHWC
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
dy_data
,
ddscale_data
,
variance_data
,
epsilon
,
C
,
sample_size
,
num
,
dx_data
);
}
else
{
DoubleGradComputeD
ata
WithGlobal
<
DoubleGradComputeD
X
WithGlobal
<
T
,
DataLayout
::
kNCHW
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
dy_data
,
ddscale_data
,
variance_data
,
epsilon
,
C
,
sample_size
,
num
,
dx_data
);
...
...
@@ -456,15 +497,15 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant
(
dev_ctx
,
ddY
,
static_cast
<
T
>
(
0
));
if
(
use_global_stats
)
{
if
(
data_layout
==
DataLayout
::
kNHWC
)
{
DoubleGradComputeD
ata
WithGlobal
<
DoubleGradComputeD
DY
WithGlobal
<
T
,
DataLayout
::
kNHWC
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
ddx_data
,
scale_data
,
variance_data
,
epsilon
,
C
,
sample_size
,
num
,
ddy_data
);
ddx_data
,
scale_data
,
mean_data
,
variance_data
,
x_data
,
ddbias_data
,
dd
scale_data
,
epsilon
,
C
,
sample_size
,
num
,
dd
y_data
);
}
else
{
DoubleGradComputeD
ata
WithGlobal
<
DoubleGradComputeD
DY
WithGlobal
<
T
,
DataLayout
::
kNCHW
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
ddx_data
,
scale_data
,
variance_data
,
epsilon
,
C
,
sample_size
,
num
,
ddy_data
);
ddx_data
,
scale_data
,
mean_data
,
variance_data
,
x_data
,
ddbias_data
,
dd
scale_data
,
epsilon
,
C
,
sample_size
,
num
,
dd
y_data
);
}
}
else
{
if
(
data_layout
==
DataLayout
::
kNHWC
)
{
...
...
python/paddle/fluid/tests/unittests/test_norm_nn_grad.py
浏览文件 @
8daccc9e
...
...
@@ -130,5 +130,41 @@ class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck):
self
.
shape
=
[
2
,
2
,
3
,
4
,
5
]
class
TestBatchNormDoubleGradCheckCase5
(
TestBatchNormDoubleGradCheck
):
@
prog_scope
()
def
func
(
self
,
place
):
prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
prog
):
np
.
random
.
seed
()
dtype
=
"float32"
eps
=
0.005
atol
=
2e-4
chn
=
self
.
shape
[
1
]
if
self
.
data_layout
==
'NCHW'
else
self
.
shape
[
-
1
]
x
=
layers
.
create_parameter
(
dtype
=
dtype
,
shape
=
self
.
shape
,
name
=
'x'
)
z
=
fluid
.
layers
.
batch_norm
(
input
=
x
,
data_layout
=
self
.
data_layout
,
use_global_stats
=
self
.
use_global_stats
)
x_arr
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
dtype
)
w
,
b
=
prog
.
global_block
().
all_parameters
()[
1
:
3
]
w_arr
=
np
.
ones
(
chn
).
astype
(
dtype
)
b_arr
=
np
.
zeros
(
chn
).
astype
(
dtype
)
gradient_checker
.
double_grad_check
(
[
x
,
w
,
b
],
z
,
x_init
=
[
x_arr
,
w_arr
,
b_arr
],
atol
=
atol
,
place
=
place
,
eps
=
eps
)
class
TestBatchNormDoubleGradCheckCase6
(
TestBatchNormDoubleGradCheckCase5
):
def
init_test
(
self
):
self
.
data_layout
=
'NCHW'
self
.
use_global_stats
=
True
self
.
shape
=
[
2
,
3
,
4
,
5
]
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录