diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 5d6488c67e0db440c8d4609736523643dd666dcc..68c7227e5a7123e1e751dd55e243ee481bf36540 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fake_dequantize_op.h" #include +#include namespace paddle { namespace operators { @@ -76,6 +77,63 @@ $$Out = \frac{scale*X}{ max_range }$$ } }; +class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null."); + PADDLE_ENFORCE(ctx->HasInputs("Scales"), + "Input(Scales) of FakeChannelWiseDequantizeMaxAbsOp " + "should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null."); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class FakeChannelWiseDequantizeMaxAbsOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) The input with float-32/64 type is the " + "low precision tensor."); + AddInput("Scales", + "(Tensors) The scales in quantization stage. " + "Now, `Scales` is a vector with at most two tensors. " + "If Scales has two elements, the second tensor should only have " + "one value.") + .AsDuplicable(); + AddOutput("Out", + "(Tensor) The output is the dequantized high " + "precision tensor."); + AddAttr>( + "quant_bits", + "Quantization bit numbers in quantization stage. " + "The size of `quant_bits` should be equal to the size of `Scales`.") + .SetDefault({8}); + + AddComment(R"DOC( +FakeChannelWiseDequantizeMaxAbsOp operator. + +This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp: + +$$Out_c = \frac{X_c\prod_{i=1}^{n}Scales_{ic}}{\prod_{i=1}^{n}(2^{quant\_bits_i-1}-1)}$$ + +In the above formula, the range value of $c$ can be represented as $0 \leq c \lt \ the\ channel\ number\ of\ X$. +Besides, the size of $quant\_bits$ should be equal to the size of $Scales$, and it is called $n$ in the formula. + +Notes: In general, the per-channel quantization is only applied to weights and the activations use per-layer quantization. +)DOC"); + } +}; + } // namespace operators } // namespace paddle @@ -88,3 +146,11 @@ REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp, REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsKernel, ops::FakeDequantizeMaxAbsKernel); + +REGISTER_OPERATOR(fake_channel_wise_dequantize_max_abs, + ops::FakeChannelWiseDequantizeMaxAbsOp, + ops::FakeChannelWiseDequantizeMaxAbsOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fake_channel_wise_dequantize_max_abs, + ops::FakeChannelWiseDequantizeMaxAbsKernel, + ops::FakeChannelWiseDequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 225bcc45bc65bc9268d1e866a4358731eaf0c3ef..35dcc69279d0119e75c4c5072e7817c839b9e819 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -55,3 +55,7 @@ using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsKernel, ops::FakeDequantizeMaxAbsKernel); +REGISTER_OP_CUDA_KERNEL( + fake_channel_wise_dequantize_max_abs, + ops::FakeChannelWiseDequantizeMaxAbsKernel, + ops::FakeChannelWiseDequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index d9923a10daa01ca06ebabb27cf9285b0628634bc..d05f2038531bbe9c35da54c94d2ef4d659acca70 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -45,5 +46,42 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel { } }; +template +class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& ctx) const { + auto* in = ctx.Input("X"); + auto scales = ctx.MultiInput("Scales"); + auto* out = ctx.Output("Out"); + + PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0], + "The number of first scale values must be the same with " + "first dimension value of Input(X)."); + + auto quant_bits = ctx.Attr>("quant_bits"); + int max_range = std::pow(2, quant_bits[0] - 1) - 1; + + auto& dev_ctx = ctx.template device_context(); + out->mutable_data(dev_ctx.GetPlace()); + + auto dequant = DequantizeFunctor(); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); + dequant(dev_ctx, &one_channel_in, &one_channel_scale, + static_cast(max_range), &one_channel_out); + } + + if (scales.size() == 2) { + PADDLE_ENFORCE_EQ( + scales[1]->numel(), 1, + "The second scale tensor should only have one value at now."); + max_range = std::pow(2, quant_bits[1] - 1) - 1; + dequant(dev_ctx, out, scales[1], static_cast(max_range), out); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 3bb07d383548e6f4be810c96d2a916c0fe5e45f5..70186e5efa29b1324ff7f3954720276156fddaf1 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -134,6 +134,60 @@ $$Out = round(X/scale * range)$$ } }; +class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FakeChannelWiseQuantizeOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FakeChannelWiseQuantizeOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("OutScales"), + "Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class FakeChannelWiseQuantizeAbsMaxOpMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input is float data type."); + AddOutput("Out", + "(Tensor) Output of quantized low level tensor, " + "but also saved as float data type."); + AddOutput("OutScales", "(Tensor) Current channel wise scale"); + AddAttr("bit_length", "(int, default 8)") + .SetDefault(8) + .AddCustomChecker([](const int& bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'bit_length' should be between 1 and 16."); + }); + AddComment(R"DOC( +The scale of FakeChannelWiseQuantize operator is a vector. +In detail, each channel of the input X has a scale value. + +$$scale_c = max(abs(X_c))$$ +$$range = 2^{bit\_length - 1} - 1$$ +$$Out_c = round(\frac{X_c * range} {scale_c})$$ +In above three formulas, the range value of c is as follow: +$$0 \leq c \lt \ the\ channel\ number\ of\ X$$ +)DOC"); + } +}; + class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { public: FakeQuantizeRangeAbsMaxOp(const std::string& type, @@ -218,3 +272,10 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel); + +REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxOp, + ops::FakeChannelWiseQuantizeAbsMaxOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index a0ff6396210c2b3a7f8bd6b9f274b875d7fd4933..5da16a7c7314c62034bff67bcc8d099e2799c3de 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -174,5 +174,7 @@ namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, ops::FakeQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 7ace7573ec5c03ab8788cfc0aab614b7f80ea073..8b47600e7d99ad9e4e40ae162582d4c8461224ad 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -63,6 +63,39 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel { } }; +template +class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + + auto* out = context.Output("Out"); + auto* out_scales = context.Output("OutScales"); + T* out_scales_data = out_scales->mutable_data(context.GetPlace()); + out->mutable_data(context.GetPlace()); + + int bit_length = context.Attr("bit_length"); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + + auto& dev_ctx = context.template device_context(); + auto find_abs_max = FindAbsMaxFunctor(); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel = in->Slice(i, i + 1); + const T* one_channel_data = one_channel.data(); + find_abs_max(dev_ctx, one_channel_data, one_channel.numel(), + &out_scales_data[i]); + } + auto clip_quant = ClipAndFakeQuantFunctor(); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1); + clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt, + &one_channel_out); + } + } +}; + template class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel { public: diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 1bb4662e8d83ac0c34b209e4e7a605869fdb59d5..32cb23cbfa9bdef4728e85d0014123652e4aefea 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,6 +31,80 @@ def dequantize_max_abs(x, scale, max_range): return y +def channel_wise_quantize_max_abs(x, quant_bit=8): + scales = [] + for i in range(x.shape[0]): + scales.append(np.max(np.abs(x[i])).astype("float32")) + + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + for i, scale in enumerate(scales): + y[i] = np.round(y[i] / scale * max_range) + return y, scales + + +def channel_wise_dequantize_max_abs(x, + scales, + quant_bits, + activation_scale=None): + y = x.copy() + for i in range(x.shape[0]): + y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i] + if activation_scale is not None: + y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) + return y + + +class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): + def set_args(self): + self.quant_bits = [8, 8] + self.data_type = "float32" + self.activation_scale = 0.7861 + + def setUp(self): + self.set_args() + self.op_type = "fake_channel_wise_dequantize_max_abs" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + self.activation_scale) + + self.inputs = { + 'X': yq, + 'Scales': [("scales0", np.array(scales).astype(self.data_type)), + ("scales1", np.array( + [self.activation_scale]).astype(self.data_type))] + } + self.attrs = {'quant_bits': self.quant_bits} + self.outputs = {'Out': ydq} + + def test_check_output(self): + self.check_output() + + +class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): + def set_args(self): + self.quant_bits = [8] + self.data_type = "float32" + + def setUp(self): + self.set_args() + self.op_type = "fake_channel_wise_dequantize_max_abs" + x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits) + + self.inputs = { + 'X': yq, + 'Scales': [("scales0", np.array(scales).astype(self.data_type))] + } + self.attrs = {'quant_bits': self.quant_bits} + self.outputs = {'Out': ydq} + + def test_check_output(self): + self.check_output() + + class TestFakeDequantizeMaxAbsOp(OpTest): def set_args(self): self.num_bits = 8 diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 4582b2a0eed401235835374d4cd58782d8d3a68f..90a90112bd5f0e24374111073514b20dd1231edb 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -35,6 +35,30 @@ class TestFakeQuantizeOp(OpTest): self.check_output() +class TestFakeChannelWiseQuantizeOp(OpTest): + def setUp(self): + self.op_type = "fake_channel_wise_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = { + 'X': np.random.random((4, 3, 64, 64)).astype("float32"), + } + scales = [] + for i in range(self.inputs['X'].shape[0]): + scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32")) + outputs = self.inputs['X'].copy() + for i, scale in enumerate(scales): + outputs[i] = np.round(outputs[i] / scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)) + + self.outputs = { + 'Out': outputs, + 'OutScales': np.array(scales).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + class TestFakeQuantizeRangeAbsMaxOp(OpTest): def setUp(self): self.op_type = "fake_quantize_range_abs_max"