diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc index e41bfb80dfc0452955f7978f74ccfea184886b69..4debc7ca5ec90d6cc781d10e817e9ed8650f12aa 100644 --- a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc @@ -73,6 +73,29 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { } }; +template +class TransposeINT8MKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + std::vector axis = ctx.Attr>("axis"); + std::vector axis_int8 = {0, 2, 3, 1}; + if (axis.size() != 1) { + PADDLE_ENFORCE_EQ(axis.size(), axis_int8.size()); + for (size_t i = 0; i < axis.size(); i++) { + PADDLE_ENFORCE_EQ(axis[i], axis_int8[i], + "Current INT8 MKLDNN Transpose kernel only surpport " + "axis with [0, 2, 3, 1] due to MKL-DNN kernel " + "implementation."); + } + } + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + output->ShareDataWith(*input); + output->set_layout(DataLayout::kMKLDNN); + output->set_format(input->format()); + } +}; + template class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { public: @@ -140,7 +163,10 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, - ops::TransposeMKLDNNOpKernel); + ops::TransposeMKLDNNOpKernel, + ops::TransposeINT8MKLDNNOpKernel, + ops::TransposeINT8MKLDNNOpKernel); + REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNOpKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_int8_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a8127bcc781378fa5ef4a189a0b14d079a793946 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_transpose_int8_mkldnn_op.py @@ -0,0 +1,78 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest +from mkldnn_op_test import format_reorder + + +class TestTransposeOp(OpTest): + def setUp(self): + self.init_op_type() + self.initTestCase() + self.initInputData() + self.use_mkldnn = True + self.axis = (0, 2, 3, 1) + + self.inputs = { + 'X': format_reorder(self.input_data, self.shape) + } #transform data format to 'NHWC' for INT8 transpose specially. + + self.attrs = { + 'axis': list(self.axis), + 'use_mkldnn': self.use_mkldnn, + } + + self.outputs = { + 'XShape': np.random.random(self.shape).astype('int8'), + 'Out': self.inputs['X'].transpose(self.axis) + } + + def init_op_type(self): + self.op_type = "transpose2" + + def test_check_output(self): + self.check_output(no_check_set=['XShape']) + + def initTestCase(self): + self.shape = (2, 3, 4, 5) + + def initInputData(self): + self.input_data = ( + np.random.randint(0, 100, self.shape) - 50).astype('int8') + + +class TestINT8Case(TestTransposeOp): + def initTestCase(self): + self.shape = (2, 4, 6, 8) + + def initInputData(self): + self.input_data = ( + np.random.randint(0, 100, self.shape) - 50).astype('int8') + + +class TestUINT8Case(TestTransposeOp): + def initTestCase(self): + self.shape = (1, 3, 5, 7) + + def initDataType(self): + self.input_data = (np.random.randint(0, 100, + self.shape)).astype('uint8') + + +if __name__ == '__main__': + unittest.main()