提交 08c96d1b 编写于 作者: H heqiaozhi

remove mkldnn & fix commit

test=develop
上级 725b98f2
...@@ -15,6 +15,9 @@ limitations under the License. */ ...@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/data_norm_op.h" #include "paddle/fluid/operators/data_norm_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -94,6 +97,13 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -94,6 +97,13 @@ class DataNormOp : public framework::OperatorWithKernel {
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library); library);
...@@ -251,6 +261,14 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -251,6 +261,14 @@ class DataNormGradOp : public framework::OperatorWithKernel {
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout, library); ctx.GetPlace(), layout, library);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册