diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index eae65968285703f5882d910e29bc5d8e1511cba6..43aba94dd2b45c763dfa6681b13a3852dc4be325 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -278,6 +278,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto* bias = ctx.HasInput("Bias") ? ctx.Input("Bias") : nullptr; auto* output = ctx.Output("Output"); + bool is_INT8 = ctx.HasInput("Scale_in")? true : false; + auto* scale_in = ctx.HasInput("Scale_in") ? ctx.Input("Scale_in") : nullptr; + auto* scale_in_eltwise = ctx.HasInput("Scale_in_eltwise")? ctx.Input("Scale_in_eltwise") : nullptr; + auto* scale_weights = ctx.HasInput("Scale_weights")? ctx.Input("Scale_weights") : nullptr; + auto* scale_out = ctx.HasInput("Scale_out")? ctx.Input("Scale_out") : nullptr; + PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && input->format() != memory::format::format_undef, "Wrong layout/format set for Input tensor"); @@ -329,6 +335,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + std::vector output_shift_scale; + T sum_scale = 1.0f; + if(is_INT8){ + int count = g>1? weights_tz[1]*weights_tz[0] : weights_tz[0]; + T scale_in_data = *(scale_in->data()); + T scale_in_eltwise_data = *(scale_in_eltwise->data()); + std::vector scale_weights_data(count); + for(int i=0; idata()); + } + T scale_out_data = *(scale_out->data()); + + output_shift_scale.resize(count); + for(int i=0; i { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc( bias_tz, platform::MKLDNNGetDataType(), memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_eltwise); + if(is_INT8){ + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + strides, paddings, mkldnn_engine, + fuse_relu, fuse_eltwise, + output_shift_scale, sum_scale); + } else{ + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, + strides, paddings, mkldnn_engine, + fuse_relu, fuse_eltwise); + } } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_eltwise); + if(is_INT8){ + conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_eltwise, + output_shift_scale, sum_scale); + } else{ + conv_pd = + ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_eltwise); + } } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); @@ -423,76 +466,149 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { } private: - mkldnn::primitive_attr CreatePostOps(bool fuse_relu, - bool fuse_eltwise) const { - mkldnn::primitive_attr conv_attr; - mkldnn::post_ops post_operations; - // Fusion with Elementwise layer relies on adding a sum post-operation with - // the scale parameter. It is assumed that when fuse_eltwise is true, the - // Output tensor contains the data coming from residual connection. The - // result of this post_op is: Output = scale * Output + Conv_Out. - if (fuse_eltwise) { - post_operations.append_sum(1.0f); + mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_eltwise, + const std::vector output_shift_scale, T sum_scale) const { + mkldnn::primitive_attr conv_attr; + mkldnn::post_ops post_operations; + int mask = 0; + conv_attr.set_output_scales(mask, output_shift_scale); + if (fuse_eltwise) { + post_operations.append_sum(sum_scale); + } + if (fuse_relu) { + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 0.0f; //beta + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + } + conv_attr.set_post_ops(post_operations); + return conv_attr; } - // Fusion with ReLU layer is executed through the PostOps feature. Create a - // PostOps object and configure it to execute an eltwise relu operation. - if (fuse_relu) { - constexpr float scale = 1.0f; - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 0.0f; - post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, - negative_slope, placeholder); + + mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_eltwise) const { + + mkldnn::primitive_attr conv_attr; + mkldnn::post_ops post_operations; + // Fusion with Elementwise layer relies on adding a sum post-operation with + // the scale parameter. It is assumed that when fuse_eltwise is true, the + // Output tensor contains the data coming from residual connection. The + // result of this post_op is: Output = scale * Output + Conv_Out. + + if (fuse_eltwise) { + post_operations.append_sum(1.0f); + } + // Fusion with ReLU layer is executed through the PostOps feature. Create a + // PostOps object and configure it to execute an eltwise relu operation. + if (fuse_relu) { + constexpr float scale = 1.0f; + constexpr float negative_slope = 0.0f; + constexpr float placeholder = 0.0f; + post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, + negative_slope, placeholder); + } + conv_attr.set_post_ops(post_operations); + return conv_attr; } - conv_attr.set_post_ops(post_operations); - return conv_attr; - } std::unique_ptr - ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, - const memory::desc& dst, const std::vector& strides, - const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_eltwise) const { - memory::dims stride_dims = {strides[0], strides[1]}; - memory::dims padding_dims = {paddings[0], paddings[1]}; - - auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, - dst, stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); - - auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( - conv_desc, conv_attr, engine); - - return std::unique_ptr( - p_conv_pd); - } + ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine, const bool fuse_relu, + const bool fuse_eltwise, + const std::vector output_shift_scale, const T sum_scale) const { + memory::dims stride_dims = {strides[0], strides[1]}; + memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); + + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_eltwise, output_shift_scale, sum_scale); + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); + + return std::unique_ptr( + p_conv_pd); + } std::unique_ptr - ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, - const memory::desc& bias, const memory::desc& dst, - const std::vector& strides, - const std::vector& paddings, - const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_eltwise) const { - memory::dims stride_dims = {strides[0], strides[1]}; - memory::dims padding_dims = {paddings[0], paddings[1]}; - - auto conv_desc = mkldnn::convolution_forward::desc( - mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, - bias, dst, stride_dims, padding_dims, padding_dims, - mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); - - auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( - conv_desc, conv_attr, engine); - - return std::unique_ptr( - p_conv_pd); - } + ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& dst, const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine, const bool fuse_relu, + const bool fuse_eltwise) const{ + memory::dims stride_dims = {strides[0], strides[1]}; + memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); + + mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); + + return std::unique_ptr( + p_conv_pd); + } + + std::unique_ptr + ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& bias, const memory::desc& dst, + const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine, const bool fuse_relu, + const bool fuse_eltwise, + const std::vector output_shift_scale, const T sum_scale) const { + memory::dims stride_dims = {strides[0], strides[1]}; + memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + bias, dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); + + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_eltwise, output_shift_scale, sum_scale); + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); + + return std::unique_ptr( + p_conv_pd); + } + + std::unique_ptr + ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, + const memory::desc& bias, const memory::desc& dst, + const std::vector& strides, + const std::vector& paddings, + const mkldnn::engine& engine, const bool fuse_relu, + const bool fuse_eltwise) const{ + memory::dims stride_dims = {strides[0], strides[1]}; + memory::dims padding_dims = {paddings[0], paddings[1]}; + + auto conv_desc = mkldnn::convolution_forward::desc( + mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, + bias, dst, stride_dims, padding_dims, padding_dims, + mkldnn::padding_kind::zero); + + mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise); + + auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( + conv_desc, conv_attr, engine); + + return std::unique_ptr( + p_conv_pd); + } + }; template diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 8f84bf71a7f77606bed6672f0830e3fc80165a42..ffcab41c739d264946cb0216a89881fcf3ffdf74 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -128,6 +128,21 @@ void Conv2DOpMaker::Make() { "The format of output tensor is X (one-dimensional) of size equal" "to the number of output channels. Only used with MKL-DNN.") .AsDispensable(); + AddInput("Scale_in", + "(Tensor) Scale_in to be used for int8 input data. Only used with INT8.") + .AsDispensable(); + AddInput("Scale_in_eltwise", + "(Tensor) Scale_in_eltwise to be used for int8 eltwise input data." + "Only used with MKL-DNN.") + .AsDispensable(); + AddInput("Scale_weights", + "(Tensor) Scale_weights to be used for int8 weights data." + "Only used with MKL-DNN.") + .AsDispensable(); + AddInput("Scale_out", + "(Tensor) Scale_out to be used for int8 output data." + "Only used with MKL-DNN.") + .AsDispensable(); AddOutput("Output", "(Tensor) The output tensor of convolution operator. " "The format of output tensor is also NCHW.") diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ee5ed61ec0ed6be2d6b32b9cb8d031175ac6ae7 --- /dev/null +++ b/paddle/fluid/operators/dequantize_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/operators/dequantize_op.h" +#include "paddle/fluid/framework/data_layout_transform.h" + +namespace paddle { +namespace operators { + +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::reorder; +using platform::to_void_cast; +using Tensor = framework::Tensor; +using framework::DataLayout; +using mkldnn::stream; +using platform::GetMKLDNNFormat; +//using MKLDNNDataType = mkldnn::memory::data_type; + +template +class DeQuantOpKernel : public framework::OpKernel { + public: + + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* scale = ctx.Input("Scale"); + auto* output = ctx.Output("Output"); + + auto& dev_ctx = + ctx.template device_context(); + const auto& engine = dev_ctx.GetEngine(); + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + //T scale_data = *(scale->data()); + std::vector scale_data = {*(scale->data())}; + + std::vector pipeline; + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); + mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format(); + + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, scale_data); + + auto src_md = platform::MKLDNNMemDesc( + {src_tz}, src_dt, src_fmt); + auto src_pd = mkldnn::memory::primitive_desc{src_md, engine}; + auto src_memory = std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = std::shared_ptr(new primitive::at(*src_memory)); + + auto dst_md = platform::MKLDNNMemDesc( + {dst_tz}, memory::data_type::f32, memory::format::nchw); + auto dst_pd = mkldnn::memory::primitive_desc{dst_md, engine}; + auto dst_memory = mkldnn::memory(dst_pd, to_void_cast(output_data)); + + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(dst_pd, src_pd, attri)); + auto reorder_p= std::shared_ptr(new reorder(*reorder_pd, *src_memory_p, dst_memory)); + pipeline.push_back(*reorder_p); + + } +}; + +framework::OpKernelType DeQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); +} + +void DeQuantOpMaker::Make() { + AddInput("Input","input"); + AddInput("Scale","scale..."); + AddOutput("Output","output"); +AddComment(R"DOC( +This op will quantize data from INT8 to FP32 +)DOC"); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(dequantize, ops::DeQuantOpKernel); + + diff --git a/paddle/fluid/operators/dequantize_op.h b/paddle/fluid/operators/dequantize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..350b0376c2fcd19f2a1e59a33a7702b3e904be7e --- /dev/null +++ b/paddle/fluid/operators/dequantize_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class DeQuantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; + +}; + +class DeQuantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +class DeQuantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..06336b025e4cb4ad7886a94597d23aa472041358 --- /dev/null +++ b/paddle/fluid/operators/quantize_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/operators/quantize_op.h" + +namespace paddle { +namespace operators { + +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::reorder; +using platform::to_void_cast; +using Tensor = framework::Tensor; +using framework::DataLayout; +using mkldnn::stream; +using platform::GetMKLDNNFormat; + +template +class QuantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* scale = ctx.Input("Scale"); + auto* output = ctx.Output("Output"); + + auto& dev_ctx = + ctx.template device_context(); + const auto& engine = dev_ctx.GetEngine(); + + std::vector pipeline; + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + std::vector scale_data = {*(scale->data())}; + + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, scale_data); + + auto src_md = platform::MKLDNNMemDesc( + {src_tz}, memory::data_type::f32, input->format()); + auto src_pd = mkldnn::memory::primitive_desc{src_md, engine}; + auto src_memory = std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = std::shared_ptr(new primitive::at(*src_memory)); + + auto dst_md = platform::MKLDNNMemDesc( + {dst_tz}, memory::data_type::u8, memory::format::nhwc); + auto dst_pd = mkldnn::memory::primitive_desc{dst_md, engine}; + auto dst_memory = mkldnn::memory(dst_pd, to_void_cast(output_data)); + + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(dst_pd, src_pd, attri)); + auto reorder_p= std::shared_ptr(new reorder(*reorder_pd, *src_memory_p, dst_memory)); + pipeline.push_back(*reorder_p); + } +}; + +framework::OpKernelType QuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); + //ctx.device_context()); +} + + +void QuantOpMaker::Make() { + AddInput("Input","input"); + AddInput("Scale","scale..."); + AddOutput("Output","output"); + AddComment(R"DOC( +This op will quantize data from FP32 to INT8 +)DOC"); +} + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + + +REGISTER_OPERATOR(quantize, ops::QuantOp, ops::QuantOpMaker, paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(quantize, ops::QuantOpKernel); + +//REGISTER_OP_KERNEL(quantization, MKLDNN, paddle::platform::CPUPlace, ops::QuantOpKernel); + + + + + diff --git a/paddle/fluid/operators/quantize_op.h b/paddle/fluid/operators/quantize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..6df9a71b5f9c83d7db215ea5e7ecb02c767a290d --- /dev/null +++ b/paddle/fluid/operators/quantize_op.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class QuantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; + +}; + +class QuantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +//void Make() { +// AddInput("Input","input"); +// AddInput("Scale","scale..."); +// AddOutput("Output","output"); +//} +}; + + +} // namespace operators +} // namespace paddle + diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9abc0a4ea45ef8ebcd8a96571823d9dace05fa7f --- /dev/null +++ b/paddle/fluid/operators/requantize_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/operators/requantize_op.h" +#include "paddle/fluid/framework/data_layout_transform.h" + +namespace paddle { +namespace operators { + +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::reorder; +using platform::to_void_cast; +using Tensor = framework::Tensor; +using framework::DataLayout; +using mkldnn::stream; +using platform::GetMKLDNNFormat; + +template +class ReQuantOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* scale = ctx.Input("Scale"); + auto* output = ctx.Output("Output"); + + auto& dev_ctx = + ctx.template device_context(); + const auto& engine = dev_ctx.GetEngine(); + + std::vector pipeline; + std::vector src_tz = paddle::framework::vectorize2int(input->dims()); + std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); + mkldnn::memory::data_type dst_dt = paddle::framework::ToMKLDNNDataType(output->type()); + mkldnn::memory::format src_fmt = memory::format::nhwc;//input->format(); + mkldnn::memory::format dst_fmt = memory::format::nhwc;//output->format(); + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + //T scale_data = *(scale->data()); + std::vector scale_data = {*(scale->data())}; + + mkldnn::primitive_attr attri; + int mask = 0; + attri.set_output_scales(mask, scale_data); + //attri.set_int_output_round_mode(round_nearest); //FIX ME + + auto src_md = platform::MKLDNNMemDesc( + {src_tz}, src_dt, src_fmt); //FIX ME WITH S8 + auto src_pd = mkldnn::memory::primitive_desc{src_md, engine}; + auto src_memory = std::make_shared(src_pd, to_void_cast(input_data)); + std::shared_ptr src_memory_p = std::shared_ptr(new primitive::at(*src_memory)); + + auto dst_md = platform::MKLDNNMemDesc( + {dst_tz}, dst_dt, dst_fmt); + auto dst_pd = mkldnn::memory::primitive_desc{dst_md, engine}; + auto dst_memory = mkldnn::memory(dst_pd, to_void_cast(output_data)); + + auto reorder_pd = std::shared_ptr( + new reorder::primitive_desc(dst_pd, src_pd, attri)); + auto reorder_p= std::shared_ptr(new reorder(*reorder_pd, *src_memory_p, dst_memory)); + pipeline.push_back(*reorder_p); + + } +}; + +framework::OpKernelType ReQuantOp::GetExpectedKernelType(const framework::ExecutionContext& ctx) const { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()),ctx.GetPlace(),layout_, library_); +} + +void ReQuantOpMaker::Make() { + AddInput("Input","input"); + AddInput("Scale","scale..."); + AddOutput("Output","output"); +AddComment(R"DOC( +This op will requantize data from INT8 to INT8 +)DOC"); +} + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker, paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OP_CPU_KERNEL(requantize, ops::ReQuantOpKernel); + diff --git a/paddle/fluid/operators/requantize_op.h b/paddle/fluid/operators/requantize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f96df360691a4f6f1b8c66f3402f3c5c9a62bcf3 --- /dev/null +++ b/paddle/fluid/operators/requantize_op.h @@ -0,0 +1,45 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class ReQuantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + +class ReQuantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +} // namespace operators +} // namespace paddle +