diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..44e8281424ba6937dad2c2dee1db4dee96b3b2eb --- /dev/null +++ b/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/requantize_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::reorder; +using platform::to_void_cast; +using Tensor = framework::Tensor; +using framework::DataLayout; +using mkldnn::stream; +using platform::GetMKLDNNFormat; + +template +class ReQuantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto scale_in = ctx.Attr("Scale_in"); + auto scale_out = ctx.Attr("Scale_out"); + auto* output = ctx.Output("Output"); + auto& dev_ctx = + ctx.template device_context(); + const auto& engine = dev_ctx.GetEngine(); + + std::vector pipeline; + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + mkldnn::memory::data_type src_dt = + paddle::framework::ToMKLDNNDataType(input->type()); + mkldnn::memory::data_type dst_dt = src_dt; // TODO(Xiaoli) support + // requantize from different + // data type (e.g., s8 to u8) + mkldnn::memory::format src_fmt = memory::format::nhwc; + mkldnn::memory::format dst_fmt = memory::format::nhwc; + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + float scale_shift = scale_out / scale_in; + + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, {scale_shift}); + + auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); + auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); + auto src_memory = + std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = + std::shared_ptr(new primitive::at(*src_memory)); + + auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt); + auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); + auto dst_memory = mkldnn::memory(dst_pd, to_void_cast(output_data)); + + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(src_pd, dst_pd, attri)); + + auto reorder_p = std::shared_ptr( + new reorder(*reorder_pd, *src_memory_p, dst_memory)); + pipeline.push_back(*reorder_p); + stream(stream::kind::eager).submit(pipeline).wait(); + + output->set_layout(DataLayout::kMKLDNN); + output->set_format(GetMKLDNNFormat(dst_memory)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(requantize, MKLDNN, ::paddle::platform::CPUPlace, + ops::ReQuantOpKernel, ops::ReQuantOpKernel); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..08ba1470aaddf146fe3685ff6c3cd9f3d7e16d75 --- /dev/null +++ b/paddle/fluid/operators/requantize_op.cc @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/requantize_op.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +namespace paddle { +namespace operators { + +framework::OpKernelType ReQuantOp::GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library_ = framework::LibraryType::kMKLDNN; + framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; + + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_); +} + +void ReQuantOpMaker::Make() { + AddInput("Input", "input data"); + AddOutput("Output", "output data"); + AddAttr("Scale_in", "scale in data").SetDefault({1.0f}); + AddAttr("Scale_out", "scale out data").SetDefault({1.0f}); + AddComment( + R"DOC(This op will re-quantize data from INT8 with scale_in to INT8 with scale_out)DOC"); +} + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, + paddle::framework::DefaultGradOpDescMaker); diff --git a/paddle/fluid/operators/requantize_op.h b/paddle/fluid/operators/requantize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c2b154db11dc713fdce1b9ef2f2616428bc09202 --- /dev/null +++ b/paddle/fluid/operators/requantize_op.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class ReQuantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class ReQuantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py b/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py index 871f8403f812c87ac493b82482fe01fdf61037d4..57a5714fc7853905703e9db31bc143fb5cabfacb 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/mkldnn_op_test.py @@ -70,3 +70,17 @@ def check_if_mkldnn_primitives_exist_in_bwd(test_case, op_type, x, out, fetch_list=['x@GRAD', 'out']) __assert_close(x_grad, out[0], 'x@GRAD') + + +def format_reorder(out, size): + in_n = size[0] + out_h = size[2] + out_w = size[3] + out_c = size[1] + out_tmp = np.zeros((in_n, out_h, out_w, out_c)) + for n in range(in_n): + for i in range(out_h): + for j in range(out_w): + for m in range(out_c): + out_tmp[n, i, j, m] = out[n, m, i, j] + return out_tmp.reshape(in_n, out_c, out_h, out_w) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index 100a03cea0f740a615c4a08810d4ad9e8c974d7a..c7b8a096bf1a7e2f5b63b136c7036edad863c888 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -20,6 +20,7 @@ import numpy as np import paddle.fluid.core as core from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, TestConv2dOp +from mkldnn_op_test import format_reorder def conv2d_forward_refer(input, filter, group, conv_param): @@ -29,20 +30,6 @@ def conv2d_forward_refer(input, filter, group, conv_param): return format_reorder(out, size) -def format_reorder(out, size): - in_n = size[0] - out_h = size[2] - out_w = size[3] - out_c = size[1] - out_tmp = np.zeros((in_n, out_h, out_w, out_c)) - for n in range(in_n): - for i in range(out_h): - for j in range(out_w): - for m in range(out_c): - out_tmp[n, i, j, m] = out[n, m, i, j] - return out_tmp.reshape(in_n, out_c, out_h, out_w) - - class TestConv2dInt8Op(TestConv2dOp): def setUp(self): self.op_type = "conv2d" diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a4683558539d3f9daa6a1146355acc3ff2bab7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_requantize_mkldnn_op.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest +from mkldnn_op_test import format_reorder + + +class TestReQuantizeOp(OpTest): + def setUp(self): + self.op_type = 'requantize' + self.scale_in = 2.0 + self.scale_out = 1.5 + self.input_size = [1, 1, 5, 5] + self.data_type = 'int8' + self.set_scale() + self.set_data_type() + + scale_shift = self.scale_out / self.scale_in + + if self.data_type == 'int8': + input = (np.random.randint(0, 100, self.input_size) - 50 + ).astype(self.data_type) + output_tmp = np.round(input.astype('float32') * + scale_shift).astype('int8') + else: + input = (np.random.randint(0, 100, + self.input_size)).astype(self.data_type) + output_tmp = np.round(input.astype('float32') * + scale_shift).astype('uint8') + + output = format_reorder(output_tmp, self.input_size) + + self.inputs = {'Input': OpTest.np_dtype_to_fluid_dtype(input)} + + self.outputs = {'Output': output} + + self.attrs = {'Scale_in': self.scale_in, 'Scale_out': self.scale_out} + + def test_check_output(self): + self.check_output() + + def set_scale(self): + pass + + def set_data_type(OpTest): + pass + + +#--------------------test requantize with s8 input-------------------- + + +class TestReQuantizeOp1(TestReQuantizeOp): + def set_scale(self): + self.scale_in = 1.5 + self.scale_out = 1.5 + + +class TestReQuantizeOp2(TestReQuantizeOp): + def set_scale(self): + self.scale_in = 0.1 + self.scale_out = 0.2 + + +#--------------------test requantize with u8 input-------------------- + + +class TestReQuantizeOp3(TestReQuantizeOp1): + def set_data_type(self): + self.data_type = 'uint8' + + +class TestReQuantizeOp4(TestReQuantizeOp2): + def set_data_type(self): + self.data_type = 'uint8' + + +if __name__ == '__main__': + unittest.main()