From a914d9b116af76142a20059f91068ed9c4835f57 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 7 May 2019 11:16:23 +0800 Subject: [PATCH] Quant output scale (#17215) * Add MovingAverageAbsMaxScale operator which is only used for calculating the quantization scale. * test=develop * change the output into inplace. test=develop * Revert "test=develop" This reverts commit 696cf62699ba1e1c98f61f7345ac7060010eb29a. * Revert "change the output into inplace. test=develop" This reverts commit a19acd20f07eee82622701a3015e6e9c073a5e0b. * test=develop. * update the MovingAverageAbsMaxScaleOp test. test=develop --- paddle/fluid/operators/fake_quantize_op.cc | 72 ++++++++++++++++++- paddle/fluid/operators/fake_quantize_op.cu | 2 + paddle/fluid/operators/fake_quantize_op.h | 42 +++++++++++ .../tests/unittests/test_fake_quantize_op.py | 32 +++++++++ 4 files changed, 146 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 054ef4658cc..25ca1f7e0a0 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -388,14 +388,76 @@ class FakeQuantizeMovingAverageAbsMaxOpMaker AddComment(R"DOC( FakeQuantize operator is used in static quantization. -$$scale = (0.9*max(abs(x))+accum)/(0.9*state+1)$$ -$$range = 2^{bit_length - 1} - 1$$ +$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ +$$range = 2^{bit\_length - 1} - 1$$ $$Out = round(X/scale * range)$$ )DOC"); } }; +class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of MovingAverageAbsMaxScaleOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of MovingAverageAbsMaxScaleOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("OutScale"), + "Output(OutScale) of MovingAverageAbsMaxScaleOp" + "should not be null"); + if (ctx->HasOutput("OutState")) { + ctx->SetOutputDim("OutState", {1}); + } + if (ctx->HasOutput("OutAccum")) { + ctx->SetOutputDim("OutAccum", {1}); + } + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScale", {1}); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class MovingAverageAbsMaxScaleOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input is float data type."); + AddInput("InAccum", "Last accum.").AsDispensable(); + AddInput("InState", "Last state.").AsDispensable(); + AddOutput("Out", + "(Tensor) Output tensor is just equivalent to the input tensor."); + AddOutput("OutScale", " Current scale"); + AddOutput("OutState", "(Tensor) state buffer.").AsDispensable(); + AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable(); + AddAttr("moving_rate", "(float, default 0.9) moving rate.") + .SetDefault(0.9); + AddAttr("is_test", + "(bool, default false) Set true for inference only and false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddComment(R"DOC( +MovingAverageAbsMaxScale operator is only used for calculating the quantization scale. +And it will not quantize the input tensor. + +$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ +$$Out = X$$ + +)DOC"); + } +}; + } // namespace operators } // namespace paddle @@ -426,3 +488,9 @@ REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxKernel); + +REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, + ops::MovingAverageAbsMaxScaleOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, + ops::MovingAverageAbsMaxScaleKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 7d551106756..6e1d40cac76 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -300,3 +300,5 @@ REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, REGISTER_OP_CUDA_KERNEL( fake_quantize_moving_average_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, + ops::MovingAverageAbsMaxScaleKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 5ab38b086df..87bcece5824 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -197,5 +198,46 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel { } }; +template +class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + auto& dev_ctx = context.template device_context(); + framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + + bool is_test = context.Attr("is_test"); + // testing + if (is_test) { + return; + } + + // training + auto* in_accum = context.Input("InAccum"); + auto* in_state = context.Input("InState"); + auto& allocator = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + auto cur_scale = allocator.Allocate(1 * sizeof(T)); + T* cur_scale_data = static_cast(cur_scale->ptr()); + + FindAbsMaxFunctor()(dev_ctx, in->data(), in->numel(), + cur_scale_data); + + auto* out_state = context.Output("OutState"); + auto* out_accum = context.Output("OutAccum"); + auto* out_scale = context.Output("OutScale"); + out_state->mutable_data(context.GetPlace()); + out_accum->mutable_data(context.GetPlace()); + out_scale->mutable_data(context.GetPlace()); + float moving_rate = context.Attr("moving_rate"); + + FindMovingAverageAbsMaxFunctor()( + dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, + out_accum, out_scale); + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 07038b0441d..8d82438c15c 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -130,6 +130,38 @@ class TestFakeQuantizeMovingOp(OpTest): self.check_output() +class TestMovingAverageAbsMaxScaleOp(OpTest): + def setUp(self): + self.op_type = "moving_average_abs_max_scale" + self.attrs = {'moving_rate': float(0.9), 'is_test': False} + accum = np.zeros(1).astype("float32") + accum[0] = 1 + state = np.zeros(1).astype("float32") + state[0] = 1 + self.inputs = { + 'X': np.random.random((8, 16, 7, 7)).astype("float32"), + 'InAccum': accum, + 'InState': state, + } + + out_accum = np.zeros(1).astype("float32") + out_state = np.zeros(1).astype("float32") + out_scale = np.zeros(1).astype("float32") + out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max( + np.abs(self.inputs['X'])).astype("float32") + out_state[0] = self.attrs['moving_rate'] * state[0] + 1 + out_scale = out_accum / out_state + self.outputs = { + 'Out': self.inputs['X'], + 'OutAccum': out_accum, + 'OutState': out_state, + 'OutScale': out_scale, + } + + def test_check_output(self): + self.check_output() + + class TestFakeQuantizeRangeAbsMaxOp2(OpTest): def setUp(self): self.op_type = "fake_quantize_range_abs_max" -- GitLab