From bb45af02acee45f8dc38c2ef4f16d8e341850cf9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 9 Jul 2020 13:59:23 +0800 Subject: [PATCH] add the c++ part of Imperative QAT. test=develop (#25446) --- paddle/fluid/operators/fake_dequantize_op.cc | 6 +- paddle/fluid/operators/fake_quantize_op.cc | 105 ++++++++++++++---- paddle/fluid/operators/fake_quantize_op.cu | 6 +- paddle/fluid/operators/fake_quantize_op.h | 73 ++++++++++-- paddle/fluid/platform/dynload/cusolver.h | 1 + paddle/fluid/pybind/op_function_generator.cc | 1 + .../tests/unittests/test_fake_quantize_op.py | 30 +++++ 7 files changed, 185 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index a1cf30ca7f..0d2b951ee1 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -29,7 +29,7 @@ struct DequantizeFunctor { auto out_e = framework::EigenVector::Flatten(*out); auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = scale_factor[0] * in_e / max_range; + out_e.device(dev) = in_e * scale_factor[0] / max_range; } }; @@ -48,7 +48,7 @@ struct ChannelDequantizeFunctor { auto in_e = framework::EigenVector::Flatten(one_channel_in); auto out_e = framework::EigenVector::Flatten(one_channel_out); auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = s * in_e / max_range; + out_e.device(dev) = in_e * s / max_range; } } else if (scale_num == 2) { int batch_size = in->dims()[0]; @@ -67,7 +67,7 @@ struct ChannelDequantizeFunctor { auto in_e = framework::EigenVector::Flatten(one_channel_in); auto out_e = framework::EigenVector::Flatten(one_channel_out); auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = (s * scale_two[0]) * in_e / max_range; + out_e.device(dev) = in_e * s * scale_two[0] / max_range; } } } diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 8c07e445a6..401cc448ac 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -82,7 +82,7 @@ struct ClipAndFakeQuantDequantFunctor { out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = - (s / bin_cnt) * (bin_cnt * inv_s * out_e).round(); + (bin_cnt * inv_s * out_e).round() * s / static_cast(bin_cnt); } }; template struct ClipAndFakeQuantDequantFunctor { template struct FindMovingAverageAbsMaxFunctor; -class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { +class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { public: - FakeQuantizeAbsMaxOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + FakeQuantOrWithDequantAbsMaxOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", + "FakeQuantOrWithDequantAbsMaxOp"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", - "FakeQuantizeAbsMax"); + "FakeQuantOrWithDequantAbsMaxOp"); OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", - "FakeQuantizeAbsMax"); + "FakeQuantOrWithDequantAbsMaxOp"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("OutScale", {1}); ctx->ShareLoD("X", /*->*/ "Out"); @@ -199,7 +200,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { } }; -class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { +class FakeQuantOrWithDequantAbsMaxOpMaker + : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) Input is float data type."); @@ -217,12 +219,19 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { bit_length)); }); AddComment(R"DOC( -FakeQuantize operator +This is a Base Op which support FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. +FakeQuantAbsMaxOp operator is used in the dynamic quantization. $$scale = max(abs(X))$$ $$range = 2^{bit_length - 1} - 1$$ $$Out = round(X/scale * range)$$ +FakeQuantDequantAbsMaxOp operator do the abs_max quant and then dequant. + +$$scale = max(abs(X))$$ +$$range = 2^{bit\_length - 1} - 1$$ +$$Out = round(X/scale * range) * scale / range$$ + )DOC"); } }; @@ -414,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker "for training. Some layers may run faster when this is true.") .SetDefault(false); AddComment(R"DOC( -This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp -FakeQuantMovingAverageAbsMaxOp operator is used in static quantization. +This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp. +FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization. $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$range = 2^{bit\_length - 1} - 1$$ $$Out = round(X/scale * range)$$ -FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max op quant and then dequant. +FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max quant and then dequant. $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$range = 2^{bit\_length - 1} - 1$$ @@ -490,6 +499,46 @@ $$Out = X$$ } }; +class FakeQuantDequantGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, + "FakeQuantDequantGradOp"); + + auto x_grad_name = framework::GradVarName("X"); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(x_grad_name), true, + platform::errors::PreconditionNotMet( + "FakeQuantDequantGradOp doesn't have the output named %s.", + x_grad_name)); + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("fake_quantize_dequantize_grad"); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -497,13 +546,21 @@ namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR( - fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp, - ops::FakeQuantizeAbsMaxOpMaker, + fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, + ops::FakeQuantOrWithDequantAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel); +REGISTER_OPERATOR(fake_quantize_dequantize_abs_max, + ops::FakeQuantOrWithDequantAbsMaxOp, + ops::FakeQuantOrWithDequantAbsMaxOpMaker, + ops::FakeQuantDequantGradMaker, + ops::FakeQuantDequantGradMaker); +REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max, + ops::FakeQuantizeDequantizeAbsMaxKernel); + REGISTER_OPERATOR( fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, ops::FakeQuantizeRangeAbsMaxOpMaker, @@ -518,16 +575,14 @@ REGISTER_OPERATOR( ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); - REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel); -REGISTER_OPERATOR( - fake_quantize_dequantize_moving_average_abs_max, - ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, - ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max, + ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, + ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, + ops::FakeQuantDequantGradMaker, + ops::FakeQuantDequantGradMaker); REGISTER_OP_CPU_KERNEL( fake_quantize_dequantize_moving_average_abs_max, ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); @@ -547,3 +602,7 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleKernel); + +REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp); +REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad, + ops::FakeQuantDequantGradKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 6813c03933..75a55fa821 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -138,9 +138,9 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, int tid = threadIdx.x; T s = scale[0]; + T inv_s = inverse(s); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T x = in[i]; - T inv_s = inverse(s); T v = x > s ? s : x; v = v < -s ? -s : v; v = bin_cnt * inv_s * v; @@ -335,6 +335,8 @@ namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max, + ops::FakeQuantizeDequantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, @@ -347,3 +349,5 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, REGISTER_OP_CUDA_KERNEL( fake_quantize_dequantize_moving_average_abs_max, ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad, + ops::FakeQuantDequantGradKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 5c27ee8748..fa5048852e 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/hostdevice.h" +#include "paddle/fluid/platform/transform.h" namespace paddle { namespace operators { @@ -81,7 +82,7 @@ struct FindMovingAverageAbsMaxFunctor { }; template -class FakeQuantizeAbsMaxKernel : public framework::OpKernel { +class FakeAbsMaxKernelBase : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); @@ -95,8 +96,38 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); const T* in_data = in->data(); FindAbsMaxFunctor()(dev_ctx, in_data, in->numel(), out_s); - ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, - bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); + } + + virtual ~FakeAbsMaxKernelBase() = default; + + protected: + virtual void RunClipFunctor(const DeviceContext& dev_ctx, + const framework::Tensor& in, + const framework::Tensor& scale, int bin_cnt, + framework::Tensor* out) const = 0; +}; + +template +class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase { + protected: + void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, + const framework::Tensor& scale, int bin_cnt, + framework::Tensor* out) const override { + ClipAndFakeQuantFunctor()(dev_ctx, in, scale, bin_cnt, + out); + } +}; + +template +class FakeQuantizeDequantizeAbsMaxKernel + : public FakeAbsMaxKernelBase { + protected: + void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, + const framework::Tensor& scale, int bin_cnt, + framework::Tensor* out) const override { + ClipAndFakeQuantDequantFunctor()(dev_ctx, in, scale, + bin_cnt, out); } }; @@ -167,11 +198,6 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { template class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { public: - ~FakeMovingAverageAbsMaxKernelBase() {} - virtual void RunClipFunctor(const DeviceContext& dev_ctx, - const framework::Tensor& in, - const framework::Tensor& in_scale, int bin_cnt, - framework::Tensor* out) const = 0; void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* in_scale = context.Input("InScale"); @@ -212,12 +238,20 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); } + + virtual ~FakeMovingAverageAbsMaxKernelBase() = default; + + protected: + virtual void RunClipFunctor(const DeviceContext& dev_ctx, + const framework::Tensor& in, + const framework::Tensor& in_scale, int bin_cnt, + framework::Tensor* out) const = 0; }; template class FakeQuantizeMovingAverageAbsMaxKernel : public FakeMovingAverageAbsMaxKernelBase { - public: + protected: void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& in_scale, int bin_cnt, framework::Tensor* out) const override { @@ -229,7 +263,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel template class FakeQuantizeDequantizeMovingAverageAbsMaxKernel : public FakeMovingAverageAbsMaxKernelBase { - public: + protected: void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, const framework::Tensor& in_scale, int bin_cnt, framework::Tensor* out) const override { @@ -277,5 +311,24 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { } }; +template +class FakeQuantDequantGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* d_out = + context.Input(framework::GradVarName("Out")); + auto x_grad_name = framework::GradVarName("X"); + auto* d_x = context.Output(x_grad_name); + PADDLE_ENFORCE_NOT_NULL( + d_x, platform::errors::PreconditionNotMet( + "FakeQuantDequantGradOp doesn't have the output named %s.", + x_grad_name)); + + // Initialize dx as same as d_out + d_x->mutable_data(context.GetPlace()); + framework::TensorCopy(*d_out, context.GetPlace(), d_x); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 379bf78d0a..8eda352529 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include // NOLINT diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index ee9fa26b2f..7412eede11 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -80,6 +80,7 @@ std::map> op_passing_outs_map = { {"matmul", {"Out"}}, {"fake_quantize_dequantize_moving_average_abs_max", {"Out", "OutScale", "OutAccum", "OutState"}}, + {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}}, {"amp_check_finite_and_scale", {"Out", "FoundInfinite"}}, }; diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 6943f3d0ff..4314faaf39 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -242,6 +242,36 @@ class TestFakeQuantDequantMovingOp(TestMovingOpBase): return np.round(self.inputs['X'] / out_scale * range_v) * out_scale / range_v + def test_check_grad(self): + x = self.inputs["X"] + gradient = [np.ones(x.shape) / np.product(x.shape)] + self.check_grad(["X"], "Out", user_defined_grads=gradient) + + +class TestFakeQuantDequantAbsOp(OpTest): + def setUp(self): + self.op_type = "fake_quantize_dequantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + out_data = self.calc_output(scale) + self.outputs = { + 'Out': out_data, + 'OutScale': np.array(scale).astype("float32"), + } + + def calc_output(self, scale): + range_v = (1 << (self.attrs['bit_length'] - 1)) - 1 + return np.round(self.inputs['X'] / scale * range_v) * scale / range_v + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + x = self.inputs["X"] + gradient = [np.ones(x.shape) / np.product(x.shape)] + self.check_grad(["X"], "Out", user_defined_grads=gradient) + if __name__ == "__main__": unittest.main() -- GitLab