未验证 提交 98e00793 编写于 作者: H HongyuJia 提交者: GitHub

[Opt transpose2] Opt GetExpectedKernelType code of transpose2 (#46692)

* solve transpose2, follow #22402

* fix CI cmake

* update REGISTER_OP_KERNEL of transpose2
上级 46595d6b
......@@ -518,13 +518,6 @@ function(op_library TARGET)
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "transpose_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, FP32);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, S8);\n")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(transpose2, MKLDNN, U8);\n")
elseif(${MKLDNN_FILE} STREQUAL "fc_mkldnn_op")
file(APPEND ${pybind_file}
"USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(fc, MKLDNN, FP32);\n")
......
......@@ -172,35 +172,6 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
ops::kTransposeMKLDNNFP32,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(transpose2,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
ops::kTransposeMKLDNNINT8,
ops::TransposeMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
transpose2,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
ops::kTransposeMKLDNNFP32,
ops::TransposeMKLDNNOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(transpose,
MKLDNN,
::paddle::platform::CPUPlace,
......@@ -210,3 +181,11 @@ REGISTER_OP_KERNEL(transpose_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::TransposeMKLDNNGradOpKernel<float>);
REGISTER_OP_KERNEL(transpose2,
MKLDNN,
::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>,
ops::TransposeMKLDNNOpKernel<uint8_t>,
ops::TransposeMKLDNNOpKernel<int8_t>,
ops::TransposeMKLDNNOpKernel<paddle::platform::bfloat16>);
......@@ -249,18 +249,10 @@ class Transpose2Op : public TransposeOp {
OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, data_type)) {
using framework::proto::VarType;
auto input_data_type = framework::TransToProtoVarType(
ctx.Input<phi::DenseTensor>("X")->dtype());
int customized_type_value = (input_data_type == VarType::INT8 ||
input_data_type == VarType::UINT8)
? kTransposeMKLDNNINT8
: kTransposeMKLDNNFP32;
return framework::OpKernelType(data_type,
ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN,
customized_type_value);
framework::LibraryType::kMKLDNN);
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册