提交 b8a04c2f 编写于 作者: M mozga-intel

Duplicated code was moved to common function

上级 3b128337
......@@ -115,12 +115,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor
mkldnn::memory::format input_format = x->format();
if (src_tz.size() == 1) {
input_format = mkldnn::memory::format::x;
} else if (src_tz.size() == 2) {
input_format = mkldnn::memory::format::nc;
}
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
......@@ -259,23 +255,16 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// create mkldnn memory from input diff_y tensor
mkldnn::memory::format dst_format = x->format();
if (diff_dst_tz.size() == 1) {
dst_format = mkldnn::memory::format::x;
} else if (diff_dst_tz.size() == 2) {
dst_format = mkldnn::memory::format::nc;
}
mkldnn::memory::format dst_format =
platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// create mkldnn memory from input x tensor
mkldnn::memory::format input_format = x->format();
if (src_tz.size() == 1) {
input_format = mkldnn::memory::format::x;
} else if (src_tz.size() == 2) {
input_format = mkldnn::memory::format::nc;
}
mkldnn::memory::format input_format =
platform::MKLDNNFormatForSize(src_tz.size(), x->format());
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
......
......@@ -228,7 +228,7 @@ class MKLDNNHandler {
return dstr;
};
return dims2str(operand_dims) + suffix;
};
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
......@@ -237,5 +237,15 @@ class MKLDNNHandler {
bool is_reusing_;
};
inline mkldnn::memory::format MKLDNNFormatForSize(
size_t dims_size, mkldnn::memory::format data_format) {
if (dims_size == 1) {
return mkldnn::memory::format::x;
} else if (dims_size == 2) {
return mkldnn::memory::format::nc;
}
return data_format;
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册