Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7d564356
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7d564356
编写于
6月 10, 2018
作者:
M
mozga-intel
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN layout: Support for batch norm operator
上级
b7c683b8
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
193 addition
and
157 deletion
+193
-157
paddle/fluid/operators/batch_norm_mkldnn_op.cc
paddle/fluid/operators/batch_norm_mkldnn_op.cc
+180
-146
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+13
-11
未找到文件。
paddle/fluid/operators/batch_norm_mkldnn_op.cc
浏览文件 @
7d564356
...
@@ -19,10 +19,17 @@ limitations under the License. */
...
@@ -19,10 +19,17 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
batch_norm_bwd
=
mkldnn
::
batch_normalization_backward
;
using
batch_norm_fwd
=
mkldnn
::
batch_normalization_forward
;
using
framework
::
DataLayout
;
using
framework
::
Tensor
;
using
mkldnn
::
memory
;
using
mkldnn
::
primitive
;
using
mkldnn
::
reorder
;
using
mkldnn
::
stream
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNMemDesc
;
using
paddle
::
platform
::
MKLDNNMemDesc
;
using
mkldnn
::
memory
;
using
platform
::
to_void_cast
;
template
<
typename
T
>
template
<
typename
T
>
using
EigenArrayMap
=
using
EigenArrayMap
=
...
@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) {
...
@@ -64,21 +71,12 @@ void run_batch_norm_op(Args &&... args) {
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
template
<
typename
T
>
inline
void
*
cast_const_to_void
(
const
T
*
t
)
{
return
static_cast
<
void
*>
(
const_cast
<
T
*>
(
t
));
}
}
// namespace
}
// namespace
template
<
typename
T
>
template
<
typename
T
>
class
BatchNormMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
BatchNormMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
auto
data_layout
=
framework
::
StringToDataLayout
(
data_layout_str
);
PADDLE_ENFORCE
(
data_layout
==
framework
::
DataLayout
::
kNCHW
,
"MKLDNN batch normalization handles only NCHW data layout"
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
const
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
const
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
...
@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -99,41 +97,53 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE
(
x
->
layout
()
==
DataLayout
::
kMKLDNN
&&
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
x
->
format
()
!=
memory
::
format
::
format_undef
,
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
"Wrong layout/format set for Input x tensor"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
mean_data
=
mean
->
data
<
T
>
();
const
T
*
variance_data
=
variance
->
data
<
T
>
();
T
*
y_data
=
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
mean_out_data
=
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
variance_out_data
=
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batch_mean_data
=
nullptr
;
T
*
batch_variance_data
=
nullptr
;
if
(
!
is_test
)
{
if
(
!
is_test
)
{
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_mean
_data
=
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_variance
_data
=
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
}
auto
propagation
=
is_test
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
auto
propagation
=
is_test
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
:
mkldnn
::
prop_kind
::
forward_training
;
:
mkldnn
::
prop_kind
::
forward_training
;
auto
dims
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
auto
src_tz
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
auto
scale_tz
=
paddle
::
framework
::
vectorize2int
(
scale
->
dims
());
auto
src_md
=
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
const
unsigned
int
ic
=
scale_tz
[
0
];
auto
dst_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
src_pd
=
mkldnn
::
memory
::
primitive_desc
{
src_md
,
mkldnn_engine
};
auto
dst_pd
=
mkldnn
::
memory
::
primitive_desc
{
dst_md
,
mkldnn_engine
};
auto
src
=
mkldnn
::
memory
{
src_pd
,
cast_const_to_void
(
x
->
data
<
T
>
())};
auto
dst
=
mkldnn
::
memory
{
dst_pd
,
y
->
data
<
T
>
()};
unsigned
flags
=
mkldnn
::
use_scale_shift
;
unsigned
flags
=
mkldnn
::
use_scale_shift
;
if
(
is_test
)
flags
|=
mkldnn
::
use_global_stats
;
if
(
is_test
)
flags
|=
mkldnn
::
use_global_stats
;
// create mkldnn memory from input x tensor
auto
src_memory
=
memory
({{{
src_tz
},
memory
::
data_type
::
f32
,
x
->
format
()},
mkldnn_engine
},
to_void_cast
(
x_data
));
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
bn_fwd_types
::
op_desc
{
propagation
,
src_md
,
epsilon
,
flags
};
propagation
,
src_memory
.
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_fwd_pd
=
bn_fwd_types
::
op_prim
{
batch_norm_fwd_desc
,
mkldnn_engine
};
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
(
new
batch_norm_fwd
::
primitive_desc
(
batch_norm_fwd_desc
,
mkldnn_engine
));
const
unsigned
int
ic
=
dims
[
1
];
// Save the pd to be used in backward pass
const
std
::
string
key
=
ctx
.
op
().
Output
(
"SavedMean"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
dev_ctx
.
SetBlob
(
key_batch_norm_fwd_pd
,
batch_norm_fwd_pd
);
// MKLDNN requires a single piece of memory for scale and shift/bias data
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
const
size_t
scaleshift_size
=
2
*
ic
;
...
@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -143,73 +153,58 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
auto
scaleshift_memory
=
mkldnn
::
memory
{
// crate mkldnn memory for weights(scale/shift)
batch_norm_fwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
()};
auto
scaleshift_memory
=
memory
(
batch_norm_fwd_pd
->
weights_primitive_desc
(),
scaleshift_data
.
data
());
if
(
is_test
)
{
// create mkldnn memory for output y tensor
auto
mean_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
mean_primitive_desc
(),
auto
dst_memory
=
memory
(
batch_norm_fwd_pd
->
dst_primitive_desc
(),
y_data
);
cast_const_to_void
(
mean
->
data
<
T
>
())};
if
(
is_test
)
{
// create mkldnn memory for stats (as input)
auto
mean_memory
=
memory
(
batch_norm_fwd_pd
->
mean_primitive_desc
(),
to_void_cast
(
mean_data
));
auto
variance_memory
=
auto
variance_memory
=
m
kldnn
::
memory
{
batch_norm_fwd_pd
.
variance_primitive_desc
(),
m
emory
(
batch_norm_fwd_pd
->
variance_primitive_desc
(),
cast_const_to_void
(
variance
->
data
<
T
>
())}
;
to_void_cast
(
variance_data
))
;
run_batch_norm_op
<
typename
bn_fwd_types
::
op_type
>
(
run_batch_norm_op
<
typename
bn_fwd_types
::
op_type
>
(
batch_norm_fwd_pd
,
src
,
(
const
mkldnn
::
primitive
::
at
&
)
mean_memory
,
*
batch_norm_fwd_pd
,
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
variance_memory
,
scaleshift_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
variance_memory
,
scaleshift_memory
,
dst
);
dst
_memory
);
}
else
{
}
else
{
// create mkldnn memory for stats (as output)
auto
mean_memory
=
auto
mean_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
mean_primitive_desc
(),
memory
(
batch_norm_fwd_pd
->
mean_primitive_desc
(),
batch_mean_data
);
cast_const_to_void
(
batch_mean
->
data
<
T
>
())};
auto
variance_memory
=
memory
(
batch_norm_fwd_pd
->
variance_primitive_desc
(),
batch_variance_data
);
auto
variance_memory
=
mkldnn
::
memory
{
batch_norm_fwd_pd
.
variance_primitive_desc
(),
cast_const_to_void
(
batch_variance
->
data
<
T
>
())};
run_batch_norm_op
<
bn_fwd_types
::
op_type
>
(
batch_norm_fwd_pd
,
src
,
run_batch_norm_op
<
bn_fwd_types
::
op_type
>
(
*
batch_norm_fwd_pd
,
src_memory
,
scaleshift_memory
,
dst
,
scaleshift_memory
,
dst
_memory
,
mean_memory
,
variance_memory
);
mean_memory
,
variance_memory
);
}
}
if
(
!
is_test
)
{
if
(
!
is_test
)
{
const
unsigned
int
in
=
dims
[
0
];
// mkldnn only compute stats for current batch
const
unsigned
int
sample_size
=
x
->
numel
()
/
in
/
ic
;
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap
<
T
>
batch_mean_e
(
batch_mean_data
,
ic
);
// saved_xx is use just in this batch of data
EigenVectorArrayMap
<
T
>
batch_variance_e
(
batch_variance_data
,
ic
);
EigenVectorArrayMap
<
T
>
saved_mean_e
(
ConstEigenVectorArrayMap
<
T
>
mean_e
(
mean_data
,
ic
);
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
ConstEigenVectorArrayMap
<
T
>
variance_e
{
variance_data
,
ic
};
EigenVectorArrayMap
<
T
>
saved_variance_e
(
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
saved_mean_e
.
setZero
();
saved_variance_e
.
setZero
();
const
unsigned
int
x_arr_size
=
in
*
ic
;
ConstEigenArrayMap
<
T
>
x_arr
(
x
->
data
<
T
>
(),
sample_size
,
x_arr_size
);
for
(
unsigned
int
nc
=
0
;
nc
<
x_arr_size
;
++
nc
)
{
saved_mean_e
(
nc
%
ic
)
+=
x_arr
.
col
(
nc
).
sum
();
}
saved_mean_e
/=
in
*
sample_size
;
for
(
unsigned
int
nc
=
0
;
nc
<
x_arr_size
;
++
nc
)
{
saved_variance_e
(
nc
%
ic
)
+=
(
x_arr
.
col
(
nc
)
-
saved_mean_e
(
nc
%
ic
)).
matrix
().
squaredNorm
();
}
saved_variance_e
/=
in
*
sample_size
;
ConstEigenVectorArrayMap
<
T
>
mean_arr
{
mean
->
data
<
T
>
(),
ic
};
ConstEigenVectorArrayMap
<
T
>
variance_arr
{
variance
->
data
<
T
>
(),
ic
};
EigenVectorArrayMap
<
T
>
running_mean_arr
(
EigenVectorArrayMap
<
T
>
running_mean_e
(
mean_out_data
,
ic
);
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
EigenVectorArrayMap
<
T
>
running_variance_e
(
variance_out_data
,
ic
);
EigenVectorArrayMap
<
T
>
running_var_arr
(
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ic
);
auto
one_minus_momentum
=
1.
-
momentum
;
auto
one_minus_momentum
=
1.
-
momentum
;
running_mean_arr
=
running_mean_e
=
mean_e
*
momentum
+
batch_mean_e
*
one_minus_momentum
;
mean_arr
*
momentum
+
saved_mean_e
*
one_minus_momentum
;
running_variance_e
=
running_var_arr
=
variance_e
*
momentum
+
batch_variance_e
*
one_minus_momentum
;
variance_arr
*
momentum
+
saved_variance_e
*
one_minus_momentum
;
}
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
(
memory
::
format
)
dst_memory
.
get_primitive_desc
().
desc
().
data
.
format
);
}
}
};
};
...
@@ -217,11 +212,6 @@ template <typename T>
...
@@ -217,11 +212,6 @@ template <typename T>
class
BatchNormMKLDNNGradOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
BatchNormMKLDNNGradOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_layout_str
=
ctx
.
Attr
<
std
::
string
>
(
"data_layout"
);
auto
data_layout
=
framework
::
StringToDataLayout
(
data_layout_str
);
PADDLE_ENFORCE
(
data_layout
==
framework
::
DataLayout
::
kNCHW
,
"MKLDNN batch normalization handles only NCHW data layout"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
...
@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -238,88 +228,132 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto
*
diff_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
diff_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
diff_shift
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
auto
*
diff_shift
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
diff_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE
(
diff_y
->
layout
()
==
DataLayout
::
kMKLDNN
&&
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
diff_y
->
format
()
!=
memory
::
format
::
format_undef
,
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
"Wrong layout/format set for Input diff_y tensor"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
diff_y_data
=
diff_y
->
data
<
T
>
();
const
T
*
batch_mean_data
=
batch_mean
->
data
<
T
>
();
const
T
*
batch_variance_data
=
batch_variance
->
data
<
T
>
();
const
T
*
scale_data
=
scale
->
data
<
T
>
();
const
T
*
shift_data
=
shift
->
data
<
T
>
();
T
*
diff_x_data
=
diff_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_scale_data
=
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_shift_data
=
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
src_tz
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
auto
diff_src_tz
=
src_tz
;
auto
dst_tz
=
src_tz
;
auto
diff_dst_tz
=
dst_tz
;
auto
scale_tz
=
paddle
::
framework
::
vectorize2int
(
scale
->
dims
());
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
// Retrieve bn_fwd_pd from device context
const
std
::
string
key
=
ctx
.
op
().
Input
(
"SavedMean"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
batch_norm_fwd_pd
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
"Fail to find batch_norm_fwd_pd in device context"
);
auto
dims
=
paddle
::
framework
::
vectorize2int
(
x
->
dims
());
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
unsigned
flags
=
mkldnn
::
use_scale_shift
|
!
mkldnn
::
use_global_stats
;
auto
src_md
=
// create mkldnn memory from input diff_y tensor
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
user_diff_dst_memory
=
auto
dst_md
=
memory
({{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
diff_y
->
format
()},
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
mkldnn_engine
},
auto
diff_src_md
=
to_void_cast
(
diff_y_data
));
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
auto
diff_dst_md
=
MKLDNNMemDesc
(
dims
,
memory
::
data_type
::
f32
,
memory
::
format
::
nchw
);
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
// create mkldnn memory from input x tensor
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
src_memory
=
memory
({{{
src_tz
},
memory
::
data_type
::
f32
,
x
->
format
()},
mkldnn_engine
},
to_void_cast
(
x_data
));
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
// for diff_dst, try to use same format as dst in forward pass
mkldnn
::
prop_kind
::
forward_training
,
src_md
,
epsilon
,
flags
};
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
batch_norm_fwd_pd
=
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
bn_fwd_types
::
op_prim
{
batch_norm_fwd_desc
,
mkldnn_engine
};
// create primitive descriptor for batch norm backward
unsigned
flags
=
mkldnn
::
use_scale_shift
;
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
dst_md
,
epsilon
,
flags
};
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
.
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
batch_norm_fwd_pd
};
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
auto
src
=
mkldnn
::
memory
{{
src_md
,
mkldnn_engine
},
// reorder user_diff_dst if it's not in preferred format
cast_const_to_void
(
x
->
data
<
T
>
())};
auto
diff_dst_memory
=
user_diff_dst_memory
;
primitive
reorder_diff_dst
;
auto
mean
=
mkldnn
::
memory
{
batch_norm_bwd_pd
.
mean_primitive_desc
(),
bool
is_diff_dst_reordered
=
false
;
cast_const_to_void
(
batch_mean
->
data
<
T
>
())};
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
memory
(
diff_dst_pd
);
auto
variance
=
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
diff_dst_memory
);
mkldnn
::
memory
{
batch_norm_bwd_pd
.
variance_primitive_desc
(),
is_diff_dst_reordered
=
true
;
cast_const_to_void
(
batch_variance
->
data
<
T
>
())};
}
auto
diff_dst
=
mkldnn
::
memory
{{
diff_dst_md
,
mkldnn_engine
},
cast_const_to_void
(
diff_y
->
data
<
T
>
())};
const
unsigned
int
ic
=
dims
[
1
];
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
memory
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
memory
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
()
,
copy_to_weights
(
scale
_data
,
scale_data
+
ic
,
shift_data
,
shift_data
+
ic
,
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
&
scaleshift_data
);
auto
scaleshift_memory
=
mkldnn
::
memory
{
// create mkldnn memory for input tensors (scale/shift)
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
()};
auto
scaleshift_memory
=
memory
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
std
::
vector
<
T
>
diff_scaleshift_data
;
std
::
vector
<
T
>
diff_scaleshift_data
;
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
diff_scale
->
data
<
T
>
(),
diff_scale
->
data
<
T
>
()
+
ic
,
diff_shift
->
data
<
T
>
(),
diff_shift
->
data
<
T
>
()
+
ic
,
&
diff_scaleshift_data
);
auto
diff_scaleshift_memory
=
auto
diff_scaleshift_memory
=
mkldnn
::
memory
{
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
memory
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
()};
diff_scaleshift_data
.
data
());
auto
diff_src
=
mkldnn
::
memory
{{
diff_src_md
,
mkldnn_engine
},
// here assume diff_src is in the same format of src
static_cast
<
void
*>
(
diff_x
->
data
<
T
>
())};
auto
diff_src_memory
=
memory
(
src_memory
.
get_primitive_desc
(),
diff_x_data
);
run_batch_norm_op
<
bn_bwd_types
::
op_type
>
(
// finally create batch_norm backward primitive
batch_norm_bwd_pd
,
src
,
mean
,
variance
,
diff_dst
,
scaleshift_memory
,
auto
batch_norm_bwd_prim
=
diff_src
,
diff_scaleshift_memory
);
batch_norm_bwd
(
batch_norm_bwd_pd
,
src_memory
,
mean_memory
,
variance_memory
,
diff_dst_memory
,
scaleshift_memory
,
diff_src_memory
,
diff_scaleshift_memory
);
// execute optional reorder and batch_norm backward primitive
std
::
vector
<
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
pipeline
.
push_back
(
reorder_diff_dst
);
pipeline
.
push_back
(
batch_norm_bwd_prim
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
// copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data
.
resize
(
scaleshift_size
);
auto
it
=
std
::
begin
(
diff_scaleshift_data
);
auto
it
=
std
::
begin
(
diff_scaleshift_data
);
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale
->
data
<
T
>
()
);
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale
_data
);
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
diff_shift
->
data
<
T
>
());
diff_shift_data
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
.
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_KERNEL
(
batch_norm
,
MKLDNN
,
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL
(
batch_norm
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
ops
::
BatchNormMKLDNNOpKernel
<
float
>
);
ops
::
BatchNormMKLDNNOpKernel
<
float
>
);
REGISTER_OP_KERNEL
(
batch_norm_grad
,
MKLDNN
,
paddle
::
platform
::
CPUPlace
,
REGISTER_OP_KERNEL
(
batch_norm_grad
,
MKLDNN
,
::
paddle
::
platform
::
CPUPlace
,
ops
::
BatchNormMKLDNNGradOpKernel
<
float
>
);
ops
::
BatchNormMKLDNNGradOpKernel
<
float
>
);
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
7d564356
...
@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel {
...
@@ -110,19 +110,19 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
()),
ctx
.
Input
<
Tensor
>
(
"Variance"
)
->
type
()),
"Variance input should be of float type"
);
"Variance input should be of float type"
);
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
library
_
==
framework
::
LibraryType
::
kPlain
&&
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library
_
=
framework
::
LibraryType
::
kMKLDNN
;
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
_
);
library
);
}
}
};
};
...
@@ -368,19 +368,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
...
@@ -368,19 +368,21 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
PADDLE_THROW
(
"can't find Y@GRAD"
);
PADDLE_THROW
(
"can't find Y@GRAD"
);
}
}
framework
::
LibraryType
library_
{
framework
::
LibraryType
::
kPlain
};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework
::
DataLayout
layout_
=
framework
::
DataLayout
::
kAnyLayout
;
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
if
(
library
_
==
framework
::
LibraryType
::
kPlain
&&
if
(
library
==
framework
::
LibraryType
::
kPlain
&&
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
platform
::
CanMKLDNNBeUsed
(
ctx
))
{
library
_
=
framework
::
LibraryType
::
kMKLDNN
;
library
=
framework
::
LibraryType
::
kMKLDNN
;
layout
_
=
framework
::
DataLayout
::
kMKLDNN
;
layout
=
framework
::
DataLayout
::
kMKLDNN
;
}
}
#endif
#endif
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
()),
ctx
.
GetPlace
(),
layout
_
,
library_
);
layout
,
library
);
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录