未验证 提交 0b2c2f82 编写于 作者: H HongyuJia 提交者: GitHub

clean mkldnn hardcode (#49130)

上级 a6dcaf64
......@@ -85,16 +85,6 @@ class FlattenOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -167,16 +157,6 @@ class FlattenGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -241,16 +221,6 @@ class Flatten2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -303,16 +273,6 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -255,16 +255,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -621,16 +611,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -124,15 +124,6 @@ class SqueezeOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -152,15 +143,6 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -220,15 +202,6 @@ class Squeeze2Op : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -269,15 +242,6 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -34,27 +34,7 @@ static const std::unordered_set<std::string> mkldnn_white_list = {
"pad3d",
"slice",
"slice_grad",
"split",
// NOTE(jiahongyu): squeeze MKLDNN kernel are disabled
// (https://github.com/PaddlePaddle/Paddle/pull/35781). If these MKLDNN
// kernels and codes are deleted in the future, attributes `use_mkldnn`
// should be removed from function declaration
"squeeze",
"squeeze_grad",
"squeeze2",
"squeeze2_grad",
// NOTE(jiahongyu): reshape and flatten have attribute use_mkldnn and they
// are registered in paddle, but they didn't change the ExpectedKernelType
// of tensor. Actually, mkldnn kernel of squeeze, reshape, and flatten
// should never be called.
"reshape",
"reshape_grad",
"reshape2",
"reshape2_grad",
"flatten",
"flatten_grad",
"flatten2",
"flatten2_grad"};
"split"};
inline bool in_mkldnn_white_list(const std::string& op_name) {
return mkldnn_white_list.find(op_name) != mkldnn_white_list.end();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册