未验证 提交 02606d45 编写于 作者: H huangxu96 提交者: GitHub

Quant op dev (#25932)

* Finished ChannelWiseQuantDequantAbsMaxOp and Passed unittests.

* Finished channel-wise quantize strategy in imperative quantization.

* Added Cuda code of ChannelWiseQuantDequantMaxAbsOP
Add Cuda code of ChannelWiseQuantDequantMaxAbsOp

* Add quant_axis for channel_wise quant.

* fixed a bug in unnitests, which will not trigger axis = 1 case and cannot meet the coverage rate requirement.

* Added some assert infomation and fixed some coding style mistakes.
上级 aa7835ef
...@@ -174,7 +174,64 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -174,7 +174,64 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>; float>;
template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans;
if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel;
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
}
} else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
for (int i = 0; i < in_dims[0]; i++) {
for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j];
T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) *
s / static_cast<T>(bin_cnt);
}
}
}
}
}
};
template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext,
float>;
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -360,6 +417,75 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$ ...@@ -360,6 +417,75 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
} }
}; };
class FakeChannelWiseQuantizeDequantizeAbsMaxOp
: public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeChannelWiseQuantizeDequantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddOutput("Out",
"(Tensor) Output of quantized and dequantized low level tensor, "
"saved as float data type.");
AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c}) * \frac{scale_c} {range}$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC");
}
};
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeRangeAbsMaxOp(const std::string& type, FakeQuantizeRangeAbsMaxOp(const std::string& type,
...@@ -666,3 +792,12 @@ REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, ...@@ -666,3 +792,12 @@ REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp); REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad, REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CPU, float>); ops::FakeQuantDequantGradKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);
...@@ -417,8 +417,90 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -417,8 +417,90 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, // ChannelClipAndQuantDequantKernel for quant_axis is 0
float>; template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(
const T* in, const T* scale, const int bin_cnt, const int n, const int c,
T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
}
}
// ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(
const T* in, const T* scale, const int bin_cnt, const int n, const int cin,
const int cout, T* out) {
T s = scale[blockIdx.x % cout];
T inv_s = inverse(s);
int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
}
}
template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
int num = in.numel();
auto in_dims = in.dims();
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis0<
T><<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt,
num, in_dims[0], out_data);
} else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1];
int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis1<
T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
}
}
};
template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
float>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -443,3 +525,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -443,3 +525,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad, REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CUDA, float>); ops::FakeQuantDequantGradKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_channel_wise_quantize_dequantize_abs_max,
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);
...@@ -72,6 +72,13 @@ struct ChannelClipAndFakeQuantFunctor { ...@@ -72,6 +72,13 @@ struct ChannelClipAndFakeQuantFunctor {
const int quant_axis, framework::Tensor* out); const int quant_axis, framework::Tensor* out);
}; };
template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int quant_axis, framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor { struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum,
...@@ -154,6 +161,30 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -154,6 +161,30 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis");
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis,
out_scale_data);
ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out);
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h" #include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/fc.h"
......
...@@ -111,6 +111,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -111,6 +111,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"fake_quantize_dequantize_moving_average_abs_max", {"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}}, {"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
{"fake_channel_wise_quantize_dequantize_abs_max", {"Out", "OutScale"}},
{"check_finite_and_unscale", {"Out", "FoundInfinite"}}, {"check_finite_and_unscale", {"Out", "FoundInfinite"}},
{"update_loss_scaling", {"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}}, {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
......
...@@ -99,7 +99,12 @@ class ImperativeQuantAware(object): ...@@ -99,7 +99,12 @@ class ImperativeQuantAware(object):
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._moving_rate = moving_rate self._moving_rate = moving_rate
quant_type = {'abs_max', 'moving_average_abs_max'} quant_type = {
'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'
}
assert activation_quantize_type != 'channel_wise_abs_max', \
"The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type: if activation_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be " "Unknown activation_quantize_type : '%s'. It can only be "
...@@ -108,8 +113,8 @@ class ImperativeQuantAware(object): ...@@ -108,8 +113,8 @@ class ImperativeQuantAware(object):
if weight_quantize_type not in quant_type: if weight_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be " "Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'moving_average_abs_max' now." % "'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now."
(str(weight_quantize_type))) % (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
......
...@@ -24,7 +24,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype ...@@ -24,7 +24,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype
__all__ = [ __all__ = [
'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D', 'FakeQuantMovingAverage', 'FakeQuantAbsMax', 'QuantizedConv2D',
'QuantizedLinear' 'QuantizedLinear', 'FakeChannelWiseQuantDequantAbsMax'
] ]
...@@ -209,6 +209,89 @@ class FakeQuantAbsMax(layers.Layer): ...@@ -209,6 +209,89 @@ class FakeQuantAbsMax(layers.Layer):
return quant_out return quant_out
class FakeChannelWiseQuantDequantAbsMax(layers.Layer):
def __init__(self,
name=None,
channel_num=None,
quant_bits=8,
quant_axis=0,
dtype='float32',
quant_on_weight=False):
assert quant_on_weight == True, "Channel_wise only can be used on weight quantization."
super(FakeChannelWiseQuantDequantAbsMax, self).__init__()
self._quant_bits = quant_bits
self._quant_axis = quant_axis
self._dtype = dtype
self._name = name
self._channel_num = channel_num
scale_prefix = "{}.scale".format(
name) if name else 'quant_dequant.scale'
self._scale_name = unique_name.generate(scale_prefix)
if quant_on_weight:
scale_attr = ParamAttr(
name=self._scale_name,
initializer=Constant(0.0),
trainable=False)
self._scale = self.create_parameter(
shape=[self._channel_num], attr=scale_attr, dtype=self._dtype)
self._scale.stop_gradient = True
else:
self._scale = None
def forward(self, input):
if in_dygraph_mode():
attrs = ('bit_length', self._quant_bits, 'quant_axis',
self._quant_axis)
quant_out = _varbase_creator(
type=input.type,
name="{}.quantized.dequantized".format(input.name),
shape=input.shape,
dtype=input.dtype,
persistable=False)
out_scale = self._scale
if out_scale is None:
out_scale = _varbase_creator(
type=core.VarDesc.VarType.LOD_TENSOR,
name=self._scale_name,
shape=[self._channel_num],
dtype=self._dtype,
persistable=False)
out_scale.stop_gradient = True
out, _, = core.ops.fake_channel_wise_quantize_dequantize_abs_max(
input, quant_out, out_scale, *attrs)
return out
check_variable_and_dtype(input, 'input', ['float32'],
"FakeChannelWiseQuantDequantAbsMax")
attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis}
inputs = {"X": [input]}
quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
out_scale = self._scale
if not out_scale:
out_scale = self._helper.create_variable(
name=self._scale_name,
dtype=self._dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=True)
outputs = {"Out": [quant_out], "OutScale": [out_scale]}
self._helper.append_op(
type="fake_channel_wise_quantize_dequantize_abs_max",
inputs=inputs,
outputs=outputs,
attrs=attrs)
return quant_out
def _get_fake_quant_type(quant_type, **kwargs): def _get_fake_quant_type(quant_type, **kwargs):
call_args = { call_args = {
"name": kwargs.get("name", None), "name": kwargs.get("name", None),
...@@ -220,10 +303,17 @@ def _get_fake_quant_type(quant_type, **kwargs): ...@@ -220,10 +303,17 @@ def _get_fake_quant_type(quant_type, **kwargs):
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False) call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
elif quant_type == 'moving_average_abs_max': elif quant_type == 'moving_average_abs_max':
call_args["moving_rate"] = kwargs.get("moving_rate", 0.9) call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
elif quant_type == 'channel_wise_abs_max':
call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
call_args["channel_num"] = kwargs.get("channel_num", None)
call_args["quant_axis"] = kwargs.get("quant_axis", 0)
assert call_args["channel_num"] is not None, (
"You need to input channel_num"
"when you use channel_wise_abs_max strategy.")
fake_quant_map = { fake_quant_map = {
'abs_max': FakeQuantAbsMax, 'abs_max': FakeQuantAbsMax,
'moving_average_abs_max': FakeQuantMovingAverage 'moving_average_abs_max': FakeQuantMovingAverage,
'channel_wise_abs_max': FakeChannelWiseQuantDequantAbsMax
} }
return fake_quant_map[quant_type](**call_args) return fake_quant_map[quant_type](**call_args)
...@@ -255,19 +345,23 @@ class QuantizedConv2D(layers.Layer): ...@@ -255,19 +345,23 @@ class QuantizedConv2D(layers.Layer):
self.weight = getattr(layer, 'weight') self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias') self.bias = getattr(layer, 'bias')
# For FakeQuant # For FakeQuant
self._conv2d_quant_axis = 0
self._fake_quant_weight = _get_fake_quant_type( self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, weight_quantize_type,
name=self.weight.name, name=self.weight.name,
moving_rate=moving_rate, moving_rate=moving_rate,
quant_bits=weight_bits, quant_bits=weight_bits,
dtype=self._dtype, dtype=self._dtype,
quant_on_weight=True) quant_on_weight=True,
channel_num=self.weight.shape[self._conv2d_quant_axis],
quant_axis=self._conv2d_quant_axis)
self._fake_quant_input = _get_fake_quant_type( self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type, activation_quantize_type,
name=layer.full_name(), name=layer.full_name(),
moving_rate=moving_rate, moving_rate=moving_rate,
quant_bits=activation_bits, quant_bits=activation_bits,
dtype=self._dtype) dtype=self._dtype,
quant_on_weight=False)
def forward(self, input): def forward(self, input):
quant_input = self._fake_quant_input(input) quant_input = self._fake_quant_input(input)
...@@ -341,19 +435,23 @@ class QuantizedLinear(layers.Layer): ...@@ -341,19 +435,23 @@ class QuantizedLinear(layers.Layer):
self.weight = getattr(layer, 'weight') self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias') self.bias = getattr(layer, 'bias')
# For FakeQuant # For FakeQuant
self._linear_quant_axis = 1
self._fake_quant_weight = _get_fake_quant_type( self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type, weight_quantize_type,
name=self.weight.name, name=self.weight.name,
moving_rate=moving_rate, moving_rate=moving_rate,
quant_bits=weight_bits, quant_bits=weight_bits,
dtype=self._dtype, dtype=self._dtype,
quant_on_weight=True) quant_on_weight=True,
channel_num=self.weight.shape[self._linear_quant_axis],
quant_axis=self._linear_quant_axis)
self._fake_quant_input = _get_fake_quant_type( self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type, activation_quantize_type,
name=layer.full_name(), name=layer.full_name(),
moving_rate=moving_rate, moving_rate=moving_rate,
quant_bits=activation_bits, quant_bits=activation_bits,
dtype=self._dtype) dtype=self._dtype,
quant_on_weight=False)
def forward(self, input): def forward(self, input):
quant_input = self._fake_quant_input(input) quant_input = self._fake_quant_input(input)
......
...@@ -181,7 +181,6 @@ class TestImperativeQat(unittest.TestCase): ...@@ -181,7 +181,6 @@ class TestImperativeQat(unittest.TestCase):
img = fluid.dygraph.to_variable(x_data) img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data) label = fluid.dygraph.to_variable(y_data)
out = lenet(img) out = lenet(img)
acc = fluid.layers.accuracy(out, label) acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label) loss = fluid.layers.cross_entropy(out, label)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import os
import numpy as np
import random
import unittest
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger
os.environ["CPU_NUM"] = "1"
if core.is_compiled_with_cuda():
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def StaticLenet(data, num_classes=10, classifier_activation='softmax'):
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
conv1 = fluid.layers.conv2d(
data,
num_filters=6,
filter_size=3,
stride=1,
padding=1,
param_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr)
pool1 = fluid.layers.pool2d(
conv1, pool_size=2, pool_type='max', pool_stride=2)
conv2 = fluid.layers.conv2d(
pool1,
num_filters=16,
filter_size=5,
stride=1,
padding=0,
param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr)
pool2 = fluid.layers.pool2d(
conv2, pool_size=2, pool_type='max', pool_stride=2)
fc1 = fluid.layers.fc(input=pool2,
size=120,
param_attr=fc_w1_attr,
bias_attr=fc_b1_attr)
fc2 = fluid.layers.fc(input=fc1,
size=84,
param_attr=fc_w2_attr,
bias_attr=fc_b2_attr)
fc3 = fluid.layers.fc(input=fc2,
size=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr,
bias_attr=fc_b3_attr)
return fc3
class ImperativeLenet(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1")
conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2")
fc_w1_attr = fluid.ParamAttr(name="fc_w_1")
fc_w2_attr = fluid.ParamAttr(name="fc_w_2")
fc_w3_attr = fluid.ParamAttr(name="fc_w_3")
conv2d_b1_attr = fluid.ParamAttr(name="conv2d_b_1")
conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2")
fc_b1_attr = fluid.ParamAttr(name="fc_b_1")
fc_b2_attr = fluid.ParamAttr(name="fc_b_2")
fc_b3_attr = fluid.ParamAttr(name="fc_b_3")
self.features = Sequential(
Conv2D(
num_channels=1,
num_filters=6,
filter_size=3,
stride=1,
padding=1,
param_attr=conv2d_w1_attr,
bias_attr=conv2d_b1_attr),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
Conv2D(
num_channels=6,
num_filters=16,
filter_size=5,
stride=1,
padding=0,
param_attr=conv2d_w2_attr,
bias_attr=conv2d_b2_attr),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2))
self.fc = Sequential(
Linear(
input_dim=400,
output_dim=120,
param_attr=fc_w1_attr,
bias_attr=fc_b1_attr),
Linear(
input_dim=120,
output_dim=84,
param_attr=fc_w2_attr,
bias_attr=fc_b2_attr),
Linear(
input_dim=84,
output_dim=num_classes,
act=classifier_activation,
param_attr=fc_w3_attr,
bias_attr=fc_b3_attr))
def forward(self, inputs):
x = self.features(inputs)
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class TestImperativeQat(unittest.TestCase):
"""
QAT = quantization-aware training
"""
def test_qat_save(self):
imperative_qat = ImperativeQuantAware(
weight_quantize_type='channel_wise_abs_max',
activation_quantize_type='moving_average_abs_max')
with fluid.dygraph.guard():
lenet = ImperativeLenet()
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=lenet.parameters())
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=32, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32)
epoch_num = 1
for epoch in range(epoch_num):
lenet.train()
for batch_id, data in enumerate(train_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc = fluid.layers.accuracy(out, label)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
lenet.eval()
for batch_id, data in enumerate(test_reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
acc_top1 = fluid.layers.accuracy(
input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(
input=out, label=label, k=5)
if batch_id % 100 == 0:
_logger.info(
"Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}".
format(epoch, batch_id,
acc_top1.numpy(), acc_top5.numpy()))
# save weights
model_dict = lenet.state_dict()
fluid.save_dygraph(model_dict, "save_temp")
# test the correctness of `paddle.jit.save`
data = next(test_reader())
test_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
test_img = fluid.dygraph.to_variable(test_data)
lenet.eval()
before_save = lenet(test_img)
# save inference quantized model
path = "./mnist_infer_model"
paddle.jit.save(
layer=lenet,
model_path=path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=path,
executor=exe,
model_filename="__model__",
params_filename="__variables__"))
after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets)
self.assertTrue(
np.allclose(after_save, before_save.numpy()),
msg='Failed to save the inference quantized model.')
def test_qat_acc(self):
def _build_static_lenet(main, startup, is_test=False, seed=1000):
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
main.random_seed = seed
startup.random_seed = seed
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
prediction = StaticLenet(img)
if not is_test:
loss = fluid.layers.cross_entropy(
input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
else:
avg_loss = prediction
return img, label, avg_loss
reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=32, drop_last=True)
weight_quantize_type = 'channel_wise_abs_max'
activation_quant_type = 'moving_average_abs_max'
param_init_map = {}
seed = 1000
lr = 0.1
# imperative train
_logger.info(
"--------------------------dynamic graph qat--------------------------"
)
imperative_qat = ImperativeQuantAware(
weight_quantize_type=weight_quantize_type,
activation_quantize_type=activation_quant_type)
with fluid.dygraph.guard():
np.random.seed(seed)
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
lenet = ImperativeLenet()
fixed_state = {}
for name, param in lenet.named_parameters():
p_shape = param.numpy().shape
p_value = param.numpy()
if name.endswith("bias"):
value = np.zeros_like(p_value).astype('float32')
else:
value = np.random.normal(
loc=0.0, scale=0.01, size=np.product(p_shape)).reshape(
p_shape).astype('float32')
fixed_state[name] = value
param_init_map[param.name] = value
lenet.set_dict(fixed_state)
imperative_qat.quantize(lenet)
adam = AdamOptimizer(
learning_rate=lr, parameter_list=lenet.parameters())
dynamic_loss_rec = []
lenet.train()
for batch_id, data in enumerate(reader()):
x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(x_data)
label = fluid.dygraph.to_variable(y_data)
out = lenet(img)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
lenet.clear_gradients()
dynamic_loss_rec.append(avg_loss.numpy()[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', avg_loss.numpy()))
paddle.jit.save(
layer=lenet,
model_path="./dynamic_mnist",
input_spec=[
paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
])
# static graph train
_logger.info(
"--------------------------static graph qat--------------------------"
)
static_loss_rec = []
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
main = fluid.Program()
infer = fluid.Program()
startup = fluid.Program()
static_img, static_label, static_loss = _build_static_lenet(
main, startup, False, seed)
infer_img, _, infer_pre = _build_static_lenet(infer, startup, True,
seed)
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
opt = AdamOptimizer(learning_rate=lr)
opt.minimize(static_loss)
scope = core.Scope()
with fluid.scope_guard(scope):
exe.run(startup)
for param in main.all_parameters():
param_tensor = scope.var(param.name).get_tensor()
param_tensor.set(param_init_map[param.name], place)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
infer_graph = IrGraph(core.Graph(infer.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quantize_type,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'])
transform_pass.apply(main_graph)
transform_pass.apply(infer_graph)
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel(
loss_name=static_loss.name, build_strategy=build_strategy)
feeder = fluid.DataFeeder(
feed_list=[static_img, static_label], place=place)
with fluid.scope_guard(scope):
for batch_id, data in enumerate(reader()):
loss_v, = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[static_loss])
static_loss_rec.append(loss_v[0])
if batch_id % 100 == 0:
_logger.info('{}: {}'.format('loss', loss_v))
save_program = infer_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model("./static_mnist", [infer_img.name],
[infer_pre], exe, save_program)
rtol = 1e-05
atol = 1e-08
for i, (loss_d,
loss_s) in enumerate(zip(dynamic_loss_rec, static_loss_rec)):
diff = np.abs(loss_d - loss_s)
if diff > (atol + rtol * np.abs(loss_s)):
_logger.info(
"diff({}) at {}, dynamic loss = {}, static loss = {}".
format(diff, i, loss_d, loss_s))
break
self.assertTrue(
np.allclose(
np.array(dynamic_loss_rec),
np.array(static_loss_rec),
rtol=rtol,
atol=atol,
equal_nan=True),
msg='Failed to do the imperative qat.')
if __name__ == '__main__':
unittest.main()
...@@ -306,5 +306,70 @@ class TestFakeQuantDequantAbsOp(OpTest): ...@@ -306,5 +306,70 @@ class TestFakeQuantDequantAbsOp(OpTest):
self.check_grad(["X"], "Out", user_defined_grads=gradient) self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestChannelWiseFakeQuantDequantOp(OpTest):
def setUp(self):
self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
self.op_type = "fake_channel_wise_quantize_dequantize_abs_max"
self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}
scales = []
outputs = self.inputs['X'].copy()
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
if self.quant_axis == 0:
for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
scales.append(scale_v)
outputs[i] = np.round(outputs[i] * range_v /
scale_v) * scale_v / range_v
elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
"float32")
scales.append(scale_v)
outputs[:, i] = np.round(outputs[:, i] * range_v /
scale_v) * scale_v / range_v
self.outputs = {
'Out': outputs,
'OutScale': np.array(scales).astype("float32"),
}
def set_arg(self):
self.quant_axis = 0
self.inputs = {
'X': np.random.random((3, 4, 64, 64)).astype("float32"),
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 1
self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype("float32"),
}
class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册