提交 89dee160 编写于 作者: Z Zhen Wang

add channel wise dequantize op.

上级 545247d7
......@@ -76,6 +76,70 @@ $$Out = \frac{scale*X}{ max_range }$$
}
};
class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("WeightScales"),
"Input(WeightScales) of FakeChannelWiseDequantizeMaxAbsOp "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeChannelWiseDequantizeMaxAbsOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddInput("ActivationScale",
"(float) The activation scale in quantization stage.")
.AsDispensable();
AddInput("WeightScales",
"(float array) The weight scales in quantization stage.");
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<int>("activation_bits", "Quantization bit number for activation.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'activation_bits' should be between 1 and 16.");
});
AddAttr<int>("weight_bits", "Quantization bit number for weights.")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'weight_bits' should be between 1 and 16.");
});
AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp:
$$Out_c = \frac{ActivationScale*WeightScale_c*X_c}{(2^{weight\_bits-1}-1)*(2^{activation\_bits-1}-1)}$$
In the above formula, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
Notes: Tha per-channel quantization is only applied to weights(channel size scale).
And the activations use per-layer quantization(only one scale).
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -88,3 +152,11 @@ REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
ops::FakeDequantizeMaxAbsKernel<CPU, double>);
REGISTER_OPERATOR(fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsOp,
ops::FakeChannelWiseDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CPU, double>);
......@@ -55,3 +55,7 @@ using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_dequantize_max_abs,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeChannelWiseDequantizeMaxAbsKernel<CUDA, double>);
......@@ -45,5 +45,56 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* weight_scales = ctx.Input<framework::Tensor>("WeightScales");
auto* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(weight_scales->numel(), in->dims()[0],
"The weight uses the per-channel quantization type, so "
"the number of weight scale values must be the same with "
"first dimension value of Input(X).");
int ativation_bits = ctx.Attr<int>("activation_bits");
int weight_bits = ctx.Attr<int>("weight_bits");
int range = std::pow(2, weight_bits - 1) - 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
auto dequant = DequantizeFunctor<DeviceContext, T>();
if (ctx.HasInput("ActivationScale")) {
auto* activation_scale = ctx.Input<framework::Tensor>("ActivationScale");
PADDLE_ENFORCE_EQ(activation_scale->numel(), 1,
"The activation uses per-layer quantization type, so "
"it must have only one value.");
framework::Tensor cpu_weigth_scales;
framework::TensorCopy(*weight_scales, platform::CPUPlace(),
&cpu_weigth_scales);
dev_ctx.Wait();
const T* weight_scales_data = cpu_weigth_scales.data<T>();
range *= (std::pow(2, ativation_bits - 1) - 1);
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto max_range = range / weight_scales_data[i];
dequant(dev_ctx, &one_channel_in, activation_scale,
static_cast<T>(max_range), &one_channel_out);
}
} else {
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = weight_scales->Slice(i, i + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(range), &one_channel_out);
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -180,11 +180,10 @@ The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out_c = round(X_c / scale_c * range)$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \leq \ the\ channel\ number\ of\ X$$
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC");
}
};
......
......@@ -31,6 +31,77 @@ def dequantize_max_abs(x, scale, max_range):
return y
def channel_wise_quantize_max_abs(x, max_range):
scales = []
for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy()
for i, scale in enumerate(scales):
y[i] = np.round(y[i] / scale * max_range)
return y, scales
def channel_wise_dequantize_max_abs(x, scales, max_range):
y = x.copy()
for i in range(x.shape[0]):
y[i] = (scales[i] / max_range) * y[i]
return y
class TestFakeChannelWiseDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.weight_bits = 8
self.activation_bits = 2
self.data_type = "float32"
def setUp(self):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
max_range = math.pow(2, self.weight_bits - 1) - 1
yq, scales = channel_wise_quantize_max_abs(x, max_range)
ydq = channel_wise_dequantize_max_abs(yq, scales, max_range)
self.inputs = {
'X': yq,
'ActivationScale': np.array(1.0).astype(self.data_type),
'WeightScales': np.array(scales).astype(self.data_type)
}
self.attrs = {
'weight_bits': self.weight_bits,
'activation_bits': self.activation_bits
}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeChannelWiseDequantizeMaxAbsOpNoActivationScale(OpTest):
def set_args(self):
self.weight_bits = 8
self.data_type = "float32"
def setUp(self):
self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
max_range = math.pow(2, self.weight_bits - 1) - 1
yq, scales = channel_wise_quantize_max_abs(x, max_range)
ydq = channel_wise_dequantize_max_abs(yq, scales, max_range)
self.inputs = {
'X': yq,
'WeightScales': np.array(scales).astype(self.data_type)
}
self.attrs = {'weight_bits': self.weight_bits}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册