From ff7f911b4d6b1794a2823ee323edd524ca40ee1e Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Tue, 21 May 2019 10:28:57 +0800 Subject: [PATCH] add quant_dequant_moving_avg_max_abs op (#17480) * add quant_dequant_moving_avg_max_abs op test=develop * add more note for quantdequant op test=develop --- paddle/fluid/operators/fake_quantize_op.cc | 70 +++++++++---- paddle/fluid/operators/fake_quantize_op.cu | 41 ++++++++ paddle/fluid/operators/fake_quantize_op.h | 44 ++++++++- .../tests/unittests/test_fake_quantize_op.py | 97 +++++++++++-------- 4 files changed, 190 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 25ca1f7e0a0..034f3c7dceb 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -68,6 +68,23 @@ struct ClipAndFakeQuantFunctor { template struct ClipAndFakeQuantFunctor; +template +struct ClipAndFakeQuantDequantFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + T s = scale.data()[0]; + platform::Transform trans; + trans(ctx, in.data(), in.data() + in.numel(), + out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); + auto out_e = framework::EigenVector::Flatten(*out); + out_e.device(*ctx.eigen_device()) = + (s / bin_cnt) * (bin_cnt / s * out_e).round(); + } +}; +template struct ClipAndFakeQuantDequantFunctor; + template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, @@ -324,24 +341,26 @@ $$Out = round(X/scale * range)$$ } }; -class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel { +class FakeQuantOrWithDequantMovingAverageAbsMaxOp + : public framework::OperatorWithKernel { public: - FakeQuantizeMovingAverageAbsMaxOp(const std::string& type, - const framework::VariableNameMap& inputs, - const framework::VariableNameMap& outputs, - const framework::AttributeMap& attrs) + FakeQuantOrWithDequantMovingAverageAbsMaxOp( + const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " + "should not be null."); PADDLE_ENFORCE( - ctx->HasInput("X"), - "Input(X) of FakeQuantizeMovingAverageAbsMaxOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("Out"), - "Output(Out) of FakeQuantizeMovingAverageAbsMaxOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("OutScale"), - "Output(OutScale) of FakeQuantizeMovingAverageAbsMaxOp " - "should not be null"); + ctx->HasOutput("OutScale"), + "Output(OutScale) of FakeQuantOrWithDequantMovingAverageAbsMaxOp " + "should not be null"); if (ctx->HasOutput("OutState")) { ctx->SetOutputDim("OutState", {1}); } @@ -361,7 +380,7 @@ class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel { } }; -class FakeQuantizeMovingAverageAbsMaxOpMaker +class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -386,12 +405,19 @@ class FakeQuantizeMovingAverageAbsMaxOpMaker "for training. Some layers may run faster when this is true.") .SetDefault(false); AddComment(R"DOC( -FakeQuantize operator is used in static quantization. +This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp +FakeQuantMovingAverageAbsMaxOp operator is used in static quantization. $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$range = 2^{bit\_length - 1} - 1$$ $$Out = round(X/scale * range)$$ +FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max op quant and then dequant. + +$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ +$$range = 2^{bit\_length - 1} - 1$$ +$$Out = round(X/scale * range) * scale / range$$ + )DOC"); } }; @@ -477,11 +503,21 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel); REGISTER_OPERATOR(fake_quantize_moving_average_abs_max, - ops::FakeQuantizeMovingAverageAbsMaxOp, - ops::FakeQuantizeMovingAverageAbsMaxOpMaker, + ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, + ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, paddle::framework::EmptyGradOpMaker); + REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel); + +REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max, + ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, + ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fake_quantize_dequantize_moving_average_abs_max, + ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); + REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp, ops::FakeChannelWiseQuantizeAbsMaxOpMaker, diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 6e1d40cac76..3d24e8986df 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -129,6 +129,23 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, } } +template +__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, + const int bin_cnt, const int n, + T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + T s = scale[0]; + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T x = in[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt / s * v; + out[i] = round(v) * s / bin_cnt; + } +} + template struct ClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, @@ -149,6 +166,27 @@ struct ClipAndFakeQuantFunctor { template struct ClipAndFakeQuantFunctor; +template +struct ClipAndFakeQuantDequantFunctor { + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& scale, + const int bin_cnt, framework::Tensor* out) { + int num = in.numel(); + int block = 1024; + int grid = (block - 1 + num) / block; + + const T* in_data = in.data(); + const T* scale_data = scale.data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + + ClipAndQuantDequantKernel<<>>( + in_data, scale_data, bin_cnt, num, out_data); + } +}; + +template struct ClipAndFakeQuantDequantFunctor; + template __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, const int bin_cnt, const int n, @@ -302,3 +340,6 @@ REGISTER_OP_CUDA_KERNEL( ops::FakeQuantizeMovingAverageAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleKernel); +REGISTER_OP_CUDA_KERNEL( + fake_quantize_dequantize_moving_average_abs_max, + ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 87bcece5824..422d99dd433 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -35,6 +35,13 @@ struct ClipAndFakeQuantFunctor { framework::Tensor* out); }; +template +struct ClipAndFakeQuantDequantFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& scale, const int bin_cnt, + framework::Tensor* out); +}; + template struct FindRangeAbsMaxFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, @@ -150,8 +157,13 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { }; template -class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel { +class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel { public: + ~FakeMovingAverageAbsMaxKernelBase() {} + virtual void RunClipFunctor(const DeviceContext& dev_ctx, + const framework::Tensor& in, + const framework::Tensor& in_scale, int bin_cnt, + framework::Tensor* out) const = 0; void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* in_scale = context.Input("InScale"); @@ -165,8 +177,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel { // testing if (is_test) { - ClipAndFakeQuantFunctor()(dev_ctx, *in, *in_scale, - bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, out); return; } @@ -193,8 +204,31 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel { dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, out_accum, out_scale); - ClipAndFakeQuantFunctor()(dev_ctx, *in, *out_scale, - bin_cnt, out); + RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); + } +}; + +template +class FakeQuantizeMovingAverageAbsMaxKernel + : public FakeMovingAverageAbsMaxKernelBase { + public: + void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, + const framework::Tensor& in_scale, int bin_cnt, + framework::Tensor* out) const override { + ClipAndFakeQuantFunctor()(dev_ctx, in, in_scale, bin_cnt, + out); + } +}; + +template +class FakeQuantizeDequantizeMovingAverageAbsMaxKernel + : public FakeMovingAverageAbsMaxKernelBase { + public: + void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, + const framework::Tensor& in_scale, int bin_cnt, + framework::Tensor* out) const override { + ClipAndFakeQuantDequantFunctor()(dev_ctx, in, in_scale, + bin_cnt, out); } }; diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 8d82438c15c..8fe814dc50d 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -90,46 +90,6 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): self.check_output() -class TestFakeQuantizeMovingOp(OpTest): - def setUp(self): - self.op_type = "fake_quantize_moving_average_abs_max" - self.attrs = { - 'bit_length': int(5), - 'moving_rate': float(0.9), - 'is_test': False - } - accum = np.zeros(1).astype("float32") - accum[0] = 1 - state = np.zeros(1).astype("float32") - state[0] = 1 - scale = np.zeros(1).astype("float32") - scale[0] = 0.001 - self.inputs = { - 'X': np.random.random((8, 16, 7, 7)).astype("float32"), - 'InScale': scale, - 'InAccum': accum, - 'InState': state, - } - - out_accum = np.zeros(1).astype("float32") - out_state = np.zeros(1).astype("float32") - out_scale = np.zeros(1).astype("float32") - out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max( - np.abs(self.inputs['X'])).astype("float32") - out_state[0] = self.attrs['moving_rate'] * state[0] + 1 - out_scale = out_accum / out_state - self.outputs = { - 'Out': np.round(self.inputs['X'] / out_scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutAccum': out_accum, - 'OutState': out_state, - 'OutScale': out_scale, - } - - def test_check_output(self): - self.check_output() - - class TestMovingAverageAbsMaxScaleOp(OpTest): def setUp(self): self.op_type = "moving_average_abs_max_scale" @@ -193,5 +153,62 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest): self.check_output(no_check_set=set(['OutScale', 'OutScales'])) +class TestMovingOpBase(OpTest): + def setUp(self): + self.init_type() + self.attrs = { + 'bit_length': int(5), + 'moving_rate': float(0.9), + 'is_test': False + } + accum = np.zeros(1).astype("float32") + accum[0] = 1 + state = np.zeros(1).astype("float32") + state[0] = 1 + scale = np.zeros(1).astype("float32") + scale[0] = 0.001 + self.inputs = { + 'X': np.random.random((8, 16, 7, 7)).astype("float32"), + 'InScale': scale, + 'InAccum': accum, + 'InState': state, + } + + out_accum = np.zeros(1).astype("float32") + out_state = np.zeros(1).astype("float32") + out_scale = np.zeros(1).astype("float32") + out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max( + np.abs(self.inputs['X'])).astype("float32") + out_state[0] = self.attrs['moving_rate'] * state[0] + 1 + out_scale = out_accum / out_state + out_data = self.calc_output(out_scale) + self.outputs = { + 'Out': out_data, + 'OutAccum': out_accum, + 'OutState': out_state, + 'OutScale': out_scale, + } + + def init_type(self): + self.op_type = "fake_quantize_moving_average_abs_max" + + def calc_output(self, out_scale): + return np.round(self.inputs['X'] / out_scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)) + + def test_check_output(self): + self.check_output() + + +class TestFakeQuantDequantMovingOp(TestMovingOpBase): + def init_type(self): + self.op_type = "fake_quantize_dequantize_moving_average_abs_max" + + def calc_output(self, out_scale): + range_v = (1 << (self.attrs['bit_length'] - 1)) - 1 + return np.round(self.inputs['X'] / out_scale * + range_v) * out_scale / range_v + + if __name__ == "__main__": unittest.main() -- GitLab