未验证 提交 5420cf95 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16070 from wzzju/channel_wise_quant_op

Add channel wise quant op and channel wise dequant op.
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -76,6 +77,63 @@ $$Out = \frac{scale*X}{ max_range }$$
}
};
class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasInputs("Scales"),
"Input(Scales) of FakeChannelWiseDequantizeMaxAbsOp "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeChannelWiseDequantizeMaxAbsOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddInput("Scales",
"(Tensors) The scales in quantization stage. "
"Now, `Scales` is a vector with at most two tensors. "
"If Scales has two elements, the second tensor should only have "
"one value.")
.AsDuplicable();
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<std::vector<int>>(
"quant_bits",
"Quantization bit numbers in quantization stage. "
"The size of `quant_bits` should be equal to the size of `Scales`.")
.SetDefault({8});
AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp:
$$Out_c = \frac{X_c\prod_{i=1}^{n}Scales_{ic}}{\prod_{i=1}^{n}(2^{quant\_bits_i-1}-1)}$$
In the above formula, the range value of $c$ can be represented as $0 \leq c \lt \ the\ channel\ number\ of\ X$.
Besides, the size of $quant\_bits$ should be equal to the size of $Scales$, and it is called $n$ in the formula.
Notes: In general, the per-channel quantization is only applied to weights and the activations use per-layer quantization.
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -88,3 +146,11 @@ REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
ops::FakeDequantizeMaxAbsKernel<CPU, double>);
REGISTER_OPERATOR(fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsOp,
ops::FakeChannelWiseDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, double>);
......@@ -55,3 +55,7 @@ using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>);
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -45,5 +46,42 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto scales = ctx.MultiInput<framework::Tensor>("Scales");
auto* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X).");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
int max_range = std::pow(2, quant_bits[0] - 1) - 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
auto dequant = DequantizeFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
if (scales.size() == 2) {
PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1,
"The second scale tensor should only have one value at now.");
max_range = std::pow(2, quant_bits[1] - 1) - 1;
dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -134,6 +134,60 @@ $$Out = round(X/scale * range)$$
}
};
class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScales"),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
class FakeChannelWiseQuantizeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddOutput("Out",
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type.");
AddOutput("OutScales", "(Tensor) Current channel wise scale");
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC");
}
};
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeRangeAbsMaxOp(const std::string& type,
......@@ -218,3 +272,10 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
......@@ -174,5 +174,7 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
......@@ -63,6 +63,39 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scales = context.Output<framework::Tensor>("OutScales");
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto find_abs_max = FindAbsMaxFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel = in->Slice(i, i + 1);
const T* one_channel_data = one_channel.data<T>();
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scales_data[i]);
}
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out);
}
}
};
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public:
......
......@@ -31,6 +31,80 @@ def dequantize_max_abs(x, scale, max_range):
return y
def channel_wise_quantize_max_abs(x, quant_bit=8):
scales = []
for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i, scale in enumerate(scales):
y[i] = np.round(y[i] / scale * max_range)
return y, scales
def channel_wise_dequantize_max_abs(x,
scales,
quant_bits,
activation_scale=None):
y = x.copy()
for i in range(x.shape[0]):
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i]
if activation_scale is not None:
y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
return y
class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
def set_args(self):
self.quant_bits = [8, 8]
self.data_type = "float32"
self.activation_scale = 0.7861
def setUp(self):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0])
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.activation_scale)
self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type)),
("scales1", np.array(
[self.activation_scale]).astype(self.data_type))]
}
self.attrs = {'quant_bits': self.quant_bits}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
def setUp(self):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0])
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits)
self.inputs = {
'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type))]
}
self.attrs = {'quant_bits': self.quant_bits}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
......
......@@ -35,6 +35,30 @@ class TestFakeQuantizeOp(OpTest):
self.check_output()
class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_channel_wise_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {
'X': np.random.random((4, 3, 64, 64)).astype("float32"),
}
scales = []
for i in range(self.inputs['X'].shape[0]):
scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32"))
outputs = self.inputs['X'].copy()
for i, scale in enumerate(scales):
outputs[i] = np.round(outputs[i] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1))
self.outputs = {
'Out': outputs,
'OutScales': np.array(scales).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册