diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index 480167f43525bc19defab0faab31460e6c179eff..88b96dcde8a7e74fb3be01da3383a7711efe6b74 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -68,29 +68,6 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { } }; -template -class TransposeINT8MKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - std::vector axis = ctx.Attr>("axis"); - std::vector axis_int8 = {0, 2, 3, 1}; - if (axis.size() != 1) { - PADDLE_ENFORCE_EQ(axis.size(), axis_int8.size()); - for (size_t i = 0; i < axis.size(); i++) { - PADDLE_ENFORCE_EQ(axis[i], axis_int8[i], - "Current INT8 MKLDNN Transpose kernel only surpport " - "axis with [0, 2, 3, 1] due to MKL-DNN kernel " - "implementation."); - } - } - auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); - output->ShareDataWith(*input); - output->set_layout(DataLayout::kMKLDNN); - output->set_format(input->format()); - } -}; - template class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { public: @@ -148,9 +125,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, - ops::TransposeMKLDNNOpKernel, - ops::TransposeINT8MKLDNNOpKernel, - ops::TransposeINT8MKLDNNOpKernel); + ops::TransposeMKLDNNOpKernel); REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNOpKernel);