Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9774f965
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9774f965
编写于
4月 21, 2022
作者:
Z
Zhangjingyu06
提交者:
GitHub
4月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify batch_norm and batch_norm_grad. *test=kunlun (#41976)
上级
c3b0b680
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
101 addition
and
75 deletion
+101
-75
paddle/fluid/operators/batch_norm_op_xpu.cc
paddle/fluid/operators/batch_norm_op_xpu.cc
+101
-75
未找到文件。
paddle/fluid/operators/batch_norm_op_xpu.cc
浏览文件 @
9774f965
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
...
...
@@ -38,15 +37,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
bool
global_stats
=
test_mode
||
use_global_stats
;
const
auto
&
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
const
auto
data_layout
=
framework
::
StringToDataLayout
(
data_layout_str
);
PADDLE_ENFORCE_EQ
(
data_layout_str
==
"NCHW"
||
data_layout_str
==
"NHWC"
,
true
,
platform
::
errors
::
InvalidArgument
(
"The 'data_layout' attribute must be NCHW or NHWC. "
"But recevived 'data_layout' is [%s]."
,
data_layout_str
));
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
&
x_dims
=
x
->
dims
();
int
temp
=
x_dims
[
3
];
temp
=
(
x_dims
.
size
()
!=
4
)
?
1
:
temp
;
bool
is_nchw
=
(
data_layout
==
DataLayout
::
kNCHW
);
const
int
N
=
x_dims
[
0
];
const
int
C
=
is_nchw
?
x_dims
[
1
]
:
temp
;
const
int
H
=
is_nchw
?
x_dims
[
2
]
:
x_dims
[
1
];
const
int
W
=
is_nchw
?
temp
:
x_dims
[
2
];
PADDLE_ENFORCE_EQ
(
x_dims
.
size
()
>=
2
&&
x_dims
.
size
()
<=
5
,
true
,
platform
::
errors
::
InvalidArgument
(
"The size of input's dimensions should be between 2 and 5"
"But received: the size of input's dimensions is [%d]"
,
x_dims
.
size
()));
int
N
,
C
,
H
,
W
,
D
;
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
x_data
=
x
->
data
<
T
>
();
...
...
@@ -67,6 +76,7 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
saved_variance
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
bool
is_nchw
=
data_layout_str
==
"NCHW"
;
if
(
!
global_stats
)
{
auto
*
mean_out_data
=
mean_out
->
data
<
float
>
();
...
...
@@ -83,35 +93,29 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
&
mom_cpu
);
momentum
=
mom_tensor
->
data
<
float
>
()[
0
];
}
if
(
C
==
1
)
{
int
r
=
xpu
::
batch_norm
<
T
>
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
N
,
1
,
H
,
W
,
epsilon
,
momentum
,
scale_data
,
bias_data
,
saved_mean_data
,
saved_variance_data
,
mean_out_data
,
variance_out_data
,
true
);
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
"The batch_norm XPU API return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
else
{
int
r
=
xpu
::
batch_norm
<
T
>
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
N
,
C
,
H
,
W
,
epsilon
,
momentum
,
scale_data
,
bias_data
,
saved_mean_data
,
saved_variance_data
,
mean_out_data
,
variance_out_data
,
is_nchw
);
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
"The batch_norm XPU API return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
int
r
=
xpu
::
batch_norm
<
T
>
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
N
,
C
,
H
,
W
,
epsilon
,
momentum
,
scale_data
,
bias_data
,
saved_mean_data
,
saved_variance_data
,
mean_out_data
,
variance_out_data
,
is_nchw
);
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
"The batch_norm XPU API return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
else
{
PADDLE_ENFORCE_EQ
(
data_layout_str
==
"NCHW"
,
true
,
platform
::
errors
::
InvalidArgument
(
"The batch_norm_infer 'data_layout' attribute must be NCHW. "
"But recevived 'data_layout' is [%s]."
,
data_layout_str
));
const
auto
*
mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
const
auto
*
mean_data
=
mean
->
data
<
float
>
();
const
auto
*
variance_data
=
variance
->
data
<
float
>
();
int
r
=
xpu
::
batch_norm_infer
(
dev_ctx
.
x_context
(),
x_data
,
y_data
,
N
,
C
,
H
,
W
,
epsilon
,
scale_data
,
bias_data
,
mean_data
,
variance_data
,
true
);
mean_data
,
variance_data
,
is_nchw
);
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
...
...
@@ -172,6 +176,13 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
auto
data_layout
=
framework
::
StringToDataLayout
(
data_layout_str
);
PADDLE_ENFORCE_EQ
(
data_layout_str
==
"NCHW"
||
data_layout_str
==
"NHWC"
,
true
,
platform
::
errors
::
InvalidArgument
(
"The 'data_layout' attribute must be NCHW or NHWC. "
"But recevived 'data_layout' is [%s]."
,
data_layout_str
));
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
d_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
...
...
@@ -204,13 +215,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
}
const
auto
&
x_dims
=
x
->
dims
();
int
temp
=
x_dims
[
3
];
temp
=
(
x_dims
.
size
()
!=
4
)
?
1
:
temp
;
bool
is_nchw
=
(
data_layout
==
DataLayout
::
kNCHW
);
const
int
N
=
x_dims
[
0
];
const
int
C
=
is_nchw
?
x_dims
[
1
]
:
temp
;
const
int
H
=
is_nchw
?
x_dims
[
2
]
:
x_dims
[
1
];
const
int
W
=
is_nchw
?
temp
:
x_dims
[
2
];
PADDLE_ENFORCE_EQ
(
x_dims
.
size
()
>=
2
&&
x_dims
.
size
()
<=
5
,
true
,
platform
::
errors
::
InvalidArgument
(
"The size of input's dimensions should be between 2 and 5"
"But received: the size of input's dimensions is [%d]"
,
x_dims
.
size
()));
int
N
,
C
,
H
,
W
,
D
;
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
const
auto
*
x_data
=
x
->
data
<
T
>
();
const
auto
*
d_y_data
=
d_y
->
data
<
T
>
();
...
...
@@ -235,42 +248,45 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"the size of scale's dimensions is [%d], the dimensions of scale "
"is [%s]."
,
scale
->
dims
().
size
(),
scale
->
dims
()));
PADDLE_ENFORCE_EQ
(
scale
->
dims
()[
0
],
C
,
platform
::
errors
::
InvalidArgument
(
"The first dimension of scale must equal to Channels[%d]. But "
"received: the first dimension of scale is [%d]"
,
C
,
scale
->
dims
()[
0
]));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
const
T
*
mean_data
=
nullptr
;
const
T
*
inv_var_data
=
nullptr
;
const
auto
*
batch_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
batch_inv_std
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
global_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
global_var
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
// TODO(guozibin): hadle the situation case of N * H * W = 1
if
(
!
use_global_stats
)
{
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
// SavedVariance have been reverted in forward operator
const
auto
*
saved_inv_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
mean_data
=
saved_mean
->
data
<
float
>
();
inv_var_data
=
saved_inv_variance
->
data
<
float
>
();
}
else
{
const
auto
*
running_mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
running_variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
mean_data
=
running_mean
->
data
<
float
>
();
inv_var_data
=
running_variance
->
data
<
float
>
();
float
*
running_inv_var_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
running_variance
->
numel
());
float
*
epsilon_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
1
);
int
r1
=
calculate_inv_var
(
dev_ctx
.
x_context
(),
inv_var_data
,
epsilon
,
C
,
epsilon_data
,
running_inv_var_data
);
PADDLE_ENFORCE_EQ
(
r1
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(batch_norm_grad "
"calculate_inv_var function) "
"return wrong value[%d %s]"
,
r1
,
XPUAPIErrorMsg
[
r1
]));
inv_var_data
=
running_inv_var_data
;
}
if
(
is_inplace
)
{
float
*
global_inv_std_data
=
nullptr
;
if
(
use_global_stats
)
{
global_inv_std_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
global_var
->
numel
());
float
*
epsilon_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
1
);
int
r1
=
calculate_inv_var
(
dev_ctx
.
x_context
(),
global_var
->
data
<
float
>
(),
epsilon
,
C
,
epsilon_data
,
global_inv_std_data
);
PADDLE_ENFORCE_EQ
(
r1
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(batch_norm_grad "
"calculate_inv_var function) "
"return wrong value[%d %s]"
,
r1
,
XPUAPIErrorMsg
[
r1
]));
}
auto
px
=
*
x
;
auto
*
inv_std_data
=
use_global_stats
?
global_inv_std_data
:
batch_inv_std
->
data
<
float
>
();
auto
mean_data
=
use_global_stats
?
global_mean
->
data
<
float
>
()
:
batch_mean
->
data
<
float
>
();
int
r2
=
calculate_inv_BN_Y
(
dev_ctx
.
x_context
(),
px
.
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
scale
->
data
<
float
>
(),
bias
->
data
<
float
>
(),
mean_data
,
inv_
var
_data
,
N
,
scale
->
data
<
float
>
(),
bias
->
data
<
float
>
(),
mean_data
,
inv_
std
_data
,
N
,
C
,
H
*
W
,
x
->
data
<
T
>
());
PADDLE_ENFORCE_EQ
(
r2
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(batch_norm_grad "
...
...
@@ -278,19 +294,29 @@ class BatchNormGradXPUKernel : public framework::OpKernel<T> {
"return wrong value[%d %s]"
,
r2
,
XPUAPIErrorMsg
[
r2
]));
}
if
(
!
d_x
)
{
d_x_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
T
>
(
x
->
numel
());
}
if
(
!
d_scale
)
{
d_scale_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
C
);
}
if
(
!
d_bias_data
)
{
d_bias_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
C
);
}
int
r3
=
xpu
::
batch_norm_grad
<
T
>
(
dev_ctx
.
x_context
(),
x_data
,
d_y_data
,
d_x_data
,
N
,
C
,
H
,
W
,
scale_data
,
mean_data
,
inv_var_data
,
d_scale_data
,
d_bias_data
,
is_nchw
);
int
r3
;
bool
is_nchw
=
data_layout_str
==
"NCHW"
;
if
(
use_global_stats
)
{
r3
=
xpu
::
batch_norm_grad
<
T
>
(
dev_ctx
.
x_context
(),
x_data
,
d_y_data
,
d_x_data
,
N
,
C
,
H
,
W
,
scale_data
,
nullptr
,
nullptr
,
d_scale_data
,
d_bias_data
,
is_nchw
,
global_mean
->
data
<
float
>
(),
global_var
->
data
<
float
>
(),
epsilon
);
}
else
{
if
(
!
d_x
)
{
d_x_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
T
>
(
x
->
numel
());
}
if
(
!
d_scale
)
{
d_scale_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
C
);
}
if
(
!
d_bias_data
)
{
d_bias_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
C
);
}
r3
=
xpu
::
batch_norm_grad
<
T
>
(
dev_ctx
.
x_context
(),
x_data
,
d_y_data
,
d_x_data
,
N
,
C
,
H
,
W
,
scale_data
,
batch_mean
->
data
<
float
>
(),
batch_inv_std
->
data
<
float
>
(),
d_scale_data
,
d_bias_data
,
is_nchw
);
}
PADDLE_ENFORCE_EQ
(
r3
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(batch_norm_grad) return "
"wrong value[%d %s]"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录