提交 3b128337 编写于 作者: M mozga-intel

The mkldnn batch norm supports other data format

上级 ae0d0c41
...@@ -115,8 +115,15 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -115,8 +115,15 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
auto src_memory = mkldnn::memory::format input_format = x->format();
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, if (src_tz.size() == 1) {
input_format = mkldnn::memory::format::x;
} else if (src_tz.size() == 2) {
input_format = mkldnn::memory::format::nc;
}
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data)); to_void_cast(x_data));
// create primitive descriptor for batch norm forward // create primitive descriptor for batch norm forward
...@@ -251,14 +258,27 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -251,14 +258,27 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
// create mkldnn memory from input diff_y tensor // create mkldnn memory from input diff_y tensor
auto user_diff_dst_memory =
memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()}, mkldnn::memory::format dst_format = x->format();
mkldnn_engine}, if (diff_dst_tz.size() == 1) {
dst_format = mkldnn::memory::format::x;
} else if (diff_dst_tz.size() == 2) {
dst_format = mkldnn::memory::format::nc;
}
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data)); to_void_cast(diff_y_data));
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
auto src_memory = mkldnn::memory::format input_format = x->format();
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, if (src_tz.size() == 1) {
input_format = mkldnn::memory::format::x;
} else if (src_tz.size() == 2) {
input_format = mkldnn::memory::format::nc;
}
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data)); to_void_cast(x_data));
// for diff_dst, try to use same format as dst in forward pass // for diff_dst, try to use same format as dst in forward pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册