提交 e818fa10 编写于 作者: X xiaolil1 提交者: Tao Luo

Enable INT8 transpose kernel for MobileNet-SSD improvement. (#16159)

* Enable INT8 transpose kernel for MobileNet-SSD improvement.
test=develop

* Refine the license year.
test=develop

* Delete redundant code.
test=develop

* Add axis check.
test=develop
上级 374abcf3
...@@ -73,6 +73,29 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -73,6 +73,29 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
}; };
template <typename T>
class TransposeINT8MKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
std::vector<int> axis_int8 = {0, 2, 3, 1};
if (axis.size() != 1) {
PADDLE_ENFORCE_EQ(axis.size(), axis_int8.size());
for (size_t i = 0; i < axis.size(); i++) {
PADDLE_ENFORCE_EQ(axis[i], axis_int8[i],
"Current INT8 MKLDNN Transpose kernel only surpport "
"axis with [0, 2, 3, 1] due to MKL-DNN kernel "
"implementation.");
}
}
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
output->ShareDataWith(*input);
output->set_layout(DataLayout::kMKLDNN);
output->set_format(input->format());
}
};
template <typename T> template <typename T>
class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -140,7 +163,10 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,7 +163,10 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>); ops::TransposeMKLDNNOpKernel<float>,
ops::TransposeINT8MKLDNNOpKernel<uint8_t>,
ops::TransposeINT8MKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>); ops::TransposeMKLDNNOpKernel<float>);
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from mkldnn_op_test import format_reorder
class TestTransposeOp(OpTest):
def setUp(self):
self.init_op_type()
self.initTestCase()
self.initInputData()
self.use_mkldnn = True
self.axis = (0, 2, 3, 1)
self.inputs = {
'X': format_reorder(self.input_data, self.shape)
} #transform data format to 'NHWC' for INT8 transpose specially.
self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
}
self.outputs = {
'XShape': np.random.random(self.shape).astype('int8'),
'Out': self.inputs['X'].transpose(self.axis)
}
def init_op_type(self):
self.op_type = "transpose2"
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def initTestCase(self):
self.shape = (2, 3, 4, 5)
def initInputData(self):
self.input_data = (
np.random.randint(0, 100, self.shape) - 50).astype('int8')
class TestINT8Case(TestTransposeOp):
def initTestCase(self):
self.shape = (2, 4, 6, 8)
def initInputData(self):
self.input_data = (
np.random.randint(0, 100, self.shape) - 50).astype('int8')
class TestUINT8Case(TestTransposeOp):
def initTestCase(self):
self.shape = (1, 3, 5, 7)
def initDataType(self):
self.input_data = (np.random.randint(0, 100,
self.shape)).astype('uint8')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册