未验证 提交 0eaab803 编写于 作者: J jakpiase 提交者: GitHub

Disabled oneDNN reshape1/2 and squeeze1/2 kernels (#35781)

* disabled matmul_v2 grad

* Revert "disabled matmul_v2 grad"

This reverts commit b569bcef162116ca9f7963f3975b4a412f9e8555.

* reverted disabling matmul_v2, disabled reshape and squeeze
上级 2c781455
...@@ -249,11 +249,11 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -249,11 +249,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -367,11 +367,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -367,11 +367,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -558,11 +558,11 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -558,11 +558,11 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -114,11 +114,11 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -114,11 +114,11 @@ class SqueezeOp : public framework::OperatorWithKernel {
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -141,11 +141,11 @@ class SqueezeGradOp : public framework::OperatorWithKernel { ...@@ -141,11 +141,11 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -242,11 +242,11 @@ class Squeeze2Op : public framework::OperatorWithKernel { ...@@ -242,11 +242,11 @@ class Squeeze2Op : public framework::OperatorWithKernel {
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -288,11 +288,11 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { ...@@ -288,11 +288,11 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
} // }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册