Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4b65af77
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4b65af77
编写于
9月 26, 2019
作者:
A
Adam
提交者:
Tao Luo
9月 26, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
MKLDNN BatchNorm operator refactor (#20012)
test=develop
上级
bda7eab7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
155 addition
and
362 deletion
+155
-362
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
+155
-362
未找到文件。
paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc
浏览文件 @
4b65af77
...
@@ -19,136 +19,103 @@ limitations under the License. */
...
@@ -19,136 +19,103 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
batch_norm_bwd
=
mkldnn
::
batch_normalization_backward
;
using
batch_norm_fwd
=
mkldnn
::
batch_normalization_forward
;
using
mkldnn
::
memory
;
using
mkldnn
::
memory
;
using
mkldnn
::
primitive
;
using
mkldnn
::
primitive
;
using
mkldnn
::
reorder
;
using
mkldnn
::
reorder
;
using
mkldnn
::
stream
;
using
mkldnn
::
stream
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNDeviceContext
;
using
paddle
::
platform
::
MKLDNNMemDesc
;
using
platform
::
to_void_cast
;
using
platform
::
to_void_cast
;
namespace
{
template
<
typename
T
>
template
<
typename
T
>
struct
bn_type_traits
{
class
BatchNormMKLDNNHandler
using
op_type
=
T
;
:
public
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
using
op_desc
=
typename
op_type
::
desc
;
mkldnn
::
batch_normalization_backward
>
{
using
op_prim
=
typename
op_type
::
primitive_desc
;
};
class
BatchNormMKLDNNHandler
:
public
platform
::
MKLDNNHandler
{
public:
public:
BatchNormMKLDNNHandler
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
BatchNormMKLDNNHandler
(
const
std
::
vector
<
int
>
&
dims
,
const
float
&
epsilon
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
const
unsigned
&
flags
,
const
bool
&
global_stats
,
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{}
const
MKLDNNMemoryFormat
fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
epsilon
,
flags
,
global_stats
,
fmt
,
uniq_name
))
{
auto
md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
fmt
);
this
->
AcquireForwardPrimitiveDescriptor
(
global_stats
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
:
mkldnn
::
prop_kind
::
forward_training
,
md
,
epsilon
,
flags
);
}
BatchNormMKLDNNHandler
(
const
std
::
vector
<
int
>
&
dims
,
const
float
&
epsilon
,
const
unsigned
&
flags
,
const
MKLDNNMemoryFormat
diff_fmt
,
const
MKLDNNMemoryFormat
src_fmt
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
platform
::
Place
cpu_place
,
const
std
::
string
&
uniq_name
)
:
platform
::
MKLDNNHandlerT
<
T
,
mkldnn
::
batch_normalization_forward
,
mkldnn
::
batch_normalization_backward
>
(
dev_ctx
,
dev_ctx
.
GetEngine
(),
cpu_place
,
platform
::
CreateKey
(
dims
,
epsilon
,
flags
,
false
,
src_fmt
,
uniq_name
))
{
auto
diff_dst_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
diff_fmt
);
auto
src_md
=
mkldnn
::
memory
::
desc
(
dims
,
platform
::
MKLDNNGetDataType
<
T
>
(),
src_fmt
);
this
->
AcquireBackwardPrimitiveDescriptor
(
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_md
,
epsilon
,
flags
);
}
std
::
shared_ptr
<
m
emory
>
AcquireScaleshiftMemoryFromPrimitive
(
void
*
ptr
)
{
std
::
shared_ptr
<
m
kldnn
::
memory
>
AcquireScaleShiftMemory
(
T
*
scaleshift_data
)
{
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
weights_primitive_desc
(),
ptr
,
"@scaleshift_mem_p"
);
this
->
fwd_pd_
->
weights_primitive_desc
(),
scaleshift_data
,
"@scaleshift_mem_p"
);
}
}
std
::
shared_ptr
<
memory
>
AcquireMeanMemoryFromPrimitive
(
void
*
ptr
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDiffScaleShiftMemory
(
T
*
diff_scaleshift_data
)
{
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
mean_primitive_desc
(),
ptr
,
"@mean_mem_p"
);
this
->
bwd_pd_
->
diff_weights_primitive_desc
(),
diff_scaleshift_data
,
"@diff_scaleshift_mem_p"
);
}
}
std
::
shared_ptr
<
memory
>
AcquireVarianceMemoryFromPrimitive
(
void
*
ptr
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMeanMemory
(
const
framework
::
Tensor
*
mean
)
{
const
T
*
mean_data
=
mean
->
data
<
T
>
();
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
variance_primitive_desc
(),
ptr
,
"@variance_mem_p"
);
this
->
fwd_pd_
->
mean_primitive_desc
(),
to_void_cast
<
T
>
(
mean_data
),
"@mean_mem_p"
);
}
}
template
<
typename
T
>
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireMeanMemory
(
framework
::
Tensor
*
mean
)
{
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireDstMemoryFromPrimitive
(
T
*
mean_data
=
mean
->
mutable_data
<
T
>
(
framework
::
Tensor
*
output
,
platform
::
Place
place
)
{
this
->
place_
,
this
->
fwd_pd_
->
mean_primitive_desc
().
get_size
());
T
*
ptr
=
output
->
mutable_data
<
T
>
(
place
,
batch_norm_pd_
->
dst_primitive_desc
().
get_size
());
return
this
->
AcquireMemoryFromPrimitive
(
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
dst_primitive_desc
(),
ptr
,
"@dst
_mem_p"
);
this
->
fwd_pd_
->
mean_primitive_desc
(),
mean_data
,
"@mean
_mem_p"
);
}
}
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireVarianceMemory
(
AcquireBatchNormPrimitiveDescriptor
(
const
batch_norm_fwd
::
desc
&
bn_fwd_desc
,
const
framework
::
Tensor
*
variance
)
{
const
mkldnn
::
engine
&
engine
)
{
const
T
*
variance_data
=
variance
->
data
<
T
>
();
// BatchNorm PD has to be passed to Grad op that
return
this
->
AcquireMemoryFromPrimitive
(
// may be executed by diffrent thread, hence
this
->
fwd_pd_
->
variance_primitive_desc
(),
// for that one we use key that does not contain TID
to_void_cast
<
T
>
(
variance_data
),
"@variance_mem_p"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key_common_
+
"@bn_fwd_pd"
;
batch_norm_pd_
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_batch_norm_fwd_pd
));
if
(
batch_norm_pd_
==
nullptr
)
{
static
std
::
mutex
acquire_barrier
;
std
::
lock_guard
<
std
::
mutex
>
block_threads_until_finish_this_job
(
acquire_barrier
);
batch_norm_pd_
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx_
.
GetBlob
(
key_batch_norm_fwd_pd
));
if
(
batch_norm_pd_
==
nullptr
)
{
batch_norm_pd_
.
reset
(
new
batch_norm_fwd
::
primitive_desc
(
bn_fwd_desc
,
engine
));
dev_ctx_
.
SetBlob
(
key_batch_norm_fwd_pd
,
batch_norm_pd_
);
}
}
return
batch_norm_pd_
;
}
}
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTestTrainingBatchNormFwd
(
std
::
shared_ptr
<
mkldnn
::
memory
>
AcquireVarianceMemory
(
std
::
shared_ptr
<
memory
>
src_memory
,
framework
::
Tensor
*
variance
)
{
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
T
*
variance_data
=
variance
->
mutable_data
<
T
>
(
std
::
shared_ptr
<
memory
>
dst_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
,
this
->
place_
,
this
->
fwd_pd_
->
variance_primitive_desc
().
get_size
());
std
::
shared_ptr
<
memory
>
variance_memory
,
bool
is_test
)
{
return
this
->
AcquireMemoryFromPrimitive
(
auto
prim_key
=
key_
+
"@batch_norm_p"
;
this
->
fwd_pd_
->
variance_primitive_desc
(),
variance_data
,
auto
batch_norm_p
=
"@variance_mem_p"
);
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
if
(
batch_norm_p
==
nullptr
)
{
if
(
is_test
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
}
else
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
}
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
return
batch_norm_p
;
}
}
private:
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd_
;
};
};
std
::
shared_ptr
<
memory
>
UpdateMemoryData
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
std
::
string
&
key
,
void
*
new_ptr
)
{
auto
mem
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key
));
PADDLE_ENFORCE
(
mem
!=
nullptr
,
(
std
::
string
(
"Fail to find memory in device context [key: "
)
+
key
+
"]"
)
.
c_str
());
mem
->
set_data_handle
(
new_ptr
);
return
mem
;
}
template
<
typename
T
,
typename
Container
>
void
copy_to_weights
(
T
scale_begin
,
T
scale_end
,
T
shift_begin
,
T
shift_end
,
Container
*
c
)
{
auto
it
=
std
::
begin
(
*
c
);
std
::
copy
(
scale_begin
,
scale_end
,
std
::
inserter
(
*
c
,
it
));
std
::
copy
(
shift_begin
,
shift_end
,
std
::
inserter
(
*
c
,
std
::
next
(
it
,
std
::
distance
(
scale_begin
,
scale_end
))));
}
}
// namespace
template
<
typename
T
>
template
<
typename
T
>
class
BatchNormMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
class
BatchNormMKLDNNOpKernel
:
public
paddle
::
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -158,14 +125,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -158,14 +125,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
is_test
=
ctx
.
Attr
<
bool
>
(
"is_test"
);
const
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
const
bool
use_global_stats
=
ctx
.
Attr
<
bool
>
(
"use_global_stats"
);
const
bool
fuse_with_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_with_relu"
);
const
bool
fuse_with_relu
=
ctx
.
Attr
<
bool
>
(
"fuse_with_relu"
);
bool
global_stats
=
is_test
||
use_global_stats
;
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
bool
global_stats
=
is_test
||
use_global_stats
;
const
auto
*
mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
const
auto
*
variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
MKLDNNDeviceContext
>();
auto
mkldnn_engine
=
dev_ctx
.
GetEngine
();
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
mean_out
=
ctx
.
Output
<
Tensor
>
(
"MeanOut"
);
auto
*
mean_out
=
ctx
.
Output
<
Tensor
>
(
"MeanOut"
);
...
@@ -173,102 +140,61 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -173,102 +140,61 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto
*
batch_mean
=
ctx
.
Output
<
Tensor
>
(
"SavedMean"
);
auto
*
batch_mean
=
ctx
.
Output
<
Tensor
>
(
"SavedMean"
);
auto
*
batch_variance
=
ctx
.
Output
<
Tensor
>
(
"SavedVariance"
);
auto
*
batch_variance
=
ctx
.
Output
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
x
->
layout
(),
DataLayout
::
kMKLDNN
,
PADDLE_ENFORCE_EQ
(
x
->
layout
(),
DataLayout
::
kMKLDNN
,
"Wrong layout set for X tensor"
);
"Wrong layout set for X tensor"
);
PADDLE_ENFORCE_NE
(
x
->
format
(),
MKLDNNMemoryFormat
::
format_undef
,
PADDLE_ENFORCE_NE
(
x
->
format
(),
MKLDNNMemoryFormat
::
format_undef
,
"Wrong format set for X tensor"
);
"Wrong format set for X tensor"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
mean_data
=
mean
->
data
<
T
>
();
const
T
*
variance_data
=
variance
->
data
<
T
>
();
T
*
mean_out_data
=
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
variance_out_data
=
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batch_mean_data
=
nullptr
;
T
*
batch_variance_data
=
nullptr
;
if
(
!
global_stats
)
{
batch_mean_data
=
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
batch_variance_data
=
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
auto
propagation
=
global_stats
==
true
?
mkldnn
::
prop_kind
::
forward_scoring
:
mkldnn
::
prop_kind
::
forward_training
;
auto
src_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
src_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
scale_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
scale
->
dims
());
auto
scale_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
scale
->
dims
());
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
const
unsigned
int
C
=
scale_tz
[
0
];
// MKLDNN requires a single piece of memory for scale and shift/bias data
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
C
);
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
2
*
C
);
scaleshift_data
.
reserve
(
scaleshift_size
);
scaleshift_data
.
insert
(
scaleshift_data
.
end
(),
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
C
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
// Flags are added by bitwise OR operation
unsigned
flags
=
mkldnn
::
use_scale_shift
;
// 001
unsigned
flags
=
mkldnn
::
use_scale_shift
;
if
(
global_stats
)
flags
|=
mkldnn
::
use_global_stats
;
// 010
if
(
global_stats
)
flags
|=
mkldnn
::
use_global_stats
;
if
(
fuse_with_relu
&&
is_test
)
flags
|=
mkldnn
::
fuse_bn_relu
;
// 100
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
BatchNormMKLDNNHandler
<
T
>
handler
(
// create mkldnn memory from input x tensor
src_tz
,
epsilon
,
flags
,
global_stats
,
MKLDNNMemoryFormat
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
()),
dev_ctx
,
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
ctx
.
GetPlace
(),
ctx
.
op
().
Output
(
"SavedMean"
));
// keys for backward pass
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
epsilon
,
flags
,
global_stats
,
input_format
,
ctx
.
op
().
Output
(
"SavedMean"
));
BatchNormMKLDNNHandler
handler
(
dev_ctx
,
mkldnn_engine
,
key
);
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
user_src_md
,
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
handler
.
AcquireBatchNormPrimitiveDescriptor
(
batch_norm_fwd_desc
,
mkldnn_engine
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
(
x_data
));
// crate mkldnn memory for weights(scale/shift)
auto
scaleshift_memory
=
auto
scaleshift_memory
=
handler
.
AcquireScaleshiftMemoryFromPrimitive
(
scaleshift_data
.
data
());
handler
.
AcquireScaleShiftMemory
(
scaleshift_data
.
data
());
auto
dst_memory
=
handler
.
AcquireDstMemory
(
y
);
// create mkldnn memory for output y tensor
auto
dst_memory
=
handler
.
AcquireDstMemoryFromPrimitive
<
T
>
(
y
,
ctx
.
GetPlace
());
std
::
shared_ptr
<
batch_norm_fw
d
>
batch_norm_p
;
std
::
shared_ptr
<
mkldnn
::
batch_normalization_forwar
d
>
batch_norm_p
;
if
(
global_stats
)
{
if
(
global_stats
)
{
// create mkldnn memory for stats (as input)
// mean and variance are taken from input Tensor
std
::
shared_ptr
<
memory
>
mean_memory
=
const
auto
*
mean
=
ctx
.
Input
<
Tensor
>
(
"Mean"
);
handler
.
AcquireMeanMemoryFromPrimitive
(
to_void_cast
(
mean_data
));
const
auto
*
variance
=
ctx
.
Input
<
Tensor
>
(
"Variance"
);
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemory
(
mean
);
std
::
shared_ptr
<
memory
>
variance_memory
=
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
handler
.
AcquireVarianceMemory
(
variance
);
to_void_cast
(
variance_data
));
batch_norm_p
=
handler
.
AcquireTestTrainingBatchNormFwd
(
batch_norm_p
=
handler
.
AcquireForwardPrimitive
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
*
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
variance_memory
,
true
);
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
}
else
{
}
else
{
//
create mkldnn memory for stats (as output)
//
mean and variance are calculated and saved in output Tensor
std
::
shared_ptr
<
memory
>
mean_memory
=
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemory
FromPrimitive
(
batch_mean_data
);
handler
.
AcquireMeanMemory
(
batch_mean
);
std
::
shared_ptr
<
memory
>
variance_memory
=
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemory
FromPrimitive
(
batch_variance_data
);
handler
.
AcquireVarianceMemory
(
batch_variance
);
batch_norm_p
=
handler
.
Acquire
TestTrainingBatchNormFwd
(
batch_norm_p
=
handler
.
Acquire
ForwardPrimitive
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
variance_memory
,
false
);
*
variance_memory
);
}
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
...
@@ -281,18 +207,20 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -281,18 +207,20 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if
(
!
global_stats
)
{
if
(
!
global_stats
)
{
// mkldnn only compute stats for current batch
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
// so we need compute momentum stats via Eigen lib
EigenVectorArrayMap
<
T
>
batch_mean_e
(
batch_mean_data
,
ic
);
EigenVectorArrayMap
<
T
>
batch_mean_e
(
EigenVectorArrayMap
<
T
>
batch_variance_e
(
batch_variance_data
,
ic
);
batch_mean
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
ConstEigenVectorArrayMap
<
T
>
mean_e
(
mean_data
,
ic
);
EigenVectorArrayMap
<
T
>
batch_variance_e
(
ConstEigenVectorArrayMap
<
T
>
variance_e
{
variance_data
,
ic
};
batch_variance
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
EigenVectorArrayMap
<
T
>
running_mean_e
(
mean_out_data
,
ic
);
EigenVectorArrayMap
<
T
>
running_mean_e
(
EigenVectorArrayMap
<
T
>
running_variance_e
(
variance_out_data
,
ic
);
mean_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
EigenVectorArrayMap
<
T
>
running_variance_e
(
auto
one_minus_momentum
=
1.
-
momentum
;
variance_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
C
);
running_mean_e
=
mean_e
*
momentum
+
batch_mean_e
*
one_minus_momentum
;
running_mean_e
=
running_mean_e
*
momentum
+
batch_mean_e
*
(
1.
-
momentum
);
running_variance_e
=
running_variance_e
=
variance_e
*
momentum
+
batch_variance_e
*
one_minus_momentum
;
running_variance_e
*
momentum
+
batch_variance_e
*
(
1.
-
momentum
)
;
}
}
}
}
};
};
...
@@ -311,7 +239,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -311,7 +239,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
shift
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
batch_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
batch_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
batch_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
batch_variance
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
auto
*
diff_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
diff_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
diff_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
diff_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
diff_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
...
@@ -322,27 +249,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -322,27 +249,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_NE
(
diff_y
->
format
(),
MKLDNNMemoryFormat
::
format_undef
,
PADDLE_ENFORCE_NE
(
diff_y
->
format
(),
MKLDNNMemoryFormat
::
format_undef
,
"Wrong format set for Input diff_y tensor"
);
"Wrong format set for Input diff_y tensor"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
diff_y_data
=
diff_y
->
data
<
T
>
();
const
T
*
batch_mean_data
=
batch_mean
->
data
<
T
>
();
const
T
*
batch_variance_data
=
batch_variance
->
data
<
T
>
();
const
T
*
scale_data
=
scale
->
data
<
T
>
();
const
T
*
shift_data
=
shift
->
data
<
T
>
();
T
*
diff_x_data
=
diff_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_scale_data
=
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_shift_data
=
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
src_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
src_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
x
->
dims
());
auto
diff_src_tz
=
src_tz
;
auto
dst_tz
=
src_tz
;
auto
diff_dst_tz
=
dst_tz
;
auto
scale_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
scale
->
dims
());
auto
scale_tz
=
paddle
::
framework
::
vectorize
<
int
>
(
scale
->
dims
());
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
const
unsigned
int
C
=
scale_tz
[
0
];
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
MKLDNNMemoryFormat
dst_format
=
MKLDNNMemoryFormat
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
...
@@ -350,170 +261,52 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
...
@@ -350,170 +261,52 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
MKLDNNMemoryFormat
input_format
=
MKLDNNMemoryFormat
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
unsigned
flags
=
mkldnn
::
use_scale_shift
;
BatchNormMKLDNNHandler
<
T
>
handler
(
src_tz
,
epsilon
,
mkldnn
::
use_scale_shift
,
dst_format
,
input_format
,
// keys from forward pass
dev_ctx
,
ctx
.
GetPlace
(),
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key
=
platform
::
CreateKey
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
,
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
platform
::
CreateKey
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_src_mem_p"
;
const
std
::
string
key_batch_norm_mean_mem_p
=
key_with_hash
+
"@batch_norm_bwd_mean_mem_p"
;
const
std
::
string
key_batch_norm_variance_mem_p
=
key_with_hash
+
"@batch_norm_bwd_variance_mem_p"
;
const
std
::
string
key_batch_norm_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_src_mem_p"
;
const
std
::
string
key_batch_norm_diff_dst_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_dst_mem_p"
;
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
const
size_t
scaleshift_size
=
2
*
C
;
std
::
vector
<
T
>
scaleshift_data
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
C
);
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale_data
,
scale_data
+
ic
,
shift_data
,
shift_data
+
ic
,
scaleshift_data
.
insert
(
scaleshift_data
.
end
(),
shift
->
data
<
T
>
()
,
&
scaleshift_data
);
shift
->
data
<
T
>
()
+
C
);
std
::
vector
<
T
>
diff_scaleshift_data
;
std
::
vector
<
T
>
diff_scaleshift_data
;
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
auto
batch_norm_fwd_pd
=
auto
src_memory
=
handler
.
AcquireSrcMemory
(
x
);
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
auto
mean_memory
=
handler
.
AcquireMeanMemory
(
batch_mean
);
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
auto
variance_memory
=
handler
.
AcquireVarianceMemory
(
batch_variance
);
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
auto
diff_dst_memory
=
handler
.
AcquireDiffDstMemory
(
diff_y
);
"Fail to find batch_norm_fwd_pd in device context"
);
auto
scaleshift_memory
=
handler
.
AcquireScaleShiftMemory
(
scaleshift_data
.
data
());
auto
batch_norm_bwd_p
=
std
::
static_pointer_cast
<
batch_norm_bwd
>
(
auto
diff_src_memory
=
handler
.
AcquireDiffSrcMemory
(
diff_x
);
dev_ctx
.
GetBlob
(
key_batch_norm_bwd_p
));
auto
diff_scaleshift_memory
=
handler
.
AcquireDiffScaleShiftMemory
(
diff_scaleshift_data
.
data
());
if
(
batch_norm_bwd_p
==
nullptr
)
{
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
// finally create batch_norm backward primitive
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
auto
batch_norm_bwd_p
=
handler
.
AcquireBackwardPrimitive
(
to_void_cast
(
x_data
)));
*
src_memory
,
*
mean_memory
,
*
variance_memory
,
*
diff_dst_memory
,
*
scaleshift_memory
,
*
diff_src_memory
,
*
diff_scaleshift_memory
);
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
// create primitive descriptor for batch norm backward
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
->
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
// reorder user_diff_dst if it's not in preferred format
auto
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
user_diff_dst_memory
);
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
diff_dst_pd
);
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
// create mkldnn memory for input tensors (scale/shift)
auto
scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
auto
diff_scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
());
// here assume diff_src is in the same format of src
auto
diff_src_memory
=
std
::
make_shared
<
memory
>
(
src_memory
->
get_primitive_desc
(),
diff_x_data
);
// finally create batch_norm backward primitive
batch_norm_bwd_p
=
std
::
make_shared
<
batch_norm_bwd
>
(
batch_norm_bwd_pd
,
*
src_memory
,
*
mean_memory
,
*
variance_memory
,
*
diff_dst_memory
,
*
scaleshift_memory
,
*
diff_src_memory
,
*
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_bwd_p
,
batch_norm_bwd_p
);
dev_ctx
.
SetBlob
(
key_batch_norm_src_mem_p
,
src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_mean_mem_p
,
mean_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_variance_mem_p
,
variance_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_scaleshift_mem_p
,
scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_src_mem_p
,
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_dst_mem_p
,
diff_dst_memory
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
(
MKLDNNMemoryFormat
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
else
{
// primitives already exist
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_src_mem_p
,
to_void_cast
(
x_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_mean_mem_p
,
to_void_cast
(
batch_mean_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_variance_mem_p
,
to_void_cast
(
batch_variance_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_scaleshift_mem_p
,
scaleshift_data
.
data
());
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_data
.
data
());
auto
diff_src_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_src_mem_p
,
to_void_cast
(
diff_x_data
));
auto
diff_dst_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_dst_mem_p
,
to_void_cast
(
diff_y_data
));
// reorder user_diff_dst if it's not in preferred format
if
(
diff_dst_memory
->
get_primitive_desc
()
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
(
MKLDNNMemoryFormat
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
// execute optional reorder and batch_norm backward primitive
std
::
vector
<
primitive
>
pipeline
;
std
::
vector
<
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
pipeline
.
push_back
(
reorder_diff_dst
);
pipeline
.
push_back
(
*
batch_norm_bwd_p
);
pipeline
.
push_back
(
*
batch_norm_bwd_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
T
*
diff_scale_data
=
diff_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
diff_shift_data
=
diff_shift
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// copy back diff sacle/shift to output tensors (diff scale/shift)
// copy back diff sacle/shift to output tensors (diff scale/shift)
diff_scaleshift_data
.
resize
(
scaleshift_size
);
diff_scaleshift_data
.
resize
(
scaleshift_size
);
auto
it
=
std
::
begin
(
diff_scaleshift_data
);
auto
it
=
std
::
begin
(
diff_scaleshift_data
);
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale_data
);
std
::
copy
(
it
,
std
::
next
(
it
,
C
),
diff_scale_data
);
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
std
::
copy
(
std
::
next
(
it
,
C
),
std
::
end
(
diff_scaleshift_data
),
diff_shift_data
);
diff_shift_data
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
diff_src_memory
));
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录