Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4b65af77
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4b65af77
编写于
9月 26, 2019
作者:
A
Adam
提交者:
Tao Luo
9月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN BatchNorm operator refactor (#20012)
test=develop
上级
bda7eab7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
155 addition
and
362 deletion
+155
-362
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
+155
-362
未找到文件。
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
浏览文件 @
4b65af77
...
...
@@ -19,136 +19,103 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
batch_norm_bwd
=
mkldnn
::
batch_normalization_backward
;
using
batch_norm_fwd
=
mkldnn
::
batch_normalization_forward
;
using
mkldnn
::
memory
;
using
mkldnn
::
primitive
;
using
mkldnn
::
reorder
;
using
mkldnn
::
stream
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNMemDesc
;
using
platform
::
to_void_cast
;
namespace
{
template
<
typename
T
>
struct
bn_type_traits
{
using
op_type
=
T
;
using
op_desc
=
typename
op_type
::
desc
;
using
op_prim
=
typename
op_type
::
primitive_desc
;
};
class
BatchNormMKLDNNHandler
:
public
platform
::
MKLDNNHandler
{
class
BatchNormMKLDNNHandler
:
public
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
{
public:
BatchNormMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{}
BatchNormMKLDNNHandler
(
const
std
::
vector
<
int
>
&
dims
,
const
float
&
epsilon
,
const
unsigned
&
flags
,
const
bool
&
global_stats
,
const
MKLDNNMemoryFormat
fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
epsilon
,
flags
,
global_stats
,
fmt
,
uniq_name
))
{
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
global_stats
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
:
mkldnn
::
prop_kind
::
forward_training
,
md
,
epsilon
,
flags
);
}
BatchNormMKLDNNHandler
(
const
std
::
vector
<
int
>
&
dims
,
const
float
&
epsilon
,
const
unsigned
&
flags
,
const
MKLDNNMemoryFormat
diff_fmt
,
const
MKLDNNMemoryFormat
src_fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
epsilon
,
flags
,
false
,
src_fmt
,
uniq_name
))
{
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_fmt
);
this
->
AcquireBackwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_md
,
epsilon
,
flags
);
}
std
::
shared_ptr
<
m
emory
>
AcquireScaleshiftMemoryFromPrimitive
(
void
*
ptr
)
{
std
::
shared_ptr
<
m
kldnn
::
memory
>
AcquireScaleShiftMemory
(
T
*
scaleshift_data
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
weights_primitive_desc
(),
ptr
,
"@scaleshift_mem_p"
);
this
->
fwd_pd_
->
weights_primitive_desc
(),
scaleshift_data
,
"@scaleshift_mem_p"
);
}
std
::
shared_ptr
<
memory
>
AcquireMeanMemoryFromPrimitive
(
void
*
ptr
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffScaleShiftMemory
(
T
*
diff_scaleshift_data
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
mean_primitive_desc
(),
ptr
,
"@mean_mem_p"
);
this
->
bwd_pd_
->
diff_weights_primitive_desc
(),
diff_scaleshift_data
,
"@diff_scaleshift_mem_p"
);
}
std
::
shared_ptr
<
memory
>
AcquireVarianceMemoryFromPrimitive
(
void
*
ptr
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMeanMemory
(
const
framework
::
Tensor
*
mean
)
{
const
T
*
mean_data
=
mean
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
variance_primitive_desc
(),
ptr
,
"@variance_mem_p"
);
this
->
fwd_pd_
->
mean_primitive_desc
(),
to_void_cast
<
T
>
(
mean_data
),
"@mean_mem_p"
);
}
template
<
typename
T
>
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromPrimitive
(
framework
::
Tensor
*
output
,
platform
::
Place
place
)
{
T
*
ptr
=
output
->
mutable_data
<
T
>
(
place
,
batch_norm_pd_
->
dst_primitive_desc
().
get_size
());
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMeanMemory
(
framework
::
Tensor
*
mean
)
{
T
*
mean_data
=
mean
->
mutable_data
<
T
>
(
this
->
place_
,
this
->
fwd_pd_
->
mean_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst
_mem_p"
);
this
->
fwd_pd_
->
mean_primitive_desc
(),
mean_data
,
"@mean
_mem_p"
);
}
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
AcquireBatchNormPrimitiveDescriptor
(
const
batch_norm_fwd
::
desc
&
bn_fwd_desc
,
const
mkldnn
::
engine
&
engine
)
{
// BatchNorm PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const
std
::
string
key_batch_norm_fwd_pd
=
key_common_
+
"@bn_fwd_pd"
;
batch_norm_pd_
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_batch_norm_fwd_pd
));
if
(
batch_norm_pd_
==
nullptr
)
{
static
std
::
mutex
acquire_barrier
;
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
acquire_barrier
);
batch_norm_pd_
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_batch_norm_fwd_pd
));
if
(
batch_norm_pd_
==
nullptr
)
{
batch_norm_pd_
.
reset
(
new
batch_norm_fwd
::
primitive_desc
(
bn_fwd_desc
,
engine
));
dev_ctx_
.
SetBlob
(
key_batch_norm_fwd_pd
,
batch_norm_pd_
);
}
}
return
batch_norm_pd_
;
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireVarianceMemory
(
const
framework
::
Tensor
*
variance
)
{
const
T
*
variance_data
=
variance
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
variance_primitive_desc
(),
to_void_cast
<
T
>
(
variance_data
),
"@variance_mem_p"
);
}
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTestTrainingBatchNormFwd
(
std
::
shared_ptr
<
memory
>
src_memory
,
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
dst_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
,
std
::
shared_ptr
<
memory
>
variance_memory
,
bool
is_test
)
{
auto
prim_key
=
key_
+
"@batch_norm_p"
;
auto
batch_norm_p
=
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
batch_norm_p
==
nullptr
)
{
if
(
is_test
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
}
else
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
}
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
return
batch_norm_p
;
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireVarianceMemory
(
framework
::
Tensor
*
variance
)
{
T
*
variance_data
=
variance
->
mutable_data
<
T
>
(
this
->
place_
,
this
->
fwd_pd_
->
variance_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
this
->
fwd_pd_
->
variance_primitive_desc
(),
variance_data
,
"@variance_mem_p"
);
}
private:
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd_
;
};
std
::
shared_ptr
<
memory
>
UpdateMemoryData
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
std
::
string
&
key
,
void
*
new_ptr
)
{
auto
mem
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key
));
PADDLE_ENFORCE
(
mem
!=
nullptr
,
(
std
::
string
(
"Fail to find memory in device context [key: "
)
+
key
+
"]"
)
.
c_str
());
mem
->
set_data_handle
(
new_ptr
);
return
mem
;
}
template
<
typename
T
,
typename
Container
>
void
copy_to_weights
(
T
scale_begin
,
T
scale_end
,
T
shift_begin
,
T
shift_end
,
Container
*
c
)
{
auto
it
=
std
::
begin
(
*
c
);
std
::
copy
(
scale_begin
,
scale_end
,
std
::
inserter
(
*
c
,
it
));
std
::
copy
(
shift_begin
,
shift_end
,
std
::
inserter
(
*
c
,
std
::
next
(
it
,
std
::
distance
(
scale_begin
,
scale_end
))));
}
}
// namespace
template
<
typename
T
>
class
BatchNormMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -158,14 +125,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
const
bool
fuse_with_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_with_relu"
);
bool
global_stats
=
is_test
||
use_global_stats
;
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
bool
global_stats
=
is_test
||
use_global_stats
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
mean_out
=
ctx
.
Output
<
Tensor
>
(
"MeanOut"
);
...
...
@@ -173,102 +140,61 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
*
batch_mean
=
ctx
.
Output
<
Tensor
>
(
"SavedMean"
);
auto
*
batch_variance
=
ctx
.
Output
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
x
->
layout
(),
DataLayout
::
kMKLDNN
,
"Wrong layout set for X tensor"
);
PADDLE_ENFORCE_NE
(
x
->
format
(),
MKLDNNMemoryFormat
::
format_undef
,
"Wrong format set for X tensor"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
mean_data
=
mean
->
data
<
T
>
();
const
T
*
variance_data
=
variance
->
data
<
T
>
();
T
*
mean_out_data
=
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
variance_out_data
=
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batch_mean_data
=
nullptr
;
T
*
batch_variance_data
=
nullptr
;
if
(
!
global_stats
)
{
batch_mean_data
=
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_variance_data
=
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
auto
propagation
=
global_stats
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
:
mkldnn
::
prop_kind
::
forward_training
;
auto
src_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
scale_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
scale
->
dims
());
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
const
unsigned
int
C
=
scale_tz
[
0
];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
unsigned
flags
=
mkldnn
::
use_scale_shift
;
if
(
global_stats
)
flags
|=
mkldnn
::
use_global_stats
;
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
// create mkldnn memory from input x tensor
MKLDNNMemoryFormat
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
// keys for backward pass
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
epsilon
,
flags
,
global_stats
,
input_format
,
ctx
.
op
().
Output
(
"SavedMean"
));
BatchNormMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
user_src_md
,
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
handler
.
AcquireBatchNormPrimitiveDescriptor
(
batch_norm_fwd_desc
,
mkldnn_engine
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
(
x_data
));
// crate mkldnn memory for weights(scale/shift)
std
::
vector
<
T
>
scaleshift_data
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
C
);
scaleshift_data
.
reserve
(
2
*
C
);
scaleshift_data
.
insert
(
scaleshift_data
.
end
(),
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
C
);
// Flags are added by bitwise OR operation
unsigned
flags
=
mkldnn
::
use_scale_shift
;
// 001
if
(
global_stats
)
flags
|=
mkldnn
::
use_global_stats
;
// 010
if
(
fuse_with_relu
&&
is_test
)
flags
|=
mkldnn
::
fuse_bn_relu
;
// 100
BatchNormMKLDNNHandler
<
T
>
handler
(
src_tz
,
epsilon
,
flags
,
global_stats
,
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
()),
dev_ctx
,
ctx
.
GetPlace
(),
ctx
.
op
().
Output
(
"SavedMean"
));
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
auto
scaleshift_memory
=
handler
.
AcquireScaleshiftMemoryFromPrimitive
(
scaleshift_data
.
data
());
// create mkldnn memory for output y tensor
auto
dst_memory
=
handler
.
AcquireDstMemoryFromPrimitive
<
T
>
(
y
,
ctx
.
GetPlace
());
handler
.
AcquireScaleShiftMemory
(
scaleshift_data
.
data
());
auto
dst_memory
=
handler
.
AcquireDstMemory
(
y
);
std
::
shared_ptr
<
batch_norm_fw
d
>
batch_norm_p
;
std
::
shared_ptr
<
mkldnn
::
batch_normalization_forwar
d
>
batch_norm_p
;
if
(
global_stats
)
{
// create mkldnn memory for stats (as input)
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemoryFromPrimitive
(
to_void_cast
(
mean_data
));
// mean and variance are taken from input Tensor
const
auto
*
mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemory
(
mean
);
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
to_void_cast
(
variance_data
));
handler
.
AcquireVarianceMemory
(
variance
);
batch_norm_p
=
handler
.
AcquireTestTrainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
variance_memory
,
true
);
batch_norm_p
=
handler
.
AcquireForwardPrimitive
(
*
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
}
else
{
//
create mkldnn memory for stats (as output)
//
mean and variance are calculated and saved in output Tensor
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemory
FromPrimitive
(
batch_mean_data
);
handler
.
AcquireMeanMemory
(
batch_mean
);
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemory
FromPrimitive
(
batch_variance_data
);
handler
.
AcquireVarianceMemory
(
batch_variance
);
batch_norm_p
=
handler
.
Acquire
TestTrainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
variance_memory
,
false
);
batch_norm_p
=
handler
.
Acquire
ForwardPrimitive
(
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
...
...
@@ -281,18 +207,20 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
!
global_stats
)
{
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap
<
T
>
batch_mean_e
(
batch_mean_data
,
ic
);
EigenVectorArrayMap
<
T
>
batch_variance_e
(
batch_variance_data
,
ic
);
ConstEigenVectorArrayMap
<
T
>
mean_e
(
mean_data
,
ic
);
ConstEigenVectorArrayMap
<
T
>
variance_e
{
variance_data
,
ic
};
EigenVectorArrayMap
<
T
>
running_mean_e
(
mean_out_data
,
ic
);
EigenVectorArrayMap
<
T
>
running_variance_e
(
variance_out_data
,
ic
);
auto
one_minus_momentum
=
1.
-
momentum
;
running_mean_e
=
mean_e
*
momentum
+
batch_mean_e
*
one_minus_momentum
;
EigenVectorArrayMap
<
T
>
batch_mean_e
(
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
EigenVectorArrayMap
<
T
>
batch_variance_e
(
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
EigenVectorArrayMap
<
T
>
running_mean_e
(
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
EigenVectorArrayMap
<
T
>
running_variance_e
(
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
running_mean_e
=
running_mean_e
*
momentum
+
batch_mean_e
*
(
1.
-
momentum
);
running_variance_e
=
variance_e
*
momentum
+
batch_variance_e
*
one_minus_momentum
;
running_variance_e
*
momentum
+
batch_variance_e
*
(
1.
-
momentum
)
;
}
}
};
...
...
@@ -311,7 +239,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
batch_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
batch_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
diff_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
diff_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
...
...
@@ -322,27 +249,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_NE
(
diff_y
->
format
(),
MKLDNNMemoryFormat
::
format_undef
,
"Wrong format set for Input diff_y tensor"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
diff_y_data
=
diff_y
->
data
<
T
>
();
const
T
*
batch_mean_data
=
batch_mean
->
data
<
T
>
();
const
T
*
batch_variance_data
=
batch_variance
->
data
<
T
>
();
const
T
*
scale_data
=
scale
->
data
<
T
>
();
const
T
*
shift_data
=
shift
->
data
<
T
>
();
T
*
diff_x_data
=
diff_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_scale_data
=
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_shift_data
=
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
src_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
diff_src_tz
=
src_tz
;
auto
dst_tz
=
src_tz
;
auto
diff_dst_tz
=
dst_tz
;
auto
scale_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
scale
->
dims
());
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
const
unsigned
int
C
=
scale_tz
[
0
];
MKLDNNMemoryFormat
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
...
...
@@ -350,170 +261,52 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
MKLDNNMemoryFormat
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
unsigned
flags
=
mkldnn
::
use_scale_shift
;
// keys from forward pass
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
,
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
platform
::
CreateKey
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_src_mem_p"
;
const
std
::
string
key_batch_norm_mean_mem_p
=
key_with_hash
+
"@batch_norm_bwd_mean_mem_p"
;
const
std
::
string
key_batch_norm_variance_mem_p
=
key_with_hash
+
"@batch_norm_bwd_variance_mem_p"
;
const
std
::
string
key_batch_norm_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_src_mem_p"
;
const
std
::
string
key_batch_norm_diff_dst_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_dst_mem_p"
;
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
BatchNormMKLDNNHandler
<
T
>
handler
(
src_tz
,
epsilon
,
mkldnn
::
use_scale_shift
,
dst_format
,
input_format
,
dev_ctx
,
ctx
.
GetPlace
(),
ctx
.
op
().
Input
(
"SavedMean"
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
const
size_t
scaleshift_size
=
2
*
C
;
std
::
vector
<
T
>
scaleshift_data
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
C
);
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale_data
,
scale_data
+
ic
,
shift_data
,
shift_data
+
ic
,
&
scaleshift_data
);
scaleshift_data
.
insert
(
scaleshift_data
.
end
(),
shift
->
data
<
T
>
()
,
shift
->
data
<
T
>
()
+
C
);
std
::
vector
<
T
>
diff_scaleshift_data
;
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
auto
batch_norm_fwd_pd
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
"Fail to find batch_norm_fwd_pd in device context"
);
auto
batch_norm_bwd_p
=
std
::
static_pointer_cast
<
batch_norm_bwd
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_bwd_p
));
if
(
batch_norm_bwd_p
==
nullptr
)
{
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
)));
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
// create primitive descriptor for batch norm backward
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
->
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
// reorder user_diff_dst if it's not in preferred format
auto
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
user_diff_dst_memory
);
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
diff_dst_pd
);
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
// create mkldnn memory for input tensors (scale/shift)
auto
scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
auto
diff_scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
());
// here assume diff_src is in the same format of src
auto
diff_src_memory
=
std
::
make_shared
<
memory
>
(
src_memory
->
get_primitive_desc
(),
diff_x_data
);
// finally create batch_norm backward primitive
batch_norm_bwd_p
=
std
::
make_shared
<
batch_norm_bwd
>
(
batch_norm_bwd_pd
,
*
src_memory
,
*
mean_memory
,
*
variance_memory
,
*
diff_dst_memory
,
*
scaleshift_memory
,
*
diff_src_memory
,
*
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_bwd_p
,
batch_norm_bwd_p
);
dev_ctx
.
SetBlob
(
key_batch_norm_src_mem_p
,
src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_mean_mem_p
,
mean_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_variance_mem_p
,
variance_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_scaleshift_mem_p
,
scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_src_mem_p
,
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_dst_mem_p
,
diff_dst_memory
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
(
MKLDNNMemoryFormat
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
else
{
// primitives already exist
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_src_mem_p
,
to_void_cast
(
x_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_mean_mem_p
,
to_void_cast
(
batch_mean_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_variance_mem_p
,
to_void_cast
(
batch_variance_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_scaleshift_mem_p
,
scaleshift_data
.
data
());
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_data
.
data
());
auto
diff_src_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_src_mem_p
,
to_void_cast
(
diff_x_data
));
auto
diff_dst_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_dst_mem_p
,
to_void_cast
(
diff_y_data
));
// reorder user_diff_dst if it's not in preferred format
if
(
diff_dst_memory
->
get_primitive_desc
()
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
(
MKLDNNMemoryFormat
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
auto
mean_memory
=
handler
.
AcquireMeanMemory
(
batch_mean
);
auto
variance_memory
=
handler
.
AcquireVarianceMemory
(
batch_variance
);
auto
diff_dst_memory
=
handler
.
AcquireDiffDstMemory
(
diff_y
);
auto
scaleshift_memory
=
handler
.
AcquireScaleShiftMemory
(
scaleshift_data
.
data
());
auto
diff_src_memory
=
handler
.
AcquireDiffSrcMemory
(
diff_x
);
auto
diff_scaleshift_memory
=
handler
.
AcquireDiffScaleShiftMemory
(
diff_scaleshift_data
.
data
());
// finally create batch_norm backward primitive
auto
batch_norm_bwd_p
=
handler
.
AcquireBackwardPrimitive
(
*
src_memory
,
*
mean_memory
,
*
variance_memory
,
*
diff_dst_memory
,
*
scaleshift_memory
,
*
diff_src_memory
,
*
diff_scaleshift_memory
);
// execute optional reorder and batch_norm backward primitive
std
::
vector
<
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
pipeline
.
push_back
(
reorder_diff_dst
);
pipeline
.
push_back
(
*
batch_norm_bwd_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
T
*
diff_scale_data
=
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_shift_data
=
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data
.
resize
(
scaleshift_size
);
auto
it
=
std
::
begin
(
diff_scaleshift_data
);
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale_data
);
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
std
::
copy
(
it
,
std
::
next
(
it
,
C
),
diff_scale_data
);
std
::
copy
(
std
::
next
(
it
,
C
),
std
::
end
(
diff_scaleshift_data
),
diff_shift_data
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
diff_src_memory
));
}
};
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录