未验证 提交 ec11135d 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16341 from wzzju/add_channel_wise_in_quant_pass

Add channel wise in quant pass.
...@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
} }
}; };
template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
if (scale_num == 1) {
const int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
for (int i = 0; i < channel; i++) {
T s = scale_factor[i];
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s / max_range) * in_e;
}
} else if (scale_num == 2) {
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s * scale_two[0] / max_range) * in_e;
}
}
}
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, float>; template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>; template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, double>;
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel { class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public: public:
......
...@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template <typename T>
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
int num, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale[blockIdx.x] / max_range;
}
}
template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num,
int batch_size, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / (batch_size * channel);
int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
}
}
template <typename T>
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) {
int num = in->numel();
int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
int block = 1024;
int grid = channel;
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, channel, out_data);
} else if (scale_num == 2) {
int num = in->numel();
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = batch_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
out_data);
}
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, float>; template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
template struct DequantizeFunctor<platform::CUDADeviceContext, double>; template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, double>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -28,6 +29,13 @@ struct DequantizeFunctor { ...@@ -28,6 +29,13 @@ struct DequantizeFunctor {
framework::Tensor* out); framework::Tensor* out);
}; };
template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> { class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public: public:
...@@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -54,32 +62,33 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto scales = ctx.MultiInput<framework::Tensor>("Scales"); auto scales = ctx.MultiInput<framework::Tensor>("Scales");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X).");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits"); auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
int max_range = std::pow(2, quant_bits[0] - 1) - 1; int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
int scale_num = scales.size();
auto dequant = DequantizeFunctor<DeviceContext, T>(); if (scale_num == 1) {
for (int64_t i = 0; i < in->dims()[0]; i++) { PADDLE_ENFORCE_EQ(
framework::Tensor one_channel_in = in->Slice(i, i + 1); scales[0]->numel(), in->dims()[0],
framework::Tensor one_channel_out = out->Slice(i, i + 1); "The number of first scale values must be the same with "
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); "first dimension value of Input(X) when the `Scales` has only one "
dequant(dev_ctx, &one_channel_in, &one_channel_scale, "element.");
static_cast<T>(max_range), &one_channel_out); max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} } else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
if (scales.size() == 2) { scales[0]->numel(), in->dims()[1],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1, scales[1]->numel(), 1,
"The second scale tensor should only have one value at now."); "The second scale tensor should only have one value at now.");
max_range = std::pow(2, quant_bits[1] - 1) - 1; max_range *= (std::pow(2, quant_bits[0] - 1) - 1) *
dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out); (std::pow(2, quant_bits[1] - 1) - 1);
} }
ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out);
} }
}; };
......
...@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>; template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
const int channel_size = num / channel;
for (int i = 0; i < channel; i++) {
auto* start = in + i * channel_size;
auto* end = in + (i + 1) * channel_size;
out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
}
}
};
template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>; template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
framework::Tensor* out) {
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
const int channel_size = in.numel() / channel;
platform::Transform<platform::CPUDeviceContext> trans;
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
}
}
};
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>;
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -169,10 +214,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
ctx->HasOutput("Out"), ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null."); "Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("OutScales"), ctx->HasOutput("OutScale"),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); "Output(Scale) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -192,7 +237,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
AddOutput("Out", AddOutput("Out",
"(Tensor) Output of quantized low level tensor, " "(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."); "but also saved as float data type.");
AddOutput("OutScales", "(Tensor) Current channel wise scale"); AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
......
...@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>; template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[];
shared_max_data[tid] = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max_data[tid] = shared_max_data[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[blockIdx.x] = shared_max_data[0];
}
}
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
int block = 1024;
int grid = channel;
FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
in, num, channel, out);
}
};
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale, __global__ void ClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, T* out) { const int bin_cnt, const int n, T* out) {
...@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T s = scale[0]; T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[bid]; T x = in[i];
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt / s * v; v = bin_cnt / s * v;
out[bid] = round(v); out[i] = round(v);
} }
} }
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n,
const int c, T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out_c[i] = round(v);
}
}
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = channel;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, channel, out_data);
}
};
template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
float>;
template <typename T> template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
const T* last_scale, const T* last_scale,
...@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
float>; float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor { ...@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor {
framework::Tensor* scales_arr, framework::Tensor* out_scale); framework::Tensor* scales_arr, framework::Tensor* out_scale);
}; };
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num,
const int channel, T* out);
};
template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int channel, framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor { struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum,
...@@ -78,29 +91,18 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -78,29 +91,18 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
auto* out_scales = context.Output<framework::Tensor>("OutScales"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace()); T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
auto find_abs_max = FindAbsMaxFunctor<DeviceContext, T>(); FindChannelAbsMaxFunctor<DeviceContext, T>()(
for (int64_t i = 0; i < in->dims()[0]; i++) { dev_ctx, in->data<T>(), in->numel(), in->dims()[0], out_scale_data);
framework::Tensor one_channel = in->Slice(i, i + 1); ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
const T* one_channel_data = one_channel.data<T>(); dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out);
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scales_data[i]);
}
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out);
}
} }
}; };
......
...@@ -22,6 +22,7 @@ from ....framework import IrGraph ...@@ -22,6 +22,7 @@ from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
from ....framework import Program from ....framework import Program
from ....initializer import Constant from ....initializer import Constant
from ....initializer import NumpyArrayInitializer
from .... import unique_name from .... import unique_name
__all__ = [ __all__ = [
...@@ -54,14 +55,15 @@ class QuantizationTransformPass(object): ...@@ -54,14 +55,15 @@ class QuantizationTransformPass(object):
the bias is not quantized. the bias is not quantized.
activation_bits (int): quantization bit number for activation. activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation, activation_quantize_type (str): quantization type for activation,
now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
the quantization scale will be calculated dynamically each step If use 'abs_max' mode, the quantization scale will be calculated
in both training and testing period. If use 'range_abs_max', dynamically each step in both training and testing period. If use
a static quantization scale will be calculated during training 'range_abs_max', a static quantization scale will be calculated
and used in inference. during training and used in inference.
weight_quantize_type (str): quantization type for weights, weight_quantize_type (str): quantization type for weights,
support 'abs_max'. The 'range_abs_max' usually is not used for support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
weight, since weights are fixed once the model is well trained. usually is not used for weight, since weights are fixed once the
model is well trained.
window_size (int): the window size for 'range_abs_max' quantization. window_size (int): the window size for 'range_abs_max' quantization.
Examples: Examples:
...@@ -84,7 +86,11 @@ class QuantizationTransformPass(object): ...@@ -84,7 +86,11 @@ class QuantizationTransformPass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max'] quant_type = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type: if activation_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ", "Unknown activation_quantize_type : '%s'. It can only be ",
...@@ -93,7 +99,7 @@ class QuantizationTransformPass(object): ...@@ -93,7 +99,7 @@ class QuantizationTransformPass(object):
if weight_quantize_type not in quant_type: if weight_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ", "Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type)) str(weight_quantize_type))
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
...@@ -103,6 +109,7 @@ class QuantizationTransformPass(object): ...@@ -103,6 +109,7 @@ class QuantizationTransformPass(object):
self._need_initialized = collections.OrderedDict() self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
...@@ -135,10 +142,26 @@ class QuantizationTransformPass(object): ...@@ -135,10 +142,26 @@ class QuantizationTransformPass(object):
else self._activation_bits else self._activation_bits
quant_type = self._weight_quantize_type if var_node.name() \ quant_type = self._weight_quantize_type if var_node.name() \
in persistable_vars else self._activation_quantize_type in persistable_vars else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op( if quant_type == 'channel_wise_abs_max':
graph, var_node, quant_bits, quant_type) assert var_node.name(
dequant_var_node = self._insert_dequant_op( ) in persistable_vars, "'channel_wise_abs_max' can only be applied on weights."
graph, quant_var_node, scale_var_node, quant_bits) if op.name() in self._conv_ops:
quant_var_node, scale_var_node = self._insert_channel_quant_op(
graph, var_node, quant_bits)
dequant_var_node = self._insert_channel_dequant_op(
graph, quant_var_node, [scale_var_node],
[quant_bits])
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, 'abs_max')
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node,
quant_bits)
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node dequantized_vars[var_node.name()] = dequant_var_node
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
...@@ -244,7 +267,7 @@ class QuantizationTransformPass(object): ...@@ -244,7 +267,7 @@ class QuantizationTransformPass(object):
scale_var_node = graph.create_var_node( scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
...@@ -384,6 +407,36 @@ class QuantizationTransformPass(object): ...@@ -384,6 +407,36 @@ class QuantizationTransformPass(object):
return quant_var_node, scale_out_node return quant_var_node, scale_out_node
def _insert_channel_quant_op(self, graph, var_node, quant_bits):
"""
Insert fake_channel_wise_quantize_abs_max op in the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()),
var_type=var_node.type(),
shape=[var_node.shape()[0]],
var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max',
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
'OutScale': scale_var_node})
graph.link_to(var_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node
def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits):
""" """
Insert fake_dequantize_op in the graph. Insert fake_dequantize_op in the graph.
...@@ -410,6 +463,33 @@ class QuantizationTransformPass(object): ...@@ -410,6 +463,33 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node return dequant_var_node
def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
quant_bits):
"""
Insert fake_channel_wise_dequantize_max_abs in the graph.
"""
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs',
attrs={
'quant_bits': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node,
'Scales': scale_var_nodes},
outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op_node)
for scale_n in scale_var_nodes:
graph.link_to(scale_n, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _quantized_var_name(self, var_name): def _quantized_var_name(self, var_name):
""" """
Return quantized variable name for the input `var_name`. Return quantized variable name for the input `var_name`.
...@@ -442,7 +522,7 @@ class QuantizationFreezePass(object): ...@@ -442,7 +522,7 @@ class QuantizationFreezePass(object):
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights. weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation. activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max'. weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained. model is well trained.
""" """
...@@ -463,11 +543,15 @@ class QuantizationFreezePass(object): ...@@ -463,11 +543,15 @@ class QuantizationFreezePass(object):
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = [ self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max', 'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max' 'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict() self._var_scale_map = collections.OrderedDict()
...@@ -489,20 +573,27 @@ class QuantizationFreezePass(object): ...@@ -489,20 +573,27 @@ class QuantizationFreezePass(object):
if self._weight_quantize_type == 'abs_max': if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name) param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param)) scale_v = np.max(np.abs(param))
elif self._weight_quantize_type == 'channel_wise_abs_max':
param = self._load_var(input_arg_name)
if len(param.shape) == 4: # conv2d or depthwise_conv2d
scale_v = []
for i in range(param.shape[0]):
scale_v.append(np.max(np.abs(param[i])))
else:
scale_v = np.max(np.abs(param))
else: else:
scale_v = self._load_var( scale_v = self._load_var(
op_node.output('OutScale')[0])[0] op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
else:
scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore # quantize weight and restore
param_v = self._load_var(input_arg_name) param_v = self._load_var(input_arg_name)
quantized_param_v = self._quant(param_v, scale_v, quantized_param_v = self._quant(param_v, scale_v,
self._weight_bits) self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
else:
scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
...@@ -514,7 +605,10 @@ class QuantizationFreezePass(object): ...@@ -514,7 +605,10 @@ class QuantizationFreezePass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
self._insert_post_dequant_op(graph, op_node) if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
for op_node in ops: for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops # insert dequant_op after fc/conv, need to rename inputs of the followed ops
...@@ -538,9 +632,73 @@ class QuantizationFreezePass(object): ...@@ -538,9 +632,73 @@ class QuantizationFreezePass(object):
self._op_input_rename_map[k] = self._op_input_rename_map[v] self._op_input_rename_map[k] = self._op_input_rename_map[v]
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
def _insert_post_channel_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
if original_var_name in persistable_vars:
assert isinstance(
scale_v,
list), 'The scale of parameter %s is not a list.' % (
original_var_name)
channel_scale = np.array(scale_v)
else:
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0]
weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype())
init_program = Program()
weight_scale_var = init_program.global_block().create_var(
name=weight_scale_node.name(),
shape=weight_scale_node.shape(),
dtype=weight_scale_node.dtype(),
type=weight_scale_node.type(),
lod_level=weight_scale_node.var().lod_level(),
persistable=weight_scale_node.persistable())
initializer = NumpyArrayInitializer(value=channel_scale)
initializer(weight_scale_var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
shape=output_var_node.shape(),
var_dtype=output_var_node.dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_channel_wise_dequantize_max_abs',
attrs={
'quant_bits': [self._weight_bits, self._activation_bits],
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={
'X': output_var_node,
'Scales': [weight_scale_node, scale_var_node]
},
outputs={'Out': dequant_var_node})
graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(weight_scale_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
return dequant_var_node
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
max_range = None
scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
...@@ -637,7 +795,12 @@ class QuantizationFreezePass(object): ...@@ -637,7 +795,12 @@ class QuantizationFreezePass(object):
or isinstance(v, np.float64) or isinstance(v, np.float64)
def _quant(self, x, scale, num_bits): def _quant(self, x, scale, num_bits):
return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) if isinstance(scale, list):
for i, s in enumerate(scale):
x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
return x
else:
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
...@@ -731,9 +894,13 @@ class TransformForMobilePass(object): ...@@ -731,9 +894,13 @@ class TransformForMobilePass(object):
def __init__(self): def __init__(self):
self._fake_quant_op_names = [ self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max' 'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
def apply(self, graph): def apply(self, graph):
""" """
......
...@@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized')) arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops) self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type, for_ci=False): def linear_fc_quant(self, activation_quant_type, for_ci=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
place=place, place=place,
activation_quantize_type=quant_type) activation_quantize_type=activation_quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) graph.draw('.', 'quantize_fc_' + activation_quant_type,
marked_nodes)
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
...@@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for op in val_graph.all_op_nodes(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
val_marked_nodes.add(op) val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_fc_' + activation_quant_type,
val_marked_nodes)
def test_linear_fc_quant_abs_max(self): def test_linear_fc_quant_abs_max(self):
self.linear_fc_quant('abs_max', for_ci=True) self.linear_fc_quant('abs_max', for_ci=True)
...@@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def test_linear_fc_quant_moving_average_abs_max(self): def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True) self.linear_fc_quant('moving_average_abs_max', for_ci=True)
def residual_block_quant(self, quant_type, for_ci=False): def residual_block_quant(self, activation_quant_type, for_ci=False):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
place=place, place=place,
activation_quantize_type=quant_type) activation_quantize_type=activation_quant_type)
transform_pass.apply(graph) transform_pass.apply(graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) graph.draw('.', 'quantize_residual_' + activation_quant_type,
marked_nodes)
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
...@@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for op in val_graph.all_op_nodes(): for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
val_marked_nodes.add(op) val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) val_graph.draw('.', 'val_residual_' + activation_quant_type,
val_marked_nodes)
def test_residual_block_abs_max(self): def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', for_ci=True) self.residual_block_quant('abs_max', for_ci=True)
...@@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
class TestQuantizationFreezePass(unittest.TestCase): class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False): def freeze_graph(self,
use_cuda,
seed,
activation_quant_type,
weight_quant_type='abs_max',
for_ci=False):
def build_program(main, startup, is_test): def build_program(main, startup, is_test):
main.random_seed = seed main.random_seed = seed
startup.random_seed = seed startup.random_seed = seed
...@@ -243,7 +252,12 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -243,7 +252,12 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
exe.run(startup) exe.run(startup)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, place=place, activation_quantize_type=quant_type) scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
...@@ -252,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -252,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in main_graph.all_op_nodes(): for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) main_graph.draw('.', 'main' + dev_name + activation_quant_type + '_'
+ weight_quant_type, marked_nodes)
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) test_graph.draw('.', 'test' + dev_name + activation_quant_type + '_'
+ weight_quant_type, marked_nodes)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
...@@ -282,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -282,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
if not for_ci: if not for_ci:
print('{}: {}'.format('loss' + dev_name + quant_type, print('{}: {}'.format('loss' + dev_name +
loss_v)) activation_quant_type + '_' +
weight_quant_type, loss_v))
test_data = next(test_reader()) test_data = next(test_reader())
with fluid.program_guard(quantized_test_program): with fluid.program_guard(quantized_test_program):
...@@ -296,14 +313,17 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -296,14 +313,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list=[loss, w_var]) fetch_list=[loss, w_var])
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass = QuantizationFreezePass(
scope=scope, place=place, weight_quantize_type=weight_quant_type)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type, test_graph.draw('.', 'test_freeze' + dev_name +
activation_quant_type + '_' + weight_quant_type,
marked_nodes) marked_nodes)
server_program = test_graph.to_program() server_program = test_graph.to_program()
...@@ -313,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -313,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list=[loss]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
if not for_ci: if not for_ci:
print('{}: {}'.format('test_loss1' + dev_name + quant_type, print(
test_loss1)) '{}: {}'.format('test_loss1' + dev_name + activation_quant_type
print('{}: {}'.format('test_loss2' + dev_name + quant_type, + '_' + weight_quant_type, test_loss1))
test_loss2)) print(
'{}: {}'.format('test_loss2' + dev_name + activation_quant_type
+ '_' + weight_quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision # Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if not for_ci: if not for_ci:
print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
np.sum(w_freeze))) + '_' + weight_quant_type, np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type, print('{}: {}'.format('w_quant' + dev_name + activation_quant_type +
np.sum(w_quant))) '_' + weight_quant_type, np.sum(w_quant)))
# Convert parameter to 8-bit. # Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
...@@ -334,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -334,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, test_graph.draw('.', 'test_int8' + dev_name + activation_quant_type
marked_nodes) + '_' + weight_quant_type, marked_nodes)
server_program_int8 = test_graph.to_program() server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file. # Save the 8-bit parameter and model file.
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
fluid.io.save_inference_model('server_int8' + dev_name + quant_type, fluid.io.save_inference_model(
['image', 'label'], [loss], exe, 'server_int8' + dev_name + activation_quant_type + '_' +
server_program_int8) weight_quant_type, ['image', 'label'], [loss], exe,
server_program_int8)
# Test whether the 8-bit parameter and model file can be loaded successfully. # Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model( [infer, feed, fetch] = fluid.io.load_inference_model(
'server_int8' + dev_name + quant_type, exe) 'server_int8' + dev_name + activation_quant_type + '_' +
weight_quant_type, exe)
# Check the loaded 8-bit weight. # Check the loaded 8-bit weight.
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
if not for_ci: if not for_ci:
print('{}: {}'.format('w_8bit' + dev_name + quant_type, print('{}: {}'.format('w_8bit' + dev_name + activation_quant_type +
np.sum(w_8bit))) '_' + weight_quant_type, np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type, print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
np.sum(w_freeze))) + '_' + weight_quant_type, np.sum(w_freeze)))
mobile_pass = TransformForMobilePass() mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph) mobile_pass.apply(test_graph)
...@@ -362,42 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -362,42 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in test_graph.all_op_nodes(): for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type, test_graph.draw('.', 'test_mobile' + dev_name +
activation_quant_type + '_' + weight_quant_type,
marked_nodes) marked_nodes)
mobile_program = test_graph.to_program() mobile_program = test_graph.to_program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type, fluid.io.save_inference_model(
['image', 'label'], [loss], exe, 'mobile_int8' + dev_name + activation_quant_type + '_' +
mobile_program) weight_quant_type, ['image', 'label'], [loss], exe,
mobile_program)
def test_freeze_graph_cuda_dynamic(self): def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
True, seed=1, quant_type='abs_max', for_ci=True) True,
seed=1,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True)
with fluid.unique_name.guard():
self.freeze_graph(
True,
seed=1,
activation_quant_type='abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_freeze_graph_cpu_dynamic(self): def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True) self.freeze_graph(
False,
seed=2,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
activation_quant_type='abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_freeze_graph_cuda_static(self): def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
True, seed=1, quant_type='range_abs_max', for_ci=True) True,
seed=1,
activation_quant_type='range_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
True,
seed=1,
activation_quant_type='moving_average_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph( self.freeze_graph(
True, True,
seed=1, seed=1,
quant_type='moving_average_abs_max', activation_quant_type='range_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
self.freeze_graph(
True,
seed=1,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True) for_ci=True)
def test_freeze_graph_cpu_static(self): def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph( self.freeze_graph(
False, seed=2, quant_type='range_abs_max', for_ci=True) False,
seed=2,
activation_quant_type='range_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
activation_quant_type='moving_average_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
activation_quant_type='range_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
self.freeze_graph( self.freeze_graph(
False, seed=2, quant_type='moving_average_abs_max', for_ci=True) False,
seed=2,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range): ...@@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range):
return y return y
def channel_wise_quantize_max_abs(x, quant_bit=8): def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False):
scales = [] scales = []
for i in range(x.shape[0]): if not use_second_dim:
scales.append(np.max(np.abs(x[i])).astype("float32")) for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy() y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1 max_range = math.pow(2, quant_bit - 1) - 1
for i, scale in enumerate(scales): for i, scale in enumerate(scales):
y[i] = np.round(y[i] / scale * max_range) y[i] = np.round(x[i] / scale * max_range)
else:
for i in range(x.shape[0]):
s = []
for j in range(x.shape[1]):
s.append(np.max(np.abs(x[i][j])).astype("float32"))
scales.append(s)
scales = np.amax(np.array(scales), axis=0)
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i in range(x.shape[0]):
for j, scale in enumerate(scales):
y[i][j] = np.round(x[i][j] / scale * max_range)
return y, scales return y, scales
...@@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x, ...@@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x,
scales, scales,
quant_bits, quant_bits,
activation_scale=None): activation_scale=None):
y = x.copy() if activation_scale is None:
for i in range(x.shape[0]): y = x.copy()
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i] for i in range(x.shape[0]):
if activation_scale is not None: y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i]
else:
y = x.copy()
for i in range(x.shape[0]):
for j in range(x.shape[1]):
y[i][j] = (scales[j] /
(math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j]
y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
return y return y
...@@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): ...@@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.set_args() self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs" self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type) x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) yq, scales = channel_wise_quantize_max_abs(
x, self.quant_bits[0], use_second_dim=True)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.activation_scale) self.activation_scale)
......
...@@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest): ...@@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest):
self.outputs = { self.outputs = {
'Out': outputs, 'Out': outputs,
'OutScales': np.array(scales).astype("float32"), 'OutScale': np.array(scales).astype("float32"),
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册