提交 b8a04c2f 编写于 作者: M mozga-intel

Duplicated code was moved to common function

上级 3b128337
...@@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
mkldnn::memory::format input_format = x->format(); mkldnn::memory::format input_format =
if (src_tz.size() == 1) { platform::MKLDNNFormatForSize(src_tz.size(), x->format());
input_format = mkldnn::memory::format::x;
} else if (src_tz.size() == 2) {
input_format = mkldnn::memory::format::nc;
}
auto src_memory = memory( auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
...@@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// create mkldnn memory from input diff_y tensor // create mkldnn memory from input diff_y tensor
mkldnn::memory::format dst_format = x->format(); mkldnn::memory::format dst_format =
if (diff_dst_tz.size() == 1) { platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
dst_format = mkldnn::memory::format::x;
} else if (diff_dst_tz.size() == 2) {
dst_format = mkldnn::memory::format::nc;
}
auto user_diff_dst_memory = memory( auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine}, {{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data)); to_void_cast(diff_y_data));
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
mkldnn::memory::format input_format = x->format(); mkldnn::memory::format input_format =
if (src_tz.size() == 1) { platform::MKLDNNFormatForSize(src_tz.size(), x->format());
input_format = mkldnn::memory::format::x;
} else if (src_tz.size() == 2) {
input_format = mkldnn::memory::format::nc;
}
auto src_memory = memory( auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine}, {{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
......
...@@ -228,7 +228,7 @@ class MKLDNNHandler { ...@@ -228,7 +228,7 @@ class MKLDNNHandler {
return dstr; return dstr;
}; };
return dims2str(operand_dims) + suffix; return dims2str(operand_dims) + suffix;
}; }
protected: protected:
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
...@@ -237,5 +237,15 @@ class MKLDNNHandler { ...@@ -237,5 +237,15 @@ class MKLDNNHandler {
bool is_reusing_; bool is_reusing_;
}; };
inline mkldnn::memory::format MKLDNNFormatForSize(
size_t dims_size, mkldnn::memory::format data_format) {
if (dims_size == 1) {
return mkldnn::memory::format::x;
} else if (dims_size == 2) {
return mkldnn::memory::format::nc;
}
return data_format;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册