未验证 提交 b0378963 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #11666 from mozga-intel/mozga-intel/Batch_norm_support_other_type

The mkldnn batch norm supports other data format
...@@ -147,9 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -147,9 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"Input tensor type is not supported: ", in.type().name()); "Input tensor type is not supported: ", in.type().name());
memory::data_type out_type = in_type; memory::data_type out_type = in_type;
auto in_format = MKLDNNFormatForSize(in_tz.size(), in.format()); auto in_format = platform::MKLDNNFormatForSize(in_tz.size(), in.format());
auto out_format = auto out_format =
MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
......
...@@ -62,12 +62,6 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { ...@@ -62,12 +62,6 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
inline MKLDNNFormat MKLDNNFormatForSize(size_t dims_size,
MKLDNNFormat default_format) {
return (dims_size == 1
? mkldnn::memory::format::x
: dims_size == 2 ? mkldnn::memory::format::nc : default_format);
}
#endif #endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
......
...@@ -18,6 +18,10 @@ limitations under the License. */ ...@@ -18,6 +18,10 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/data_type_transform.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -48,8 +52,8 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -48,8 +52,8 @@ void TransformData(const OpKernelType &expected_kernel_type,
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur // Just set layout/format. No real transform occur
auto out_format = auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
MKLDNNFormatForSize(in.dims().size(), ToMKLDNNFormat(lin)); ToMKLDNNFormat(lin));
out.ShareDataWith(input_tensor); out.ShareDataWith(input_tensor);
out.set_layout(DataLayout::kMKLDNN); out.set_layout(DataLayout::kMKLDNN);
......
...@@ -115,9 +115,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -115,9 +115,12 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu; if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
auto src_memory = mkldnn::memory::format input_format =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, platform::MKLDNNFormatForSize(src_tz.size(), x->format());
to_void_cast(x_data));
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data));
// create primitive descriptor for batch norm forward // create primitive descriptor for batch norm forward
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
...@@ -251,15 +254,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -251,15 +254,21 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>; using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
// create mkldnn memory from input diff_y tensor // create mkldnn memory from input diff_y tensor
auto user_diff_dst_memory =
memory({{{diff_dst_tz}, memory::data_type::f32, diff_y->format()}, mkldnn::memory::format dst_format =
mkldnn_engine}, platform::MKLDNNFormatForSize(src_tz.size(), diff_y->format());
to_void_cast(diff_y_data));
auto user_diff_dst_memory = memory(
{{{diff_dst_tz}, memory::data_type::f32, dst_format}, mkldnn_engine},
to_void_cast(diff_y_data));
// create mkldnn memory from input x tensor // create mkldnn memory from input x tensor
auto src_memory = mkldnn::memory::format input_format =
memory({{{src_tz}, memory::data_type::f32, x->format()}, mkldnn_engine}, platform::MKLDNNFormatForSize(src_tz.size(), x->format());
to_void_cast(x_data));
auto src_memory = memory(
{{{src_tz}, memory::data_type::f32, input_format}, mkldnn_engine},
to_void_cast(x_data));
// for diff_dst, try to use same format as dst in forward pass // for diff_dst, try to use same format as dst in forward pass
auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc(); auto diff_dst_pd = batch_norm_fwd_pd.get()->dst_primitive_desc();
......
...@@ -228,7 +228,7 @@ class MKLDNNHandler { ...@@ -228,7 +228,7 @@ class MKLDNNHandler {
return dstr; return dstr;
}; };
return dims2str(operand_dims) + suffix; return dims2str(operand_dims) + suffix;
}; }
protected: protected:
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
...@@ -237,5 +237,15 @@ class MKLDNNHandler { ...@@ -237,5 +237,15 @@ class MKLDNNHandler {
bool is_reusing_; bool is_reusing_;
}; };
inline mkldnn::memory::format MKLDNNFormatForSize(
size_t dims_size, mkldnn::memory::format data_format) {
if (dims_size == 1) {
return mkldnn::memory::format::x;
} else if (dims_size == 2) {
return mkldnn::memory::format::nc;
}
return data_format;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册