Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
26cac36b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
接近 2 年 前同步成功
通知
706
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
26cac36b
编写于
8月 28, 2018
作者:
T
Tao Luo
提交者:
GitHub
8月 28, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12515 from kbinias/kbinias/bnorm-fwd-reuse
Reusing primitives for forward Batch Norm operator
上级
1d9cb932
fb4b4f8d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
277 addition
and
119 deletion
+277
-119
paddle/fluid/operators/batch_norm_mkldnn_op.cc
paddle/fluid/operators/batch_norm_mkldnn_op.cc
+277
-119
未找到文件。
paddle/fluid/operators/batch_norm_mkldnn_op.cc
浏览文件 @
26cac36b
...
...
@@ -37,6 +37,95 @@ struct bn_type_traits {
using
op_prim
=
typename
op_type
::
primitive_desc
;
};
class
BatchNormMKLDNNHandler
:
public
platform
::
MKLDNNHandler
{
public:
BatchNormMKLDNNHandler
(
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd
,
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
mkldnn
::
engine
engine
,
const
std
::
string
&
base_key
)
:
platform
::
MKLDNNHandler
(
dev_ctx
,
engine
,
base_key
)
{
batch_norm_pd_
=
batch_norm_pd
;
}
std
::
shared_ptr
<
memory
>
AcquireScaleshiftMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
weights_primitive_desc
(),
ptr
,
"@scaleshift_mem_p"
);
}
std
::
shared_ptr
<
memory
>
AcquireMeanMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
mean_primitive_desc
(),
ptr
,
"@mean_mem_p"
);
}
std
::
shared_ptr
<
memory
>
AcquireVarianceMemoryFromPrimitive
(
void
*
ptr
)
{
return
this
->
AcquireMemoryFromPrimitive
(
batch_norm_pd_
->
variance_primitive_desc
(),
ptr
,
"@variance_mem_p"
);
}
std
::
shared_ptr
<
batch_norm_fwd
>
AcquireTestTrainingBatchNormFwd
(
std
::
shared_ptr
<
memory
>
src_memory
,
std
::
shared_ptr
<
memory
>
scaleshift_memory
,
std
::
shared_ptr
<
memory
>
dst_memory
,
std
::
shared_ptr
<
memory
>
mean_memory
,
std
::
shared_ptr
<
memory
>
variance_memory
,
bool
is_test
)
{
auto
prim_key
=
key_
+
"@batch_norm_p"
;
auto
batch_norm_p
=
std
::
static_pointer_cast
<
batch_norm_fwd
>
(
dev_ctx_
.
GetBlob
(
prim_key
));
PADDLE_ENFORCE
((
batch_norm_p
!=
nullptr
)
||
!
is_reusing_
,
"Fail to find batch norm primitive in device context"
);
if
(
batch_norm_p
==
nullptr
)
{
if
(
is_test
)
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
*
variance_memory
,
*
scaleshift_memory
,
*
dst_memory
);
}
else
{
batch_norm_p
=
std
::
make_shared
<
batch_norm_fwd
>
(
*
batch_norm_pd_
,
*
src_memory
,
*
scaleshift_memory
,
*
dst_memory
,
*
mean_memory
,
*
variance_memory
);
}
dev_ctx_
.
SetBlob
(
prim_key
,
batch_norm_p
);
}
else
{
is_reusing_
=
true
;
}
return
batch_norm_p
;
}
static
std
::
string
GetHash
(
const
memory
::
dims
&
input_dims
,
float
epsilon
,
unsigned
flag
,
bool
is_test
,
memory
::
format
format
,
const
std
::
string
&
suffix
=
""
)
{
auto
dims2str
=
[](
const
memory
::
dims
&
operand_dims
)
{
std
::
string
dstr
=
""
;
for
(
size_t
i
=
0
;
i
<
operand_dims
.
size
();
++
i
)
{
dstr
+=
std
::
to_string
(
operand_dims
[
i
])
+
"-"
;
}
return
dstr
;
};
return
dims2str
(
input_dims
)
+
std
::
to_string
(
epsilon
)
+
std
::
to_string
(
flag
)
+
std
::
to_string
(
is_test
)
+
std
::
to_string
(
format
)
+
suffix
;
}
private:
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_pd_
;
};
std
::
shared_ptr
<
memory
>
UpdateMemoryData
(
const
platform
::
MKLDNNDeviceContext
&
dev_ctx
,
const
std
::
string
&
key
,
void
*
new_ptr
)
{
auto
mem
=
std
::
static_pointer_cast
<
memory
>
(
dev_ctx
.
GetBlob
(
key
));
PADDLE_ENFORCE
(
mem
!=
nullptr
,
(
std
::
string
(
"Fail to find memory in device context [key: "
)
+
key
+
"]"
)
.
c_str
());
mem
->
set_data_handle
(
new_ptr
);
return
mem
;
}
template
<
typename
T
,
typename
Container
>
void
copy_to_weights
(
T
scale_begin
,
T
scale_end
,
T
shift_begin
,
T
shift_end
,
Container
*
c
)
{
...
...
@@ -48,15 +137,6 @@ void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
std
::
inserter
(
*
c
,
std
::
next
(
it
,
std
::
distance
(
scale_begin
,
scale_end
))));
}
template
<
typename
Op
,
typename
...
Args
>
void
run_batch_norm_op
(
Args
&&
...
args
)
{
Op
batch_norm_op
{
args
...};
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
batch_norm_op
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
}
}
// namespace
template
<
typename
T
>
...
...
@@ -110,6 +190,14 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE
(
scale_tz
.
size
()
==
1
,
"Dims of scale tensor is NOT 1"
);
const
unsigned
int
ic
=
scale_tz
[
0
];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
unsigned
flags
=
mkldnn
::
use_scale_shift
;
if
(
is_test
)
flags
|=
mkldnn
::
use_global_stats
;
if
(
fuse_with_relu
)
flags
|=
mkldnn
::
fuse_bn_relu
;
...
...
@@ -118,64 +206,69 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
auto
src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
));
// keys for backward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
is_test
,
input_format
,
ctx
.
op
().
Output
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
user_src_md
=
platform
::
MKLDNNMemDesc
(
{
src_tz
},
platform
::
MKLDNNGetDataType
<
T
>
(),
input_format
);
// create primitive descriptor for batch norm forward
using
bn_fwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_forward
>
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
src_memory
.
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
batch_norm_fwd_pd
=
std
::
shared_ptr
<
batch_norm_fwd
::
primitive_desc
>
(
new
batch_norm_fwd
::
primitive_desc
(
batch_norm_fwd_desc
,
mkldnn_engine
));
// Save the pd to be used in backward pass
const
std
::
string
key
=
ctx
.
op
().
Output
(
"SavedMean"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
batch_norm_fwd_desc
=
bn_fwd_types
::
op_desc
{
propagation
,
user_src_md
,
epsilon
,
flags
};
auto
batch_norm_fwd_pd
=
std
::
make_shared
<
batch_norm_fwd
::
primitive_desc
>
(
batch_norm_fwd_desc
,
mkldnn_engine
);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx
.
SetBlob
(
key_batch_norm_fwd_pd
,
batch_norm_fwd_pd
);
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
std
::
vector
<
T
>
scaleshift_data
;
scaleshift_data
.
reserve
(
scaleshift_size
);
BatchNormMKLDNNHandler
handler
(
batch_norm_fwd_pd
,
dev_ctx
,
mkldnn_engine
,
key
);
copy_to_weights
(
scale
->
data
<
T
>
(),
scale
->
data
<
T
>
()
+
ic
,
shift
->
data
<
T
>
(),
shift
->
data
<
T
>
()
+
ic
,
&
scaleshift_data
);
auto
src_memory
=
handler
.
AcquireSrcMemory
(
user_src_md
,
to_void_cast
(
x_data
)
);
// crate mkldnn memory for weights(scale/shift)
auto
scaleshift_memory
=
memory
(
batch_norm_fwd_pd
->
weights_primitive_desc
(),
scaleshift_data
.
data
());
auto
scaleshift_memory
=
handler
.
AcquireScaleshiftMemoryFromPrimitive
(
scaleshift_data
.
data
());
// create mkldnn memory for output y tensor
auto
dst_memory
=
memory
(
batch_norm_fwd_pd
->
dst_primitive_desc
(),
y_data
);
auto
dst_memory
=
handler
.
AcquireDstMemory
(
batch_norm_fwd_pd
->
dst_primitive_desc
().
desc
(),
y_data
);
std
::
shared_ptr
<
batch_norm_fwd
>
batch_norm_p
;
if
(
is_test
)
{
// create mkldnn memory for stats (as input)
auto
mean_memory
=
memory
(
batch_norm_fwd_pd
->
mean_primitive_desc
(),
to_void_cast
(
mean_data
));
auto
variance_memory
=
memory
(
batch_norm_fwd_pd
->
variance_primitive_desc
(),
to_void_cast
(
variance_data
));
run_batch_norm_op
<
typename
bn_fwd_types
::
op_type
>
(
*
batch_norm_fwd_pd
,
src_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
mean_memory
,
(
const
mkldnn
::
primitive
::
at
&
)
variance_memory
,
scaleshift_memory
,
dst_memory
);
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemoryFromPrimitive
(
to_void_cast
(
mean_data
));
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
to_void_cast
(
variance_data
));
batch_norm_p
=
handler
.
AcquireTestTrainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean_memory
,
variance_memory
,
true
);
}
else
{
// create mkldnn memory for stats (as output)
auto
mean_memory
=
memory
(
batch_norm_fwd_pd
->
mean_primitive_desc
(),
batch_mean_data
);
auto
variance_memory
=
memory
(
batch_norm_fwd_pd
->
variance_primitive_desc
(),
batch_variance_data
);
run_batch_norm_op
<
bn_fwd_types
::
op_type
>
(
*
batch_norm_fwd_pd
,
src_memory
,
scaleshift_memory
,
dst
_memory
,
mean_memory
,
variance_memory
);
std
::
shared_ptr
<
memory
>
mean_memory
=
handler
.
AcquireMeanMemoryFromPrimitive
(
batch_mean_data
);
std
::
shared_ptr
<
memory
>
variance_memory
=
handler
.
AcquireVarianceMemoryFromPrimitive
(
batch_variance_data
);
batch_norm_p
=
handler
.
AcquireTestTrainingBatchNormFwd
(
src_memory
,
scaleshift_memory
,
dst_memory
,
mean
_memory
,
variance_memory
,
false
);
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
platform
::
GetMKLDNNFormat
(
*
dst_memory
));
std
::
vector
<
mkldnn
::
primitive
>
pipeline
;
pipeline
.
push_back
(
*
batch_norm_p
);
mkldnn
::
stream
(
mkldnn
::
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
if
(
!
is_test
)
{
// mkldnn only compute stats for current batch
// so we need compute momentum stats via Eigen lib
...
...
@@ -192,10 +285,6 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
running_variance_e
=
variance_e
*
momentum
+
batch_variance_e
*
one_minus_momentum
;
}
y
->
set_layout
(
DataLayout
::
kMKLDNN
);
y
->
set_format
(
(
memory
::
format
)
dst_memory
.
get_primitive_desc
().
desc
().
data
.
format
);
}
};
...
...
@@ -242,61 +331,48 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const
unsigned
int
ic
=
scale_tz
[
0
];
// Retrieve bn_fwd_pd from device context
const
std
::
string
key
=
ctx
.
op
().
Input
(
"SavedMean"
);
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
auto
batch_norm_fwd_pd
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
"Fail to find batch_norm_fwd_pd in device context"
);
using
bn_bwd_types
=
bn_type_traits
<
mkldnn
::
batch_normalization_backward
>
;
// create mkldnn memory from input diff_y tensor
mkldnn
::
memory
::
format
dst_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
diff_y
->
format
());
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
// create mkldnn memory from input x tensor
mkldnn
::
memory
::
format
input_format
=
platform
::
MKLDNNFormatForSize
(
src_tz
.
size
(),
x
->
format
());
auto
src_memory
=
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
));
unsigned
flags
=
mkldnn
::
use_scale_shift
;
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
// keys from forward pass
const
std
::
string
key
=
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
,
ctx
.
op
().
Input
(
"SavedMean"
));
const
std
::
string
key_batch_norm_fwd_pd
=
key
+
"@bn_fwd_pd"
;
// keys for primitives reuse
const
std
::
string
key_with_hash
=
key
+
BatchNormMKLDNNHandler
::
GetHash
(
src_tz
,
epsilon
,
flags
,
false
,
input_format
);
const
std
::
string
key_batch_norm_bwd_p
=
key_with_hash
+
"@batch_norm_bwd_p"
;
const
std
::
string
key_batch_norm_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_src_mem_p"
;
const
std
::
string
key_batch_norm_mean_mem_p
=
key_with_hash
+
"@batch_norm_bwd_mean_mem_p"
;
const
std
::
string
key_batch_norm_variance_mem_p
=
key_with_hash
+
"@batch_norm_bwd_variance_mem_p"
;
const
std
::
string
key_batch_norm_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_scaleshift_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_scaleshift_mem_p"
;
const
std
::
string
key_batch_norm_diff_src_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_src_mem_p"
;
const
std
::
string
key_batch_norm_diff_dst_mem_p
=
key_with_hash
+
"@batch_norm_bwd_diff_dst_mem_p"
;
// create primitive descriptor for batch norm backward
unsigned
flags
=
mkldnn
::
use_scale_shift
;
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
.
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
// reorder user_diff_dst if it's not in preferred format
auto
diff_dst_memory
=
user_diff_dst_memory
;
primitive
reorder_diff_dst
;
bool
is_diff_dst_reordered
=
false
;
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
memory
(
diff_dst_pd
);
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
memory
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
memory
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
auto
user_diff_dst_memory
=
memory
(
{{{
diff_dst_tz
},
memory
::
data_type
::
f32
,
dst_format
},
mkldnn_engine
},
to_void_cast
(
diff_y_data
));
// MKLDNN requires a single piece of memory for scale and shift/bias data
const
size_t
scaleshift_size
=
2
*
ic
;
...
...
@@ -306,30 +382,118 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
copy_to_weights
(
scale_data
,
scale_data
+
ic
,
shift_data
,
shift_data
+
ic
,
&
scaleshift_data
);
// create mkldnn memory for input tensors (scale/shift)
auto
scaleshift_memory
=
memory
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
std
::
vector
<
T
>
diff_scaleshift_data
;
diff_scaleshift_data
.
reserve
(
scaleshift_size
);
auto
diff_scaleshift_memory
=
memory
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
());
// here assume diff_src is in the same format of src
auto
diff_src_memory
=
memory
(
src_memory
.
get_primitive_desc
(),
diff_x_data
);
auto
batch_norm_fwd_pd
=
std
::
static_pointer_cast
<
batch_norm_fwd
::
primitive_desc
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_fwd_pd
));
PADDLE_ENFORCE
(
batch_norm_fwd_pd
!=
nullptr
,
"Fail to find batch_norm_fwd_pd in device context"
);
// finally create batch_norm backward primitive
auto
batch_norm_bwd_prim
=
batch_norm_bwd
(
batch_norm_bwd_pd
,
src_memory
,
mean_memory
,
variance_memory
,
diff_dst_memory
,
scaleshift_memory
,
diff_src_memory
,
diff_scaleshift_memory
);
auto
batch_norm_bwd_p
=
std
::
static_pointer_cast
<
batch_norm_bwd
>
(
dev_ctx
.
GetBlob
(
key_batch_norm_bwd_p
));
if
(
batch_norm_bwd_p
==
nullptr
)
{
auto
src_memory
=
std
::
shared_ptr
<
memory
>
(
new
memory
(
{{{
src_tz
},
memory
::
data_type
::
f32
,
input_format
},
mkldnn_engine
},
to_void_cast
(
x_data
)));
// for diff_dst, try to use same format as dst in forward pass
auto
diff_dst_pd
=
batch_norm_fwd_pd
.
get
()
->
dst_primitive_desc
();
auto
diff_dst_md
=
diff_dst_pd
.
desc
();
// create primitive descriptor for batch norm backward
auto
batch_norm_bwd_desc
=
bn_bwd_types
::
op_desc
{
mkldnn
::
prop_kind
::
backward
,
diff_dst_md
,
src_memory
->
get_primitive_desc
().
desc
(),
epsilon
,
flags
};
auto
batch_norm_bwd_pd
=
bn_bwd_types
::
op_prim
{
batch_norm_bwd_desc
,
mkldnn_engine
,
*
batch_norm_fwd_pd
};
// reorder user_diff_dst if it's not in preferred format
auto
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
user_diff_dst_memory
);
if
(
diff_dst_pd
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
diff_dst_memory
=
std
::
make_shared
<
memory
>
(
diff_dst_pd
);
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// create mkldnn memory for input tensors (src/mean/variance)
auto
mean_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
mean_primitive_desc
(),
to_void_cast
(
batch_mean_data
));
auto
variance_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
variance_primitive_desc
(),
to_void_cast
(
batch_variance_data
));
// create mkldnn memory for input tensors (scale/shift)
auto
scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
weights_primitive_desc
(),
scaleshift_data
.
data
());
// create mkldnn memory for output diff weights (combined scale/shift)
auto
diff_scaleshift_memory
=
std
::
make_shared
<
memory
>
(
batch_norm_bwd_pd
.
diff_weights_primitive_desc
(),
diff_scaleshift_data
.
data
());
// here assume diff_src is in the same format of src
auto
diff_src_memory
=
std
::
make_shared
<
memory
>
(
src_memory
->
get_primitive_desc
(),
diff_x_data
);
// finally create batch_norm backward primitive
batch_norm_bwd_p
=
std
::
make_shared
<
batch_norm_bwd
>
(
batch_norm_bwd_pd
,
*
src_memory
,
*
mean_memory
,
*
variance_memory
,
*
diff_dst_memory
,
*
scaleshift_memory
,
*
diff_src_memory
,
*
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_bwd_p
,
batch_norm_bwd_p
);
dev_ctx
.
SetBlob
(
key_batch_norm_src_mem_p
,
src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_mean_mem_p
,
mean_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_variance_mem_p
,
variance_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_scaleshift_mem_p
,
scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_src_mem_p
,
diff_src_memory
);
dev_ctx
.
SetBlob
(
key_batch_norm_diff_dst_mem_p
,
diff_dst_memory
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
else
{
// primitives already exist
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_src_mem_p
,
to_void_cast
(
x_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_mean_mem_p
,
to_void_cast
(
batch_mean_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_variance_mem_p
,
to_void_cast
(
batch_variance_data
));
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_scaleshift_mem_p
,
scaleshift_data
.
data
());
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_scaleshift_mem_p
,
diff_scaleshift_data
.
data
());
auto
diff_src_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_src_mem_p
,
to_void_cast
(
diff_x_data
));
auto
diff_dst_memory
=
UpdateMemoryData
(
dev_ctx
,
key_batch_norm_diff_dst_mem_p
,
to_void_cast
(
diff_y_data
));
// reorder user_diff_dst if it's not in preferred format
if
(
diff_dst_memory
->
get_primitive_desc
()
!=
user_diff_dst_memory
.
get_primitive_desc
())
{
reorder_diff_dst
=
reorder
(
user_diff_dst_memory
,
*
diff_dst_memory
);
is_diff_dst_reordered
=
true
;
}
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
->
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
// execute optional reorder and batch_norm backward primitive
std
::
vector
<
primitive
>
pipeline
;
if
(
is_diff_dst_reordered
)
pipeline
.
push_back
(
reorder_diff_dst
);
pipeline
.
push_back
(
batch_norm_bwd_prim
);
pipeline
.
push_back
(
*
batch_norm_bwd_p
);
stream
(
stream
::
kind
::
eager
).
submit
(
pipeline
).
wait
();
// copy back diff sacle/shift to output tensors (diff scale/shift)
...
...
@@ -338,12 +502,6 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std
::
copy
(
it
,
std
::
next
(
it
,
ic
),
diff_scale_data
);
std
::
copy
(
std
::
next
(
it
,
ic
),
std
::
end
(
diff_scaleshift_data
),
diff_shift_data
);
// set layout/format of output tensors
diff_x
->
set_layout
(
DataLayout
::
kMKLDNN
);
diff_x
->
set_format
((
memory
::
format
)
diff_src_memory
.
get_primitive_desc
()
.
desc
()
.
data
.
format
);
}
};
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录