Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
50d3e6e9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
50d3e6e9
编写于
8月 02, 2018
作者:
K
Krzysztof Binias
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reusing primitives for forward Batch Norm operator
上级
ef7bd03a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
303 addition
and
118 deletion
+303
-118
paddle/fluid/operators/batch_norm_mkldnn_op.cc
paddle/fluid/operators/batch_norm_mkldnn_op.cc
+303
-118
未找到文件。
paddle/fluid/operators/batch_norm_mkldnn_op.cc
浏览文件 @
50d3e6e9
...
...
@@ -37,6 +37,122 @@ struct bn_type_traits {
using
op_prim
=
typename
op_type
::
primitive_desc
;
};
class
BatchNormMKLDNNHandler
:
public
platform
::
MKLDNNHandler
{
public:
BatchNormMKLDNNHandler
(
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{
batch_norm_pd_
=
batch_norm_pd
;
}
std
::
shared_ptr
<
memory
>
AcquireScaleshiftMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
weights_primitive_desc
(),
ptr
,
"@scaleshift_mem_p"
);
}
std
::
shared_ptr
<
memory
>
AcquireMeanMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
mean_primitive_desc
(),
ptr
,
"@mean_mem_p"
);
}
std
::
shared_ptr
<
memory
>
AcquireVarianceMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
variance_primitive_desc
(),
ptr
,
"@variance_mem_p"
);
}
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTestBatchNormFwd
(
std
::
shared_ptr
<
memory
>
src_memory
,
const
mkldnn
::
primitive
::
at
&
mean_memory
,
const
mkldnn
::
primitive
::
at
&
variance_memory
,
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
dst_memory
)
{
auto
prim_key
=
key_
+
"@batch_norm_p"
;
auto
batch_norm_p
=
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
(
(
batch_norm_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find batch norm primitive for test in device context"
);
if
(
batch_norm_p
==
nullptr
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
mean_memory
,
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
else
{
is_reusing_
=
true
;
}
return
batch_norm_p
;
}
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTrainingBatchNormFwd
(
std
::
shared_ptr
<
memory
>
src_memory
,
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
dst_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
,
std
::
shared_ptr
<
memory
>
variance_memory
)
{
auto
prim_key
=
key_
+
"@batch_norm_p"
;
auto
batch_norm_p
=
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
(
(
batch_norm_p
!=
nullptr
)
||
(
is_reusing_
==
false
),
"Fail to find batch norm primitive for training in device context"
);
if
(
batch_norm_p
==
nullptr
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
else
{
is_reusing_
=
true
;
}
return
batch_norm_p
;
}
//
static
std
::
string
GetHash
(
const
memory
::
dims
&
input_dims
,
float
epsilon
,
unsigned
flag
,
bool
is_test
,
memory
::
format
format
,
const
std
::
string
&
suffix
)
{
auto
dims2str
=
[](
const
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dims2str
(
input_dims
)
+
std
::
to_string
(
epsilon
)
+
std
::
to_string
(
flag
)
+
std
::
to_string
(
is_test
)
+
std
::
to_string
(
format
)
+
suffix
;
}
private:
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd_
;
};
std
::
string
gethash
(
const
memory
::
dims
&
input_dims
,
float
epsilon
,
unsigned
flag
,
bool
is_test
,
memory
::
format
format
)
{
auto
dims2str
=
[](
const
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dims2str
(
input_dims
)
+
std
::
to_string
(
epsilon
)
+
std
::
to_string
(
flag
)
+
std
::
to_string
(
is_test
)
+
std
::
to_string
(
format
);
}
std
::
shared_ptr
<
memory
>
UpdateMemoryData
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
std
::
string
&
key
,
void
*
new_ptr
)
{
auto
mem
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key
));
PADDLE_ENFORCE
(
mem
!=
nullptr
,
(
std
::
string
(
"Fail to find memory in device context [key: "
)
+
key
+
"]"
)
.
c_str
());
mem
->
set_data_handle
(
new_ptr
);
return
mem
;
}
template
<
typename
T
,
typename
Container
>
void
copy_to_weights
(
T
scale_begin
,
T
scale_end
,
T
shift_begin
,
T
shift_end
,
Container
*
c
)
{
...
...
@@ -48,15 +164,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
std
::
inserter
(
*
c
,
std
::
next
(
it
,
std
::
distance
(
scale_begin
,
scale_end
))));
}
template
<
typename
Op
,
typename
...
Args
>
void
run_batch_norm_op
(
Args
&&
...
args
)
{
Op
batch_norm_op
{
args
...};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
batch_norm_op
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
// namespace
template
<
typename
T
>
...
...
@@ -110,6 +217,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
unsigned
flags
=
mkldnn
::
use_scale_shift
;
if
(
is_test
)
flags
|=
mkldnn
::
use_global_stats
;
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
...
...
@@ -118,64 +233,70 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
auto
src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
));
// keys for backward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
is_test
,
input_format
,
ctx
.
op
().
Output
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
src_memory
.
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_fwd_pd
=
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
(
new
batch_norm_fwd
::
primitive_desc
(
batch_norm_fwd_desc
,
mkldnn_engine
));
// Save the pd to be used in backward pass
const
std
::
string
key
=
ctx
.
op
().
Output
(
"SavedMean"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
user_src_md
,
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
std
::
make_shared
<
batch_norm_fwd
::
primitive_desc
>
(
batch_norm_fwd_desc
,
mkldnn_engine
);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx
.
SetBlob
(
key_batch_norm_fwd_pd
,
batch_norm_fwd_pd
);
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
BatchNormMKLDNNHandler
handler
(
batch_norm_fwd_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
(
x_data
)
);
// crate mkldnn memory for weights(scale/shift)
auto
scaleshift_memory
=
memory
(
batch_norm_fwd_pd
->
weights_primitive_desc
(),
scaleshift_data
.
data
());
auto
scaleshift_memory
=
handler
.
AcquireScaleshiftMemoryFromPrimitive
(
scaleshift_data
.
data
());
// create mkldnn memory for output y tensor
auto
dst_memory
=
memory
(
batch_norm_fwd_pd
->
dst_primitive_desc
(),
y_data
);
auto
dst_memory
=
handler
.
AcquireDstMemory
(
batch_norm_fwd_pd
->
dst_primitive_desc
().
desc
(),
y_data
);
std
::
shared_ptr
<
batch_norm_fwd
>
batch_norm_p
;
if
(
is_test
)
{
// create mkldnn memory for stats (as input)
auto
mean_memory
=
memory
(
batch_norm_fwd_pd
->
mean_primitive_desc
(),
to_void_cast
(
mean_data
));
auto
variance_memory
=
memory
(
batch_norm_fwd_pd
->
variance_primitive_desc
(),
to_void_cast
(
variance_data
));
run_batch_norm_op
<
typename
bn_fwd_types
::
op_type
>
(
*
batch_norm_fwd_pd
,
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
variance_memory
,
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemoryFromPrimitive
(
to_void_cast
(
mean_data
));
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
to_void_cast
(
variance_data
));
batch_norm_p
=
handler
.
AcquireTestBatchNormFwd
(
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
scaleshift_memory
,
dst_memory
);
}
else
{
// create mkldnn memory for stats (as output)
auto
mean_memory
=
memory
(
batch_norm_fwd_pd
->
mean_primitive_desc
(),
batch_mean_data
);
auto
variance_memory
=
memory
(
batch_norm_fwd_pd
->
variance_primitive_desc
(),
batch_variance_data
);
run_batch_norm_op
<
bn_fwd_types
::
op_type
>
(
*
batch_norm_fwd_pd
,
src_memory
,
scaleshift_memory
,
dst
_memory
,
mean_memory
,
variance_memory
);
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemoryFromPrimitive
(
batch_mean_data
);
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
batch_variance_data
);
batch_norm_p
=
handler
.
AcquireTrainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean
_memory
,
variance_memory
);
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory
));
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
*
batch_norm_p
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
if
(
!
is_test
)
{
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
...
...
@@ -192,10 +313,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
running_variance_e
=
variance_e
*
momentum
+
batch_variance_e
*
one_minus_momentum
;
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
(
memory
::
format
)
dst_memory
.
get_primitive_desc
().
desc
().
data
.
format
);
}
};
...
...
@@ -242,61 +359,47 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
unsigned
int
ic
=
scale_tz
[
0
];
// Retrieve bn_fwd_pd from device context
const
std
::
string
key
=
ctx
.
op
().
Input
(
"SavedMean"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
batch_norm_fwd_pd
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
"Fail to find batch_norm_fwd_pd in device context"
);
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
// create mkldnn memory from input diff_y tensor
mkldnn
::
memory
::
format
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
// create mkldnn memory from input x tensor
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
auto
src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
));
unsigned
flags
=
mkldnn
::
use_scale_shift
;
// keys from forward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
,
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
gethash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_src_mem_p"
;
const
std
::
string
key_batch_norm_mean_mem_p
=
key_with_hash
+
"@batch_norm_bwd_mean_mem_p"
;
const
std
::
string
key_batch_norm_variance_mem_p
=
key_with_hash
+
"@batch_norm_bwd_variance_mem_p"
;
const
std
::
string
key_batch_norm_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_src_mem_p"
;
const
std
::
string
key_batch_norm_diff_dst_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_dst_mem_p"
;
// create primitive descriptor for batch norm backward
unsigned
flags
=
mkldnn
::
use_scale_shift
;
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
.
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
// reorder user_diff_dst if it's not in preferred format
auto
diff_dst_memory
=
user_diff_dst_memory
;
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
memory
(
diff_dst_pd
);
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
memory
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
memory
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
...
...
@@ -306,30 +409,118 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights
(
scale_data
,
scale_data
+
ic
,
shift_data
,
shift_data
+
ic
,
&
scaleshift_data
);
// create mkldnn memory for input tensors (scale/shift)
auto
scaleshift_memory
=
memory
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
std
::
vector
<
T
>
diff_scaleshift_data
;
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
auto
diff_scaleshift_memory
=
memory
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
());
// here assume diff_src is in the same format of src
auto
diff_src_memory
=
memory
(
src_memory
.
get_primitive_desc
(),
diff_x_data
);
auto
batch_norm_fwd_pd
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
"Fail to find batch_norm_fwd_pd in device context"
);
// finally create batch_norm backward primitive
auto
batch_norm_bwd_prim
=
batch_norm_bwd
(
batch_norm_bwd_pd
,
src_memory
,
mean_memory
,
variance_memory
,
diff_dst_memory
,
scaleshift_memory
,
diff_src_memory
,
diff_scaleshift_memory
);
auto
batch_norm_bwd_p
=
std
::
static_pointer_cast
<
batch_norm_bwd
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_bwd_p
));
if
(
batch_norm_bwd_p
==
nullptr
)
{
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
)));
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
// create primitive descriptor for batch norm backward
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
->
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
// reorder user_diff_dst if it's not in preferred format
auto
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
user_diff_dst_memory
);
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
diff_dst_pd
);
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
// create mkldnn memory for input tensors (scale/shift)
auto
scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
auto
diff_scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
());
// here assume diff_src is in the same format of src
auto
diff_src_memory
=
std
::
make_shared
<
memory
>
(
src_memory
->
get_primitive_desc
(),
diff_x_data
);
// finally create batch_norm backward primitive
batch_norm_bwd_p
=
std
::
make_shared
<
batch_norm_bwd
>
(
batch_norm_bwd_pd
,
*
src_memory
,
*
mean_memory
,
*
variance_memory
,
*
diff_dst_memory
,
*
scaleshift_memory
,
*
diff_src_memory
,
*
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_bwd_p
,
batch_norm_bwd_p
);
dev_ctx
.
SetBlob
(
key_batch_norm_src_mem_p
,
src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_mean_mem_p
,
mean_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_variance_mem_p
,
variance_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_scaleshift_mem_p
,
scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_src_mem_p
,
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_dst_mem_p
,
diff_dst_memory
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
else
{
// primitives already exist
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_src_mem_p
,
to_void_cast
(
x_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_mean_mem_p
,
to_void_cast
(
batch_mean_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_variance_mem_p
,
to_void_cast
(
batch_variance_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_scaleshift_mem_p
,
scaleshift_data
.
data
());
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_data
.
data
());
auto
diff_src_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_src_mem_p
,
to_void_cast
(
diff_x_data
));
auto
diff_dst_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_dst_mem_p
,
to_void_cast
(
diff_y_data
));
// reorder user_diff_dst if it's not in preferred format
if
(
diff_dst_memory
->
get_primitive_desc
()
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
// execute optional reorder and batch_norm backward primitive
std
::
vector
<
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
pipeline
.
push_back
(
reorder_diff_dst
);
pipeline
.
push_back
(
batch_norm_bwd_prim
);
pipeline
.
push_back
(
*
batch_norm_bwd_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
// copy back diff sacle/shift to output tensors (diff scale/shift)
...
...
@@ -338,12 +529,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale_data
);
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
diff_shift_data
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
.
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
};
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录