diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index fffbdd90e74c8d9598921632798a963d60101174..45bce6e5203f8c1dbb744e0f954f7f0a71c53372 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -15,6 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/data_norm_op.h" #include #include "paddle/fluid/framework/data_layout.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -94,6 +97,13 @@ class DataNormOp : public framework::OperatorWithKernel { // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, library); @@ -251,6 +261,14 @@ class DataNormGradOp : public framework::OperatorWithKernel { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), layout, library); }