未验证 提交 491b87b4 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quantization clip and round Attribute (#43764)

上级 2739bd73
...@@ -33,8 +33,10 @@ struct Compare { ...@@ -33,8 +33,10 @@ struct Compare {
template <typename T> template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in, void operator()(const platform::CPUDeviceContext &ctx,
const int num, T* out) { const T *in,
const int num,
T *out) {
*out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>()))); *out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
} }
}; };
...@@ -43,24 +45,26 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -43,24 +45,26 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in_tensor, const int quant_axis, const framework::Tensor &in_tensor,
T* out_abs_max) { const int quant_axis,
T *out_abs_max) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d // At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul // conv2d_transpose and mul
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
auto* in_data = in_tensor.data<T>(); auto *in_data = in_tensor.data<T>();
auto in_dims = in_tensor.dims(); auto in_dims = in_tensor.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
if (quant_axis == 0) { if (quant_axis == 0) {
const int64_t channel_size = in_tensor.numel() / channel; const int64_t channel_size = in_tensor.numel() / channel;
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
auto* start = in_data + i * channel_size; auto *start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto *end = in_data + (i + 1) * channel_size;
out_abs_max[i] = out_abs_max[i] =
std::abs(*(std::max_element(start, end, Compare<T>()))); std::abs(*(std::max_element(start, end, Compare<T>())));
} }
...@@ -72,8 +76,8 @@ struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -72,8 +76,8 @@ struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]);
for (int64_t i = 0; i < in_dims[0]; i++) { for (int64_t i = 0; i < in_dims[0]; i++) {
for (int64_t j = 0; j < in_dims[1]; j++) { for (int64_t j = 0; j < in_dims[1]; j++) {
auto* start = in_data + i * step_i + j * step_j; auto *start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j; auto *end = in_data + i * step_i + (j + 1) * step_j;
T abs_max = std::abs(*(std::max_element(start, end, Compare<T>()))); T abs_max = std::abs(*(std::max_element(start, end, Compare<T>())));
out_abs_max[j] = std::max(out_abs_max[j], abs_max); out_abs_max[j] = std::max(out_abs_max[j], abs_max);
} }
...@@ -86,16 +90,30 @@ template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -86,16 +90,30 @@ template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
framework::Tensor *out) {
T s = scale.data<T>()[0]; T s = scale.data<T>()[0];
T inv_s = inverse(s); T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(), if (round_type == 0) {
trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), out->mutable_data<T>(ctx.GetPlace()),
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), round_type, inv_s)); QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()),
phi::ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
} }
}; };
...@@ -103,19 +121,34 @@ template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>; ...@@ -103,19 +121,34 @@ template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
framework::Tensor *out) {
T s = scale.data<T>()[0]; T s = scale.data<T>()[0];
T inv_s = inverse(s); T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(), if (round_type == 0) {
trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), out->mutable_data<T>(ctx.GetPlace()),
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), round_type, inv_s)); QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = out_e * s / static_cast<T>(bin_cnt); out_e.device(*ctx.eigen_device()) = out_e * s / static_cast<T>(bin_cnt);
} else {
trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()),
phi::ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
}
} }
}; };
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
...@@ -123,20 +156,24 @@ template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, ...@@ -123,20 +156,24 @@ template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
template <typename T> template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d // At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul // conv2d_transpose and mul
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
auto* scale_data = scale.data<T>(); auto *scale_data = scale.data<T>();
auto* in_data = in.data<T>(); auto *in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims(); auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
...@@ -144,12 +181,31 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -144,12 +181,31 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
const int64_t channel_size = in.numel() / channel; const int64_t channel_size = in.numel() / channel;
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
auto* start = in_data + i * channel_size; auto *start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto *end = in_data + (i + 1) * channel_size;
T inv_s = inverse(s);
if (round_type == 0) {
trans(ctx,
start,
end,
out_data + i * channel_size,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx,
start,
end,
out_data + i * channel_size,
phi::ClipFunctor<T>(-s, s));
}
}
if (round_type == 1) {
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i];
T inv_s = inverse(s); T inv_s = inverse(s);
trans( framework::Tensor one_channel_out = out->Slice(i, i + 1);
ctx, start, end, out_data + i * channel_size, auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), round_type, inv_s)); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0]; const int64_t step_i = in.numel() / in_dims[0];
...@@ -158,12 +214,21 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -158,12 +214,21 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
for (int j = 0; j < in_dims[1]; j++) { for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j]; T s = scale_data[j];
T inv_s = inverse(s); T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j; auto *start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j; auto *end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j; auto *cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, if (round_type == 0) {
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), round_type, trans(ctx,
inv_s)); start,
end,
cur_out_data,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
}
}
} }
} }
} }
...@@ -174,19 +239,23 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, ...@@ -174,19 +239,23 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>; float>;
template <typename T> template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
auto* scale_data = scale.data<T>(); auto *scale_data = scale.data<T>();
auto* in_data = in.data<T>(); auto *in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims(); auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
...@@ -194,15 +263,35 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { ...@@ -194,15 +263,35 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
const int64_t channel_size = in.numel() / channel; const int64_t channel_size = in.numel() / channel;
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
auto* start = in_data + i * channel_size; auto *start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto *end = in_data + (i + 1) * channel_size;
if (round_type == 0) {
T inv_s = inverse(s); T inv_s = inverse(s);
trans( trans(ctx,
ctx, start, end, out_data + i * channel_size, start,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), round_type, inv_s)); end,
out_data + i * channel_size,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx,
start,
end,
out_data + i * channel_size,
phi::ClipFunctor<T>(-s, s));
}
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = out_e * s / static_cast<T>(bin_cnt); if (round_type == 0) {
out_e.device(*ctx.eigen_device()) =
out_e * s / static_cast<T>(bin_cnt);
} else {
T inv_s = inverse(s);
out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
}
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0]; const int64_t step_i = in.numel() / in_dims[0];
...@@ -211,14 +300,25 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { ...@@ -211,14 +300,25 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
for (int j = 0; j < in_dims[1]; j++) { for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j]; T s = scale_data[j];
T inv_s = inverse(s); T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j; auto *start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j; auto *end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j; auto *cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, if (round_type == 0) {
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), round_type, trans(ctx,
inv_s)); start,
end,
cur_out_data,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s));
}
for (int k = 0; k < step_j; k++) { for (int k = 0; k < step_j; k++) {
if (round_type == 0) {
cur_out_data[k] = cur_out_data[k] * s / static_cast<T>(bin_cnt); cur_out_data[k] = cur_out_data[k] * s / static_cast<T>(bin_cnt);
} else {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) *
s / static_cast<T>(bin_cnt);
}
} }
} }
} }
...@@ -230,12 +330,14 @@ template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, ...@@ -230,12 +330,14 @@ template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext,
float>; float>;
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& cur_scale, const framework::Tensor &cur_scale,
const framework::Tensor& last_scale, const framework::Tensor &last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor &iter,
framework::Tensor* scales_arr, framework::Tensor* out_scale) { const int window_size,
T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace()); framework::Tensor *scales_arr,
framework::Tensor *out_scale) {
T *scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
int64_t it = iter.data<int64_t>()[0]; int64_t it = iter.data<int64_t>()[0];
int idx = it % window_size; int idx = it % window_size;
T removed = scale_arr[idx]; T removed = scale_arr[idx];
...@@ -247,8 +349,8 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -247,8 +349,8 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
max = cur; max = cur;
} else if (fabs(removed - max) < 1e-6) { } else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it; int size = (it > window_size) ? window_size : it;
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size, FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(
&max); ctx, scale_arr, size, &max);
} }
out_scale->mutable_data<T>(ctx.GetPlace())[0] = max; out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
} }
...@@ -258,11 +360,14 @@ template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -258,11 +360,14 @@ template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in_accum, const framework::Tensor &in_accum,
const framework::Tensor& in_state, const T* cur_scale, const framework::Tensor &in_state,
const float rate, framework::Tensor* out_state, const T *cur_scale,
framework::Tensor* out_accum, framework::Tensor* out_scale) { const float rate,
framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale) {
T accum = in_accum.data<T>()[0]; T accum = in_accum.data<T>()[0];
T state = in_state.data<T>()[0]; T state = in_state.data<T>()[0];
T scale = cur_scale[0]; T scale = cur_scale[0];
...@@ -282,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, ...@@ -282,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantOrWithDequantAbsMaxOp(const std::string& type, FakeQuantOrWithDequantAbsMaxOp(const std::string &type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(
"FakeQuantOrWithDequantAbsMaxOp"); ctx->HasInput("X"), "Input", "X", "FakeQuantOrWithDequantAbsMaxOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeQuantOrWithDequantAbsMaxOp"); "FakeQuantOrWithDequantAbsMaxOp");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantOrWithDequantAbsMaxOp"); "FakeQuantOrWithDequantAbsMaxOp");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1}); ctx->SetOutputDim("OutScale", {1});
...@@ -302,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { ...@@ -302,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
...@@ -320,8 +429,9 @@ class FakeQuantOrWithDequantAbsMaxOpMaker ...@@ -320,8 +429,9 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
AddOutput("OutScale", "(Tensor) Current scale"); AddOutput("OutScale", "(Tensor) Current scale");
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
...@@ -329,18 +439,22 @@ class FakeQuantOrWithDequantAbsMaxOpMaker ...@@ -329,18 +439,22 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
}); });
AddAttr<int>( AddAttr<int>(
"round_type", "round_type",
"(int, default 0) The round type of fp32 to int." "(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3") "round(2.5)=3")
.SetDefault(0) .SetDefault(1)
.AddCustomChecker([](const int& round_type) { .AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but " "'round_type' should be 0 or 1, 0 rounding to "
"the received is %d", "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type)); round_type));
}); })
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization. FakeQuantAbsMaxOp operator is used in the dynamic quantization.
...@@ -363,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -363,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(
"FakeChannelWiseQuantizeAbsMax"); ctx->HasInput("X"), "Input", "X", "FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeChannelWiseQuantizeAbsMax"); "FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeChannelWiseQuantizeAbsMax"); "FakeChannelWiseQuantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
...@@ -378,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -378,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -398,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -398,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -407,8 +526,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -407,8 +526,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
...@@ -416,18 +536,22 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -416,18 +536,22 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
}); });
AddAttr<int>( AddAttr<int>(
"round_type", "round_type",
"(int, default 0) The round type of fp32 to int." "(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3") "round(2.5)=3")
.SetDefault(0) .SetDefault(1)
.AddCustomChecker([](const int& round_type) { .AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but " "'round_type' should be 0 or 1, 0 rounding to "
"the received is %d", "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type)); round_type));
}); })
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -450,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp ...@@ -450,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(ctx->HasInput("X"),
"Input",
"X",
"FakeChannelWiseQuantizeDequantizeAbsMax"); "FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeChannelWiseQuantizeDequantizeAbsMax"); "FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeChannelWiseQuantizeDequantizeAbsMax"); "FakeChannelWiseQuantizeDequantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
...@@ -465,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp ...@@ -465,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -485,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker ...@@ -485,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -494,8 +625,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker ...@@ -494,8 +625,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
...@@ -503,18 +635,22 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker ...@@ -503,18 +635,22 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
}); });
AddAttr<int>( AddAttr<int>(
"round_type", "round_type",
"(int, default 0) The round type of fp32 to int." "(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3") "round(2.5)=3")
.SetDefault(0) .SetDefault(1)
.AddCustomChecker([](const int& round_type) { .AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but " "'round_type' should be 0 or 1, 0 rounding to "
"the received is %d", "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type)); round_type));
}); })
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector. The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value. In detail, each channel of the input X has a scale value.
...@@ -530,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$ ...@@ -530,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeRangeAbsMaxOp(const std::string& type, FakeQuantizeRangeAbsMaxOp(const std::string &type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(
"FakeQuantizeRangeAbsMax"); ctx->HasOutput("Out"), "Output", "Out", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantizeRangeAbsMax"); "FakeQuantizeRangeAbsMax");
if (ctx->HasOutput("OutScales")) { if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size"); int window_size = ctx->Attrs().Get<int>("window_size");
...@@ -553,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -553,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
...@@ -574,8 +712,9 @@ class FakeQuantizeRangeAbsMaxOpMaker ...@@ -574,8 +712,9 @@ class FakeQuantizeRangeAbsMaxOpMaker
.SetDefault(10000); .SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.") AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
...@@ -583,18 +722,22 @@ class FakeQuantizeRangeAbsMaxOpMaker ...@@ -583,18 +722,22 @@ class FakeQuantizeRangeAbsMaxOpMaker
}); });
AddAttr<int>( AddAttr<int>(
"round_type", "round_type",
"(int, default 0) The round type of fp32 to int." "(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3") "round(2.5)=3")
.SetDefault(0) .SetDefault(1)
.AddCustomChecker([](const int& round_type) { .AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but " "'round_type' should be 0 or 1, 0 rounding to "
"the received is %d", "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type)); round_type));
}); })
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -614,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp ...@@ -614,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
: public framework::OperatorWithKernel { : public framework::OperatorWithKernel {
public: public:
FakeQuantOrWithDequantMovingAverageAbsMaxOp( FakeQuantOrWithDequantMovingAverageAbsMaxOp(
const std::string& type, const framework::VariableNameMap& inputs, const std::string &type,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &inputs,
const framework::AttributeMap& attrs) const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(ctx->HasInput("X"),
"Input",
"X",
"FakeQuantOrWithDequantMovingAverageAbsMax"); "FakeQuantOrWithDequantMovingAverageAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeQuantOrWithDequantMovingAverageAbsMax"); "FakeQuantOrWithDequantMovingAverageAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantOrWithDequantMovingAverageAbsMax"); "FakeQuantOrWithDequantMovingAverageAbsMax");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1}); ctx->SetOutputDim("OutState", {1});
...@@ -639,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp ...@@ -639,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
...@@ -662,8 +812,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker ...@@ -662,8 +812,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
.SetDefault(0.9); .SetDefault(0.9);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.") AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
...@@ -671,18 +822,22 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker ...@@ -671,18 +822,22 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
}); });
AddAttr<int>( AddAttr<int>(
"round_type", "round_type",
"(int, default 0) The round type of fp32 to int." "(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3") "round(2.5)=3")
.SetDefault(0) .SetDefault(1)
.AddCustomChecker([](const int& round_type) { .AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but " "'round_type' should be 0 or 1, 0 rounding to "
"the received is %d", "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type)); round_type));
}); })
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -709,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -709,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(
"MovingAverageAbsMaxScale"); ctx->HasInput("X"), "Input", "X", "MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"MovingAverageAbsMaxScale"); "MovingAverageAbsMaxScale");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
...@@ -730,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -730,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -770,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { ...@@ -770,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, OP_INOUT_CHECK(ctx->HasInput(out_grad_name),
"Input",
out_grad_name,
"StrightThroughEstimatorGradOp"); "StrightThroughEstimatorGradOp");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, OP_INOUT_CHECK(ctx->HasOutput(x_grad_name),
"Output",
x_grad_name,
"StrightThroughEstimatorGradOp"); "StrightThroughEstimatorGradOp");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
...@@ -810,7 +971,8 @@ namespace ops = paddle::operators; ...@@ -810,7 +971,8 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, fake_quantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -818,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, ...@@ -818,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, fake_quantize_dequantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
...@@ -826,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max, ...@@ -826,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker, ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -853,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -853,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL(
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp, fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -861,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, ...@@ -861,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>); ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker, ops::MovingAverageAbsMaxScaleOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
......
...@@ -36,12 +36,12 @@ struct QuantizeDataType<paddle::platform::float16> { ...@@ -36,12 +36,12 @@ struct QuantizeDataType<paddle::platform::float16> {
}; };
template <typename T> template <typename T>
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { __global__ void FindAbsMaxKernel(const T *in, const int n, T *out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
extern __shared__ char* shared_max_data_tmp[]; extern __shared__ char *shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp); auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
if (gridDim.x > 1) { if (gridDim.x > 1) {
T local_max_data = T(0); T local_max_data = T(0);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
...@@ -73,14 +73,16 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { ...@@ -73,14 +73,16 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
template <typename T> template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in, void operator()(const platform::CUDADeviceContext &ctx,
const int num, T* out) { const T *in,
const int num,
T *out) {
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid; grid = (grid > block) ? block : grid;
framework::Tensor max; framework::Tensor max;
T* max_data = max.mutable_data<T>(phi::make_ddim({grid}), ctx.GetPlace()); T *max_data = max.mutable_data<T>(phi::make_ddim({grid}), ctx.GetPlace());
FindAbsMaxKernel<T> FindAbsMaxKernel<T>
<<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(in, num, max_data); <<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(in, num, max_data);
FindAbsMaxKernel<T> FindAbsMaxKernel<T>
...@@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor<platform::CUDADeviceContext, ...@@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
paddle::platform::float16>; paddle::platform::float16>;
template <typename T> template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, __global__ void FindChannelAbsMaxKernelQuantAxis0(const T *in,
const int c, T* out) { const int n,
const int c,
T *out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T *in_c = in + blockIdx.x * channel_size;
extern __shared__ char* shared_max_data_tmp[]; extern __shared__ char *shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp); auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
T local_max_data = T(0); T local_max_data = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = static_cast<T>( T tmp = static_cast<T>(
...@@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, ...@@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
} }
template <typename T> template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, __global__ void FindChannelAbsMaxKernelQuantAxis1(
const int cin, const int cout, const T *in, const int n, const int cin, const int cout, T *out) {
T* out) { extern __shared__ char *shared_max_data_tmp[];
extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
int cout_wh_size = n / cin; int cout_wh_size = n / cin;
int wh_size = n / (cin * cout); int wh_size = n / (cin * cout);
int tid = threadIdx.x; int tid = threadIdx.x;
int bid = blockIdx.x; int bid = blockIdx.x;
const T* in_current = in + tid * cout_wh_size + bid * wh_size; const T *in_current = in + tid * cout_wh_size + bid * wh_size;
T local_max_data = T(0); T local_max_data = T(0);
for (int i = 0; i < wh_size; i++) { for (int i = 0; i < wh_size; i++) {
T tmp = static_cast<T>( T tmp = static_cast<T>(
...@@ -162,24 +165,26 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, ...@@ -162,24 +165,26 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
template <typename T> template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in_tensor, const int quant_axis, const framework::Tensor &in_tensor,
T* out_abs_max) { const int quant_axis,
T *out_abs_max) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
const int num = in_tensor.numel(); const int num = in_tensor.numel();
auto in_dims = in_tensor.dims(); auto in_dims = in_tensor.dims();
const T* in_data = in_tensor.data<T>(); const T *in_data = in_tensor.data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
int cout = in_dims[0]; int cout = in_dims[0];
int grid = cout; int grid = cout;
int block = 1024; int block = 1024;
FindChannelAbsMaxKernelQuantAxis0<T> FindChannelAbsMaxKernelQuantAxis0<T>
<<<grid, block, block * sizeof(T), ctx.stream()>>>(in_data, num, cout, <<<grid, block, block * sizeof(T), ctx.stream()>>>(
out_abs_max); in_data, num, cout, out_abs_max);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
int cin = in_dims[0]; int cin = in_dims[0];
int cout = in_dims[1]; int cout = in_dims[1];
...@@ -213,9 +218,12 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -213,9 +218,12 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>; template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale, __global__ void ClipAndQuantKernel(const T *in,
const int bin_cnt, const int round_type, const T *scale,
const int n, T* out) { const int bin_cnt,
const int round_type,
const int n,
T *out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -227,25 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -227,25 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
ComputeDataType x = static_cast<ComputeDataType>(in[i]); ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = bin_cnt_t * inv_s * x;
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
} else {
x = round(x);
}
ComputeDataType max_bound = bin_cnt_t; ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1); ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out[i] = static_cast<T>(x); out[i] = static_cast<T>(x);
} else {
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v));
}
} }
} }
template <typename T> template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, __global__ void ClipAndQuantDequantKernel(const T *in,
const T *scale,
const int bin_cnt, const int bin_cnt,
const int round_type, const int n, const int round_type,
T* out) { const int n,
T *out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -257,33 +270,39 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, ...@@ -257,33 +270,39 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
ComputeDataType x = static_cast<ComputeDataType>(in[i]); ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = bin_cnt_t * inv_s * x;
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
} else {
x = round(x);
}
ComputeDataType max_bound = bin_cnt_t; ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1); ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out[i] = static_cast<T>((x * s) / bin_cnt_t); out[i] = static_cast<T>((x * s) / bin_cnt_t);
} else {
x = x > s ? s : x;
x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x;
x = round(x);
out[i] = static_cast<T>((x * s) / bin_cnt_t);
}
} }
} }
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
framework::Tensor *out) {
int num = in.numel(); int num = in.numel();
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>( ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, round_type, num, out_data); in_data, scale_data, bin_cnt, round_type, num, out_data);
...@@ -294,17 +313,19 @@ template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>; ...@@ -294,17 +313,19 @@ template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
framework::Tensor *out) {
int num = in.numel(); int num = in.numel();
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>( ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, round_type, num, out_data); in_data, scale_data, bin_cnt, round_type, num, out_data);
...@@ -313,16 +334,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { ...@@ -313,16 +334,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
// ChannelClipAndQuantKernel for quant_axis is 0 // ChannelClipAndQuantKernel for quant_axis is 0
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, __global__ void ChannelClipAndQuantKernelQuantAxis0(const T *in,
const T *scale,
const int bin_cnt, const int bin_cnt,
const int round_type, const int round_type,
const int64_t n, const int64_t n,
const int c, T* out) { const int c,
T *out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int64_t channel_size = n / c; int64_t channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T *in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T *out_c = out + blockIdx.x * channel_size;
using ComputeDataType = typename QuantizeDataType<T>::type; using ComputeDataType = typename QuantizeDataType<T>::type;
...@@ -332,25 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, ...@@ -332,25 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
for (int64_t i = tid; i < channel_size; i += blockDim.x) { for (int64_t i = tid; i < channel_size; i += blockDim.x) {
ComputeDataType x = static_cast<ComputeDataType>(in_c[i]); ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
x = bin_cnt_t * inv_s * x;
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
} else {
x = round(x);
}
ComputeDataType max_bound = bin_cnt_t; ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1); ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out_c[i] = static_cast<T>(x); out_c[i] = static_cast<T>(x);
} else {
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>(round(v));
}
} }
} }
// ChannelClipAndQuantKernel for quant_axis is N // ChannelClipAndQuantKernel for quant_axis is N
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxisN( __global__ void ChannelClipAndQuantKernelQuantAxisN(const T *in,
const T* in, const T* scale, const int bin_cnt, const int round_type, const T *scale,
const int64_t n, const int nScale, const int quant_stride, T* out) { const int bin_cnt,
const int round_type,
const int64_t n,
const int nScale,
const int quant_stride,
T *out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using ComputeDataType = typename QuantizeDataType<T>::type; using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt); ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
...@@ -359,37 +390,44 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( ...@@ -359,37 +390,44 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]); static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
ComputeDataType inv_s = inverse(s); ComputeDataType inv_s = inverse(s);
ComputeDataType x = static_cast<ComputeDataType>(in[i]); ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = bin_cnt_t * inv_s * x;
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
} else {
x = round(x);
}
ComputeDataType max_bound = bin_cnt_t; ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1); ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out[i] = static_cast<T>(x); out[i] = static_cast<T>(x);
} else {
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v));
}
} }
} }
template <typename T> template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
int64_t num = in.numel(); int64_t num = in.numel();
auto in_dims = in.dims(); auto in_dims = in.dims();
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = in_dims[0]; int grid = in_dims[0];
...@@ -411,9 +449,15 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { ...@@ -411,9 +449,15 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
const int64_t grid_size = const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size); std::min(max_blocks, (num + block_size - 1) / block_size);
ChannelClipAndQuantKernelQuantAxisN<T><<<grid_size, block_size>>>( ChannelClipAndQuantKernelQuantAxisN<T>
in_data, scale_data, bin_cnt, round_type, num, in_dims[quant_axis], <<<grid_size, block_size>>>(in_data,
quant_stride, out_data); scale_data,
bin_cnt,
round_type,
num,
in_dims[quant_axis],
quant_stride,
out_data);
} }
} }
}; };
...@@ -422,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, ...@@ -422,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
float>; float>;
template <typename T> template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, __global__ void FindRangeAbsMaxAndFillArray(const T *cur_scale,
const T* last_scale, const T *last_scale,
const int64_t* iter, const int64_t *iter,
const int window_size, T* scale_arr, const int window_size,
T* out_scale, int* need_find_max, T *scale_arr,
int* out_size) { T *out_scale,
int *need_find_max,
int *out_size) {
int it = iter[0]; int it = iter[0];
int idx = it % window_size; int idx = it % window_size;
T removed = scale_arr[idx]; T removed = scale_arr[idx];
...@@ -446,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, ...@@ -446,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& cur_scale, const framework::Tensor &cur_scale,
const framework::Tensor& last_scale, const framework::Tensor &last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor &iter,
framework::Tensor* scales_arr, framework::Tensor* out_scale) { const int window_size,
framework::Tensor *scales_arr,
framework::Tensor *out_scale) {
const auto gpu_place = ctx.GetPlace(); const auto gpu_place = ctx.GetPlace();
T* scale_arr = scales_arr->mutable_data<T>(gpu_place); T *scale_arr = scales_arr->mutable_data<T>(gpu_place);
T* out_scale_data = out_scale->mutable_data<T>(gpu_place); T *out_scale_data = out_scale->mutable_data<T>(gpu_place);
framework::Tensor need_find_max, out_size; framework::Tensor need_find_max, out_size;
int* find_max = need_find_max.mutable_data<int>({1}, gpu_place); int *find_max = need_find_max.mutable_data<int>({1}, gpu_place);
int* out_size_data = out_size.mutable_data<int>({1}, gpu_place); int *out_size_data = out_size.mutable_data<int>({1}, gpu_place);
FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>( FindRangeAbsMaxAndFillArray<T>
cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(), <<<1, 1, 0, ctx.stream()>>>(cur_scale.data<T>(),
window_size, scale_arr, out_scale_data, find_max, out_size_data); last_scale.data<T>(),
iter.data<int64_t>(),
window_size,
scale_arr,
out_scale_data,
find_max,
out_size_data);
int g_find_max; int g_find_max;
memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, memory::Copy(platform::CPUPlace(),
sizeof(int), ctx.stream()); &g_find_max,
gpu_place,
find_max,
sizeof(int),
ctx.stream());
ctx.Wait(); ctx.Wait();
if (g_find_max) { if (g_find_max) {
int len; int len;
memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, memory::Copy(platform::CPUPlace(),
sizeof(int), ctx.stream()); &len,
gpu_place,
out_size_data,
sizeof(int),
ctx.stream());
ctx.Wait(); ctx.Wait();
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len, FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(
out_scale_data); ctx, scale_arr, len, out_scale_data);
} }
} }
}; };
template <typename T> template <typename T>
__global__ void FindMovingAverageAbsMaxKernel(const T* in_state, __global__ void FindMovingAverageAbsMaxKernel(const T *in_state,
const T* in_accum, const T *in_accum,
const T* cur_scale, const T rate, const T *cur_scale,
T* out_state, T* out_accum, const T rate,
T* out_scale) { T *out_state,
T *out_accum,
T *out_scale) {
T state = rate * (*in_state) + T(1.0f); T state = rate * (*in_state) + T(1.0f);
T accum = rate * (*in_accum) + (*cur_scale); T accum = rate * (*in_accum) + (*cur_scale);
*out_state = state; *out_state = state;
...@@ -496,92 +560,119 @@ template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>; ...@@ -496,92 +560,119 @@ template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in_accum, const framework::Tensor &in_accum,
const framework::Tensor& in_state, const T* cur_scale, const framework::Tensor &in_state,
const float rate, framework::Tensor* out_state, const T *cur_scale,
framework::Tensor* out_accum, framework::Tensor* out_scale) { const float rate,
framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale) {
const auto gpu_place = ctx.GetPlace(); const auto gpu_place = ctx.GetPlace();
T rate_t = static_cast<T>(rate); T rate_t = static_cast<T>(rate);
T* out_state_data = out_state->mutable_data<T>(gpu_place); T *out_state_data = out_state->mutable_data<T>(gpu_place);
T* out_accum_data = out_accum->mutable_data<T>(gpu_place); T *out_accum_data = out_accum->mutable_data<T>(gpu_place);
T* out_scale_data = out_scale->mutable_data<T>(gpu_place); T *out_scale_data = out_scale->mutable_data<T>(gpu_place);
FindMovingAverageAbsMaxKernel<T><<<1, 1, 0, ctx.stream()>>>( FindMovingAverageAbsMaxKernel<T>
in_state.data<T>(), in_accum.data<T>(), cur_scale, rate_t, <<<1, 1, 0, ctx.stream()>>>(in_state.data<T>(),
out_state_data, out_accum_data, out_scale_data); in_accum.data<T>(),
cur_scale,
rate_t,
out_state_data,
out_accum_data,
out_scale_data);
} }
}; };
// ChannelClipAndQuantDequantKernel for quant_axis is 0 // ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0( __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
const T* in, const T* scale, const int bin_cnt, const int round_type, const T *scale,
const int n, const int c, T* out) { const int bin_cnt,
const int round_type,
const int n,
const int c,
T *out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T *in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T *out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x]; T s = scale[blockIdx.x];
T inv_s = inverse(s); T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i]; T x = in_c[i];
x = bin_cnt * inv_s * x;
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
} else {
x = round(x);
}
T max_bound = bin_cnt; T max_bound = bin_cnt;
T min_bound = -bin_cnt - static_cast<T>(1); T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt; out_c[i] = (x * s) / bin_cnt;
} else {
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
}
} }
} }
// ChannelClipAndQuantDequantKernel for quant_axis is 1 // ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1( __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
const T* in, const T* scale, const int bin_cnt, const int round_type, const T *scale,
const int n, const int cin, const int cout, T* out) { const int bin_cnt,
const int round_type,
const int n,
const int cin,
const int cout,
T *out) {
T s = scale[blockIdx.x % cout]; T s = scale[blockIdx.x % cout];
T inv_s = inverse(s); T inv_s = inverse(s);
int wh_size = n / (cin * cout); int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size; const T *in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size; T *out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i]; T x = in_c[i];
x = bin_cnt * inv_s * x;
if (round_type == 0) { if (round_type == 0) {
x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x); x = roundWithTiesToEven(x);
} else {
x = round(x);
}
T max_bound = bin_cnt; T max_bound = bin_cnt;
T min_bound = -bin_cnt - static_cast<T>(1); T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x; x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x; x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt; out_c[i] = (x * s) / bin_cnt;
} else {
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt;
}
} }
} }
template <typename T> template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int round_type, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d // At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul // conv2d_transpose and mul
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
...@@ -589,25 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { ...@@ -589,25 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
int num = in.numel(); int num = in.numel();
auto in_dims = in.dims(); auto in_dims = in.dims();
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = in_dims[0]; int grid = in_dims[0];
int block = 1024; int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis0<T> ChannelClipAndQuantDequantKernelQuantAxis0<T>
<<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt, <<<grid, block, 0, ctx.stream()>>>(in_data,
round_type, num, in_dims[0], scale_data,
bin_cnt,
round_type,
num,
in_dims[0],
out_data); out_data);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1]; int grid = in_dims[0] * in_dims[1];
int block = 1024; int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis1<T> ChannelClipAndQuantDequantKernelQuantAxis1<T>
<<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt, <<<grid, block, 0, ctx.stream()>>>(in_data,
round_type, num, in_dims[0], scale_data,
in_dims[1], out_data); bin_cnt,
round_type,
num,
in_dims[0],
in_dims[1],
out_data);
} }
} }
}; };
......
...@@ -51,16 +51,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) { ...@@ -51,16 +51,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) {
template <typename T> template <typename T>
class QuantTensorFunctor { class QuantTensorFunctor {
public: public:
explicit QuantTensorFunctor(const T bin_cnt, const int round_type, explicit QuantTensorFunctor(const T bin_cnt, const T inv_s)
const T inv_s) : bin_cnt_(bin_cnt), inv_s_(inv_s) {}
: bin_cnt_(bin_cnt), round_type_(round_type), inv_s_(inv_s) {}
HOSTDEVICE T operator()(const T x) const { HOSTDEVICE T operator()(const T x) const {
T out = bin_cnt_ * inv_s_ * x; T out = bin_cnt_ * inv_s_ * x;
if (round_type_ == 0) {
out = roundWithTiesToEven(out); out = roundWithTiesToEven(out);
} else if (round_type_ == 1) {
out = std::round(out);
}
T max_bound = bin_cnt_; T max_bound = bin_cnt_;
T min_bound = -bin_cnt_ - static_cast<T>(1); T min_bound = -bin_cnt_ - static_cast<T>(1);
out = out > max_bound ? max_bound : out; out = out > max_bound ? max_bound : out;
...@@ -70,82 +65,101 @@ class QuantTensorFunctor { ...@@ -70,82 +65,101 @@ class QuantTensorFunctor {
private: private:
T bin_cnt_; T bin_cnt_;
int round_type_;
T inv_s_; T inv_s_;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor { struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); void operator()(const DeviceContext &ctx, const T *in, const int num, T *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ClipAndFakeQuantFunctor { struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
const int round_type, framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ClipAndFakeQuantDequantFunctor { struct ClipAndFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
int round_type, framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
int round_type,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor { struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, void operator()(const DeviceContext &ctx,
const framework::Tensor& last_scale, const framework::Tensor &cur_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor &last_scale,
framework::Tensor* scales_arr, framework::Tensor* out_scale); const framework::Tensor &iter,
const int window_size,
framework::Tensor *scales_arr,
framework::Tensor *out_scale);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor { struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, void operator()(const DeviceContext &ctx,
const int quant_axis, T* out_abs_max); const framework::Tensor &in_tensor,
const int quant_axis,
T *out_abs_max);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor { struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
const int round_type, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out); const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor { struct ChannelClipFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
int round_type, const int quant_axis, framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
int round_type,
const int quant_axis,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor { struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, void operator()(const DeviceContext &ctx,
const framework::Tensor& in_state, const framework::Tensor &in_accum,
const framework::Tensor& cur_scale, const framework::Tensor &in_state,
framework::Tensor* out_state, framework::Tensor* out_accum, const framework::Tensor &cur_scale,
framework::Tensor* out_scale); framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeAbsMaxKernelBase : public framework::OpKernel<T> { class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace()); T *out_s = out_scale->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type"); int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>(); const T *in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s); FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out); RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} }
...@@ -153,20 +167,25 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> { ...@@ -153,20 +167,25 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
virtual ~FakeAbsMaxKernelBase() = default; virtual ~FakeAbsMaxKernelBase() = default;
protected: protected:
virtual void RunClipFunctor(const DeviceContext& dev_ctx, virtual void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in, const framework::Tensor &in,
const framework::Tensor& scale, int bin_cnt, const framework::Tensor &scale,
int round_type, framework::Tensor* out) const = 0; int bin_cnt,
int round_type,
framework::Tensor *out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> { class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& scale, int bin_cnt, const framework::Tensor &in,
int round_type, framework::Tensor* out) const override { const framework::Tensor &scale,
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, scale, bin_cnt, int bin_cnt,
round_type, out); int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, in, scale, bin_cnt, round_type, out);
} }
}; };
...@@ -174,9 +193,12 @@ template <typename DeviceContext, typename T> ...@@ -174,9 +193,12 @@ template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeAbsMaxKernel class FakeQuantizeDequantizeAbsMaxKernel
: public FakeAbsMaxKernelBase<DeviceContext, T> { : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& scale, int bin_cnt, const framework::Tensor &in,
int round_type, framework::Tensor* out) const override { const framework::Tensor &scale,
int bin_cnt,
int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()( ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, in, scale, bin_cnt, round_type, out); dev_ctx, in, scale, bin_cnt, round_type, out);
} }
...@@ -185,11 +207,11 @@ class FakeQuantizeDequantizeAbsMaxKernel ...@@ -185,11 +207,11 @@ class FakeQuantizeDequantizeAbsMaxKernel
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
...@@ -198,11 +220,11 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -198,11 +220,11 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
if (!is_test) { if (!is_test) {
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis, FindChannelAbsMaxFunctor<DeviceContext, T>()(
out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
} }
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
...@@ -213,12 +235,12 @@ template <typename DeviceContext, typename T> ...@@ -213,12 +235,12 @@ template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
...@@ -226,8 +248,8 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel ...@@ -226,8 +248,8 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis, FindChannelAbsMaxFunctor<DeviceContext, T>()(
out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()( ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
...@@ -237,60 +259,64 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel ...@@ -237,60 +259,64 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale"); auto *in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type"); int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
// testing // testing
if (is_test) { if (is_test) {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, round_type, out); dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
return; return;
} }
// training // training
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
auto* out_scales = context.Output<framework::Tensor>("OutScales"); auto *out_scales = context.Output<framework::Tensor>("OutScales");
auto* iter = context.Input<framework::Tensor>("Iter"); auto *iter = context.Input<framework::Tensor>("Iter");
int window_size = context.Attr<int>("window_size"); int window_size = context.Attr<int>("window_size");
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
framework::Tensor cur_scale; framework::Tensor cur_scale;
T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace()); T *cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(
cur_scale_data); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx, cur_scale, *in_scale, FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
*iter, window_size, out_scales, cur_scale,
*in_scale,
*iter,
window_size,
out_scales,
out_scale); out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, round_type, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale"); auto *in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type"); int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
// testing // testing
if (is_test) { if (is_test) {
...@@ -299,25 +325,30 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { ...@@ -299,25 +325,30 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
} }
// training // training
auto* in_accum = context.Input<framework::Tensor>("InAccum"); auto *in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState"); auto *in_state = context.Input<framework::Tensor>("InState");
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr()); T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(
cur_scale_data); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState"); auto *out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum"); auto *out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace()); out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace()); out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate"); float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()( FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, *in_accum,
out_accum, out_scale); *in_state,
cur_scale_data,
moving_rate,
out_state,
out_accum,
out_scale);
RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out); RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} }
...@@ -325,21 +356,26 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { ...@@ -325,21 +356,26 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
virtual ~FakeMovingAverageAbsMaxKernelBase() = default; virtual ~FakeMovingAverageAbsMaxKernelBase() = default;
protected: protected:
virtual void RunClipFunctor(const DeviceContext& dev_ctx, virtual void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in, const framework::Tensor &in,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor &in_scale,
int round_type, framework::Tensor* out) const = 0; int bin_cnt,
int round_type,
framework::Tensor *out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel class FakeQuantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor &in,
int round_type, framework::Tensor* out) const override { const framework::Tensor &in_scale,
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale, bin_cnt, int bin_cnt,
round_type, out); int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, in, in_scale, bin_cnt, round_type, out);
} }
}; };
...@@ -347,9 +383,12 @@ template <typename DeviceContext, typename T> ...@@ -347,9 +383,12 @@ template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor &in,
int round_type, framework::Tensor* out) const override { const framework::Tensor &in_scale,
int bin_cnt,
int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()( ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, in, in_scale, bin_cnt, round_type, out); dev_ctx, in, in_scale, bin_cnt, round_type, out);
} }
...@@ -358,12 +397,12 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel ...@@ -358,12 +397,12 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
if (context.HasOutput("Out")) { if (context.HasOutput("Out")) {
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} }
...@@ -375,37 +414,43 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -375,37 +414,43 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
} }
// training // training
auto* in_accum = context.Input<framework::Tensor>("InAccum"); auto *in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState"); auto *in_state = context.Input<framework::Tensor>("InState");
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr()); T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(
cur_scale_data); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState"); auto *out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum"); auto *out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace()); out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace()); out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate"); float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()( FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, *in_accum,
out_accum, out_scale); *in_state,
cur_scale_data,
moving_rate,
out_state,
out_accum,
out_scale);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> { class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* d_out = auto *d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name); auto *d_x = context.Output<framework::LoDTensor>(x_grad_name);
PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(d_x,
platform::errors::PreconditionNotMet(
"StrightThroughEstimatorGradKernel " "StrightThroughEstimatorGradKernel "
"doesn't have the output named %s.", "doesn't have the output named %s.",
x_grad_name)); x_grad_name));
......
...@@ -26,14 +26,17 @@ namespace operators { ...@@ -26,14 +26,17 @@ namespace operators {
template <typename T> template <typename T>
struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx, void operator()(const platform::CPUDeviceContext &dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale, const framework::Tensor *in,
T max_range, const int quant_axis, framework::Tensor* out) { const framework::Tensor *scale,
T max_range,
const int quant_axis,
framework::Tensor *out) {
// Dequant op is before quantized op // Dequant op is before quantized op
// Dequantize the weight of quantized op // Dequantize the weight of quantized op
auto in_dims = in->dims(); auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
const T* scale_factor = scale->data<T>(); const T *scale_factor = scale->data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i]; T s = scale_factor[i];
...@@ -41,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -41,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in); auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device(); auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range; out_e.device(dev) = in_e * s / max_range;
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
...@@ -51,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -51,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
} }
int64_t step_i = in->numel() / out_iter; int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel); int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>(); auto *in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); auto *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) { for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) { for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j; auto *cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j; auto *cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j]; T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) { for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range; *cur_out = (*cur_in) * s / max_range;
...@@ -75,11 +78,11 @@ template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>; ...@@ -75,11 +78,11 @@ template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>;
class QuantizeLinearOp : public framework::OperatorWithKernel { class QuantizeLinearOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", OP_INOUT_CHECK(
"QuantizeLinear"); ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear");
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
...@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { ...@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1 || quant_axis == -1, true, quant_axis == 0 || quant_axis == 1 || quant_axis == -1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -126,8 +130,9 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,8 +130,9 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
...@@ -140,13 +145,17 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -140,13 +145,17 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3") "round(2.5)=3")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& round_type) { .AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(round_type >= 0 && round_type <= 1, true, PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'round_type' should be between 0 and 1, but " "'round_type' should be 0 or 1, 0 rounding to "
"the received is %d", "nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type)); round_type));
}); })
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -170,14 +179,18 @@ namespace ops = paddle::operators; ...@@ -170,14 +179,18 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, quantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>); REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, dequantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -121,8 +121,7 @@ class PostTrainingQuantization(object): ...@@ -121,8 +121,7 @@ class PostTrainingQuantization(object):
algo="KL", algo="KL",
hist_percent=0.99999, hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_round_algo='round', round_type='round',
round_type='TiesToEven',
learning_rate=0.001, learning_rate=0.001,
is_full_quantize=False, is_full_quantize=False,
bias_correction=False, bias_correction=False,
...@@ -181,14 +180,10 @@ class PostTrainingQuantization(object): ...@@ -181,14 +180,10 @@ class PostTrainingQuantization(object):
quantizable_op_type(list[str], optional): List the type of ops quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
weight_round_algo(str, optional): The method of converting the quantized weights round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods. value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer. Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568. 'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
learning_rate(float, optional): The learning rate of adaround method. learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True, is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
...@@ -269,10 +264,8 @@ class PostTrainingQuantization(object): ...@@ -269,10 +264,8 @@ class PostTrainingQuantization(object):
self._support_algo_type = [ self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max' 'KL', 'hist', 'avg', 'mse', 'emd', 'abs_max', 'min_max'
] ]
assert round_type in ['TiesToEven', 'TiesAwayFromZero'] assert round_type in ['adaround', 'round']
self._round_type = round_type self._round_type = round_type
assert weight_round_algo in ['adaround', 'round']
self._weight_round_algo = weight_round_algo
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._dynamic_quantize_op_type = ['lstm'] self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \ self._support_quantize_op_type = \
...@@ -414,7 +407,7 @@ class PostTrainingQuantization(object): ...@@ -414,7 +407,7 @@ class PostTrainingQuantization(object):
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold() self._calculate_kl_hist_threshold()
if self._weight_round_algo == 'adaround': if self._round_type == 'adaround':
self._adaround_apply() self._adaround_apply()
self._reset_activation_persistable() self._reset_activation_persistable()
...@@ -651,7 +644,6 @@ class PostTrainingQuantization(object): ...@@ -651,7 +644,6 @@ class PostTrainingQuantization(object):
float(np.max(np.abs(var_tensor[i])))) float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value self._quantized_threshold[var_name] = abs_max_value
_logger.info("MSE searching stage ...") _logger.info("MSE searching stage ...")
distribution = np.round if self._round_type == 'TiesToEven' else utils.round_c
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
...@@ -664,9 +656,14 @@ class PostTrainingQuantization(object): ...@@ -664,9 +656,14 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value scale = s * abs_max_value
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
if self._onnx_format:
quant_var = np.clip(distribution(var_tensor / scale * bins), quant_var = np.clip(distribution(var_tensor / scale * bins),
-bins - 1, bins) -bins - 1, bins)
quant_dequant_var = quant_var / bins * scale quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean() mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_calibration_loss[var_name]: if mse_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = mse_loss self._best_calibration_loss[var_name] = mse_loss
...@@ -691,7 +688,6 @@ class PostTrainingQuantization(object): ...@@ -691,7 +688,6 @@ class PostTrainingQuantization(object):
float(np.max(np.abs(var_tensor[i])))) float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value self._quantized_threshold[var_name] = abs_max_value
_logger.info("EMD searching stage ...") _logger.info("EMD searching stage ...")
distribution = np.round if self._round_type == 'TiesToEven' else utils.round_c
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
...@@ -704,9 +700,14 @@ class PostTrainingQuantization(object): ...@@ -704,9 +700,14 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value scale = s * abs_max_value
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
if self._onnx_format:
quant_var = np.clip(distribution(var_tensor / scale * bins), quant_var = np.clip(distribution(var_tensor / scale * bins),
-bins - 1, bins) -bins - 1, bins)
quant_dequant_var = quant_var / bins * scale quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
emd_loss = np.abs( emd_loss = np.abs(
np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs( np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
np.std(var_tensor) - np.std(quant_dequant_var)) np.std(var_tensor) - np.std(quant_dequant_var))
...@@ -918,8 +919,7 @@ class PostTrainingQuantization(object): ...@@ -918,8 +919,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=major_quantizable_op_types)
round_type=self._round_type)
else: else:
transform_pass = QuantizationTransformPassV2( transform_pass = QuantizationTransformPassV2(
scope=self._scope, scope=self._scope,
...@@ -928,8 +928,7 @@ class PostTrainingQuantization(object): ...@@ -928,8 +928,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=major_quantizable_op_types)
round_type=self._round_type)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so # Insert fake_quant/fake_dequantize op must in test graph, so
...@@ -946,15 +945,13 @@ class PostTrainingQuantization(object): ...@@ -946,15 +945,13 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=minor_quantizable_op_types)
round_type=self._round_type)
else: else:
add_quant_dequant_pass = AddQuantDequantPassV2( add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize, is_full_quantized=self._is_full_quantize)
round_type=self._round_type)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
...@@ -979,7 +976,6 @@ class PostTrainingQuantization(object): ...@@ -979,7 +976,6 @@ class PostTrainingQuantization(object):
place=self._place, place=self._place,
bias_correction=self._bias_correction, bias_correction=self._bias_correction,
weight_bits=self._weight_bits, weight_bits=self._weight_bits,
weight_round_algo=self._weight_round_algo,
round_type=self._round_type, round_type=self._round_type,
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
......
...@@ -119,7 +119,6 @@ class QuantizationTransformPass(object): ...@@ -119,7 +119,6 @@ class QuantizationTransformPass(object):
moving_rate=0.9, moving_rate=0.9,
skip_pattern=['skip_quant'], skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
round_type='TiesToEven',
weight_quantize_func=None, weight_quantize_func=None,
act_quantize_func=None, act_quantize_func=None,
weight_preprocess_func=None, weight_preprocess_func=None,
...@@ -157,10 +156,6 @@ class QuantizationTransformPass(object): ...@@ -157,10 +156,6 @@ class QuantizationTransformPass(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this. QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight. weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not. Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and In this function, user should both define quantization function and
...@@ -211,7 +206,6 @@ class QuantizationTransformPass(object): ...@@ -211,7 +206,6 @@ class QuantizationTransformPass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._round_type = round_type
self._weight_quantize_func = weight_quantize_func self._weight_quantize_func = weight_quantize_func
self._act_quantize_func = act_quantize_func self._act_quantize_func = act_quantize_func
self._weight_preprocess_func = weight_preprocess_func self._weight_preprocess_func = weight_preprocess_func
...@@ -465,12 +459,10 @@ class QuantizationTransformPass(object): ...@@ -465,12 +459,10 @@ class QuantizationTransformPass(object):
_init_var_node(scale_var_node, _init_var_node(scale_var_node,
np.zeros(scale_var_node.shape(), dtype=data_type), np.zeros(scale_var_node.shape(), dtype=data_type),
self._scope, self._place) self._scope, self._place)
round_type = 0 if self._round_type == 'TiesToEven' else 1
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={
'bit_length': quant_bits, 'bit_length': quant_bits,
'round_type': round_type,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node}, inputs={'X': var_node},
...@@ -525,11 +517,9 @@ class QuantizationTransformPass(object): ...@@ -525,11 +517,9 @@ class QuantizationTransformPass(object):
inputs['Iter'] = self._global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = { attrs = {
'window_size': self._window_size, 'window_size': self._window_size,
'bit_length': quant_bits, 'bit_length': quant_bits,
'round_type': round_type,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
} }
...@@ -600,10 +590,8 @@ class QuantizationTransformPass(object): ...@@ -600,10 +590,8 @@ class QuantizationTransformPass(object):
outs['OutState'] = state_out_node outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node outs['OutAccum'] = accum_out_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = { attrs = {
'bit_length': quant_bits, 'bit_length': quant_bits,
'round_type': round_type,
'moving_rate': self._moving_rate, 'moving_rate': self._moving_rate,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
...@@ -650,12 +638,10 @@ class QuantizationTransformPass(object): ...@@ -650,12 +638,10 @@ class QuantizationTransformPass(object):
_init_var_node(scale_var_node, _init_var_node(scale_var_node,
np.zeros(scale_var_node.shape(), dtype=data_type), np.zeros(scale_var_node.shape(), dtype=data_type),
self._scope, self._place) self._scope, self._place)
round_type = 0 if self._round_type == 'TiesToEven' else 1
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max', op_type='fake_channel_wise_quantize_abs_max',
attrs={ attrs={
'bit_length': quant_bits, 'bit_length': quant_bits,
'round_type': round_type,
'quant_axis': quant_axis, 'quant_axis': quant_axis,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
...@@ -949,8 +935,7 @@ class QuantizationFreezePass(object): ...@@ -949,8 +935,7 @@ class QuantizationFreezePass(object):
bias_correction=False, bias_correction=False,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
weight_round_algo='round', round_type='round',
round_type='TiesToEven',
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
quantizable_op_type=None): quantizable_op_type=None):
""" """
...@@ -968,14 +953,10 @@ class QuantizationFreezePass(object): ...@@ -968,14 +953,10 @@ class QuantizationFreezePass(object):
https://arxiv.org/abs/1810.05723. https://arxiv.org/abs/1810.05723.
weight_bits(int): quantization bit number for weights. weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
weight_round_algo(str, optional): The method of converting the quantized weights round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods. value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer. Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568. 'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained. since weights are fixed once the model is well trained.
...@@ -991,7 +972,6 @@ class QuantizationFreezePass(object): ...@@ -991,7 +972,6 @@ class QuantizationFreezePass(object):
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_round_algo = weight_round_algo
self._round_type = round_type self._round_type = round_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._fake_quant_op_names = _fake_quant_op_list self._fake_quant_op_names = _fake_quant_op_list
...@@ -1039,7 +1019,7 @@ class QuantizationFreezePass(object): ...@@ -1039,7 +1019,7 @@ class QuantizationFreezePass(object):
scale_v = scale_v.tolist() scale_v = scale_v.tolist()
self._quant_var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
# Quantize weight and restore # Quantize weight and restore
if self._weight_round_algo == 'round': if self._round_type == 'round':
param_v = self._load_var(input_arg_name) param_v = self._load_var(input_arg_name)
if any( if any(
_check_grandchild_op_node(op_node, op) _check_grandchild_op_node(op_node, op)
...@@ -1049,7 +1029,8 @@ class QuantizationFreezePass(object): ...@@ -1049,7 +1029,8 @@ class QuantizationFreezePass(object):
quant_axis = 0 quant_axis = 0
quantized_param_v = utils.quant_tensor( quantized_param_v = utils.quant_tensor(
param_v.copy(), scale_v, quant_axis, param_v.copy(), scale_v, quant_axis,
self._weight_bits, self._round_type) self._weight_bits)
quantized_param_v = np.round(quantized_param_v)
# Weight bias correction # Weight bias correction
if self._bias_correction == True: if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w( quantized_param_v = utils.bias_correction_w(
...@@ -1058,6 +1039,7 @@ class QuantizationFreezePass(object): ...@@ -1058,6 +1039,7 @@ class QuantizationFreezePass(object):
scale_v, scale_v,
quant_axis, quant_axis,
weight_bits=self._weight_bits) weight_bits=self._weight_bits)
quantized_param_v = np.round(quantized_param_v)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
...@@ -1600,8 +1582,7 @@ class AddQuantDequantPass(object): ...@@ -1600,8 +1582,7 @@ class AddQuantDequantPass(object):
quant_bits=8, quant_bits=8,
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False, is_full_quantized=False):
round_type='TiesToEven'):
""" """
Constructor. Constructor.
...@@ -1623,10 +1604,6 @@ class AddQuantDequantPass(object): ...@@ -1623,10 +1604,6 @@ class AddQuantDequantPass(object):
quantization to all supported quantizable op type. If set is_full_quantized quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input as False, only apply quantization to the op type according to the input
quantizable_op_type. quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
""" """
self._scope = scope self._scope = scope
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
...@@ -1634,7 +1611,6 @@ class AddQuantDequantPass(object): ...@@ -1634,7 +1611,6 @@ class AddQuantDequantPass(object):
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = None
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._round_type = round_type
if is_full_quantized: if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type self._quantizable_op_type = utils._act_supported_quantizable_op_type
...@@ -1769,10 +1745,8 @@ class AddQuantDequantPass(object): ...@@ -1769,10 +1745,8 @@ class AddQuantDequantPass(object):
outs['OutState'] = state_out_node outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node outs['OutAccum'] = accum_out_node
round_type = 0 if self._round_type == 'TiesToEven' else 1
attrs = { attrs = {
'bit_length': quant_bits, 'bit_length': quant_bits,
'round_type': round_type,
'moving_rate': self._moving_rate, 'moving_rate': self._moving_rate,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
...@@ -1812,10 +1786,6 @@ class InsertQuantizeLinear(object): ...@@ -1812,10 +1786,6 @@ class InsertQuantizeLinear(object):
Default is -1. Default is -1.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False. channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
is_test(bool, optional): Whether quantization with training or not. Default is True. is_test(bool, optional): Whether quantization with training or not. Default is True.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
""" """
def __init__(self, def __init__(self,
...@@ -1824,15 +1794,13 @@ class InsertQuantizeLinear(object): ...@@ -1824,15 +1794,13 @@ class InsertQuantizeLinear(object):
quant_bits=8, quant_bits=8,
quant_axis=-1, quant_axis=-1,
channel_wise=False, channel_wise=False,
is_test=True, is_test=True):
round_type='TiesToEven'):
self._place = place self._place = place
self._scope = scope self._scope = scope
self.quant_bits = quant_bits self.quant_bits = quant_bits
self.quant_axis = quant_axis self.quant_axis = quant_axis
self.channel_wise = channel_wise self.channel_wise = channel_wise
self._is_test = is_test self._is_test = is_test
self._round_type = round_type
def insert_quant_op(self, graph, var_node): def insert_quant_op(self, graph, var_node):
assert var_node.is_var(), '{} is not a var'.format(var_node.name()) assert var_node.is_var(), '{} is not a var'.format(var_node.name())
...@@ -1875,12 +1843,7 @@ class InsertQuantizeLinear(object): ...@@ -1875,12 +1843,7 @@ class InsertQuantizeLinear(object):
if zero_point_node is not None: if zero_point_node is not None:
inputs["ZeroPoint"] = zero_point_node inputs["ZeroPoint"] = zero_point_node
round_type = 0 if self._round_type == 'TiesToEven' else 1 attrs = {"quant_axis": self.quant_axis, "bit_length": self.quant_bits}
attrs = {
"quant_axis": self.quant_axis,
"bit_length": self.quant_bits,
"round_type": round_type
}
outputs = {"Y": quant_var_node} outputs = {"Y": quant_var_node}
if not self._is_test: if not self._is_test:
attrs["is_test"] = self._is_test attrs["is_test"] = self._is_test
...@@ -1985,7 +1948,6 @@ class QuantizationTransformPassV2(object): ...@@ -1985,7 +1948,6 @@ class QuantizationTransformPassV2(object):
moving_rate=0.9, moving_rate=0.9,
skip_pattern=['skip_quant'], skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
round_type='TiesToEven',
weight_quantize_func=None, weight_quantize_func=None,
act_quantize_func=None, act_quantize_func=None,
weight_preprocess_func=None, weight_preprocess_func=None,
...@@ -2021,10 +1983,6 @@ class QuantizationTransformPassV2(object): ...@@ -2021,10 +1983,6 @@ class QuantizationTransformPassV2(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this. QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight. weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not. Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and In this function, user should both define quantization function and
...@@ -2074,7 +2032,6 @@ class QuantizationTransformPassV2(object): ...@@ -2074,7 +2032,6 @@ class QuantizationTransformPassV2(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._round_type = round_type
self._weight_quantize_func = weight_quantize_func self._weight_quantize_func = weight_quantize_func
self._act_quantize_func = act_quantize_func self._act_quantize_func = act_quantize_func
self._weight_preprocess_func = weight_preprocess_func self._weight_preprocess_func = weight_preprocess_func
...@@ -2198,8 +2155,7 @@ class QuantizationTransformPassV2(object): ...@@ -2198,8 +2155,7 @@ class QuantizationTransformPassV2(object):
quant_bits=quant_bits, quant_bits=quant_bits,
quant_axis=quant_axis, quant_axis=quant_axis,
channel_wise=channel_wise, channel_wise=channel_wise,
is_test=self._is_test, is_test=self._is_test)
round_type=self._round_type)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, var_node) graph, var_node)
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
...@@ -2307,8 +2263,7 @@ class AddQuantDequantPassV2(object): ...@@ -2307,8 +2263,7 @@ class AddQuantDequantPassV2(object):
quant_bits=8, quant_bits=8,
skip_pattern=["skip_quant"], skip_pattern=["skip_quant"],
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False, is_full_quantized=False):
round_type='TiesToEven'):
""" """
Args: Args:
scope(paddle.Scope): The scope is used to initialize these new parameters. scope(paddle.Scope): The scope is used to initialize these new parameters.
...@@ -2328,10 +2283,6 @@ class AddQuantDequantPassV2(object): ...@@ -2328,10 +2283,6 @@ class AddQuantDequantPassV2(object):
quantization to all supported quantizable op type. If set is_full_quantized quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input as False, only apply quantization to the op type according to the input
quantizable_op_type. quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -2354,7 +2305,6 @@ class AddQuantDequantPassV2(object): ...@@ -2354,7 +2305,6 @@ class AddQuantDequantPassV2(object):
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = None
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self._round_type = round_type
if is_full_quantized: if is_full_quantized:
self._quantizable_op_type = utils._act_supported_quantizable_op_type self._quantizable_op_type = utils._act_supported_quantizable_op_type
...@@ -2427,8 +2377,7 @@ class AddQuantDequantPassV2(object): ...@@ -2427,8 +2377,7 @@ class AddQuantDequantPassV2(object):
quant_bits=self._quant_bits, quant_bits=self._quant_bits,
quant_axis=-1, quant_axis=-1,
channel_wise=False, channel_wise=False,
is_test=self._is_test, is_test=self._is_test)
round_type=self._round_type)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node) graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
...@@ -2511,8 +2460,6 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2511,8 +2460,6 @@ class ReplaceFakeQuantDequantPass(object):
"quant_axis") else -1 "quant_axis") else -1
bit_length = op.op().attr("bit_length") if op.op().has_attr( bit_length = op.op().attr("bit_length") if op.op().has_attr(
"bit_length") else 8 "bit_length") else 8
round_type = op.op().attr("round_type") if op.op().has_attr(
"round_type") else 0
zero_point_node = None zero_point_node = None
quanted_node = x_node quanted_node = x_node
...@@ -2534,8 +2481,7 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2534,8 +2481,7 @@ class ReplaceFakeQuantDequantPass(object):
quant_op_node = graph.create_op_node(op_type="quantize_linear", quant_op_node = graph.create_op_node(op_type="quantize_linear",
attrs={ attrs={
"quant_axis": quant_axis, "quant_axis": quant_axis,
"bit_length": bit_length, "bit_length": bit_length
"round_type": round_type
}, },
inputs={ inputs={
"X": x_node, "X": x_node,
...@@ -2654,11 +2600,11 @@ class QuantWeightPass(object): ...@@ -2654,11 +2600,11 @@ class QuantWeightPass(object):
param_v = self._load_var(x_node.name()) param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis") quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length") bits_length = _op.op().attr("bit_length")
round_type = _op.op().attr("round_type") if _op.op().has_attr( quantized_param_v = utils.quant_tensor(param_v.copy(),
"round_type") else 0 scale_v,
quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v, quant_axis,
quant_axis, bits_length, bits_length,
round_type) onnx_format=True)
if self._bias_correction == True: if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w( quantized_param_v = utils.bias_correction_w(
param_v, param_v,
......
...@@ -321,39 +321,41 @@ def set_variable_data(scope, place, var_name, np_value): ...@@ -321,39 +321,41 @@ def set_variable_data(scope, place, var_name, np_value):
tensor.set(np_value, place) tensor.set(np_value, place)
def round_c_single_element(val): def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
dtype = type(val) # symmetry quant
if val >= 0: def _clip(x, scale):
return dtype(np.floor(val + 0.5)) x[x > scale] = scale
return dtype(np.ceil(val - 0.5)) x[x < -scale] = -scale
return x
# rounding to nearest ties away from zero
round_c = np.vectorize(round_c_single_element)
def quant_tensor(x,
scale,
quant_axis=0,
weight_bits=8,
round_type='TiesToEven'):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
distribution = np.round if round_type == 'TiesToEven' else round_c
bnt = (1 << (weight_bits - 1)) - 1 bnt = (1 << (weight_bits - 1)) - 1
if isinstance(scale, list): if isinstance(scale, list):
for i, s in enumerate(scale): for i, s in enumerate(scale):
if s == 0.0: if s == 0.0:
s = 1e-8 s = 1e-8
if quant_axis == 0: if quant_axis == 0:
x[i] = distribution(x[i] / s * bnt) if onnx_format:
x[i] = np.round(x[i] / s * bnt)
x[i] = np.clip(x[i], -bnt - 1, bnt) x[i] = np.clip(x[i], -bnt - 1, bnt)
else: else:
x[:, i] = distribution(x[:, i] / s * bnt) x[i] = _clip(x[i], s)
x[i] = x[i] / s * bnt
else:
if onnx_format:
x[:, i] = np.round(x[:, i] / s * bnt)
x[:, i] = np.clip(x[:, i], -bnt - 1, bnt) x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
else:
x[:, i] = _clip(x[:, i], s)
x[:, i] = x[:, i] / s * bnt
else: else:
scale = 1e-8 if scale == 0.0 else scale scale = 1e-8 if scale == 0.0 else scale
x = distribution(x / scale * bnt) if onnx_format:
x = np.round(x / scale * bnt)
x = np.clip(x, -bnt - 1, bnt) x = np.clip(x, -bnt - 1, bnt)
else:
x = _clip(x, scale)
x = x / scale * bnt
return x return x
......
...@@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path, model_path,
data_path, data_path,
algo="KL", algo="KL",
weight_round_algo="round", round_type="round",
quantizable_op_type=["conv2d"], quantizable_op_type=["conv2d"],
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False, is_use_cache_file=False,
...@@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=batch_nums, batch_nums=batch_nums,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
weight_round_algo=weight_round_algo, round_type=round_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
onnx_format=onnx_format, onnx_format=onnx_format,
...@@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
...@@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start post training quantization for {0} on {1} samples ...". print("Start post training quantization for {0} on {1} samples ...".
format(model_name, quant_iterations)) format(model_name, quant_iterations))
self.generate_quantized_model(fp32_model_path, data_path, algo, self.generate_quantized_model(fp32_model_path, data_path, algo,
weight_round_algo, quantizable_op_type, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_full_quantize, is_use_cache_file,
is_optimize_model, quant_iterations, is_optimize_model, quant_iterations,
onnx_format) onnx_format)
...@@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): ...@@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7" data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "avg" algo = "avg"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["mul", "lstm"] quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): ...@@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
infer_iterations = 100 infer_iterations = 100
quant_iterations = 10 quant_iterations = 10
self.run_test(model_name, model_url, model_md5, data_name, data_url, self.run_test(model_name, model_url, model_md5, data_name, data_url,
data_md5, algo, weight_round_algo, quantizable_op_type, data_md5, algo, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model, is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, infer_iterations, quant_iterations) diff_threshold, infer_iterations, quant_iterations)
...@@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): ...@@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz" data_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5 = "add84c754e9b792fea1fbd728d134ab7" data_md5 = "add84c754e9b792fea1fbd728d134ab7"
algo = "avg" algo = "avg"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["mul", "lstm"] quantizable_op_type = ["mul", "lstm"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): ...@@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
......
...@@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self, def generate_quantized_model(self,
model_path, model_path,
algo="KL", algo="KL",
weight_round_algo="round", round_type="round",
quantizable_op_type=["conv2d"], quantizable_op_type=["conv2d"],
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False, is_use_cache_file=False,
...@@ -130,7 +130,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -130,7 +130,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums=batch_nums, batch_nums=batch_nums,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
weight_round_algo=weight_round_algo, round_type=round_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
bias_correction=bias_correction, bias_correction=bias_correction,
...@@ -145,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -145,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
...@@ -169,11 +169,10 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -169,11 +169,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size)) format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(origin_model_path, algo, self.generate_quantized_model(origin_model_path, algo, round_type,
weight_round_algo, quantizable_op_type, quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file, is_use_cache_file, is_optimize_model,
is_optimize_model, batch_size, batch_size, quant_iterations, onnx_format,
quant_iterations, onnx_format,
skip_tensor_list, bias_correction) skip_tensor_list, bias_correction)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
...@@ -204,7 +203,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): ...@@ -204,7 +203,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "KL" algo = "KL"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -213,7 +212,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): ...@@ -213,7 +212,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -226,7 +225,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): ...@@ -226,7 +225,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "hist" algo = "hist"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -235,7 +234,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): ...@@ -235,7 +234,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -248,7 +247,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): ...@@ -248,7 +247,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse" algo = "mse"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -257,7 +256,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): ...@@ -257,7 +256,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -270,7 +269,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): ...@@ -270,7 +269,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "emd" algo = "emd"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -279,7 +278,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): ...@@ -279,7 +278,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -292,7 +291,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): ...@@ -292,7 +291,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg" algo = "avg"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -301,7 +300,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): ...@@ -301,7 +300,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -314,7 +313,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): ...@@ -314,7 +313,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "abs_max" algo = "abs_max"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "mul"] quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = True is_full_quantize = True
is_use_cache_file = False is_use_cache_file = False
...@@ -323,7 +322,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): ...@@ -323,7 +322,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 10 quant_iterations = 10
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -336,7 +335,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): ...@@ -336,7 +335,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse" algo = "mse"
weight_round_algo = "adaround" round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -350,7 +349,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): ...@@ -350,7 +349,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
...@@ -369,7 +368,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): ...@@ -369,7 +368,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "KL" algo = "KL"
weight_round_algo = "adaround" round_type = "adaround"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -378,7 +377,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): ...@@ -378,7 +377,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, weight_round_algo, self.run_test(model_name, data_url, data_md5, algo, round_type,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold, batch_size, is_optimize_model, diff_threshold, batch_size,
infer_iterations, quant_iterations) infer_iterations, quant_iterations)
...@@ -391,7 +390,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): ...@@ -391,7 +390,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse" algo = "mse"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -405,7 +404,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): ...@@ -405,7 +404,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
...@@ -425,7 +424,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( ...@@ -425,7 +424,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse" algo = "mse"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = True is_full_quantize = True
is_use_cache_file = False is_use_cache_file = False
...@@ -439,7 +438,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( ...@@ -439,7 +438,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
...@@ -458,7 +457,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): ...@@ -458,7 +457,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c" data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg" algo = "avg"
weight_round_algo = "round" round_type = "round"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"] quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -472,7 +471,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): ...@@ -472,7 +471,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url, data_url,
data_md5, data_md5,
algo, algo,
weight_round_algo, round_type,
quantizable_op_type, quantizable_op_type,
is_full_quantize, is_full_quantize,
is_use_cache_file, is_use_cache_file,
......
...@@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path, model_path,
quantizable_op_type, quantizable_op_type,
algo="KL", algo="KL",
weight_round_algo="round", round_type="round",
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False, is_use_cache_file=False,
is_optimize_model=False, is_optimize_model=False,
...@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir=model_path, model_dir=model_path,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
weight_round_algo=weight_round_algo, round_type=round_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
onnx_format=onnx_format, onnx_format=onnx_format,
...@@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def run_test(self, def run_test(self,
model, model,
algo, algo,
weight_round_algo, round_type,
data_urls, data_urls,
data_md5s, data_md5s,
quantizable_op_type, quantizable_op_type,
...@@ -299,10 +299,9 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -299,10 +299,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model", self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo, quantizable_op_type, algo, round_type,
weight_round_algo, is_full_quantize, is_full_quantize, is_use_cache_file,
is_use_cache_file, is_optimize_model, is_optimize_model, onnx_format)
onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
...@@ -330,7 +329,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -330,7 +329,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self): def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "KL" algo = "KL"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
...@@ -345,7 +344,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -345,7 +344,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = True is_optimize_model = True
diff_threshold = 0.025 diff_threshold = 0.025
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold) is_optimize_model, diff_threshold)
...@@ -355,7 +354,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): ...@@ -355,7 +354,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self): def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "avg" algo = "avg"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
...@@ -369,7 +368,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): ...@@ -369,7 +368,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = True is_optimize_model = True
diff_threshold = 0.025 diff_threshold = 0.025
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold) is_optimize_model, diff_threshold)
...@@ -379,7 +378,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): ...@@ -379,7 +378,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self): def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "hist" algo = "hist"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
...@@ -393,7 +392,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): ...@@ -393,7 +392,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = True is_optimize_model = True
diff_threshold = 0.03 diff_threshold = 0.03
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold) is_optimize_model, diff_threshold)
...@@ -403,7 +402,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ...@@ -403,7 +402,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self): def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "abs_max" algo = "abs_max"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
...@@ -417,7 +416,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ...@@ -417,7 +416,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = False is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger # The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0.05 diff_threshold = 0.05
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold) is_optimize_model, diff_threshold)
...@@ -427,7 +426,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -427,7 +426,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_onnx_format_mobilenetv1(self): def test_post_training_onnx_format_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "avg" algo = "avg"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
...@@ -444,7 +443,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -444,7 +443,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
diff_threshold = 0.05 diff_threshold = 0.05
self.run_test(model, self.run_test(model,
algo, algo,
weight_round_algo, round_type,
data_urls, data_urls,
data_md5s, data_md5s,
quantizable_op_type, quantizable_op_type,
......
...@@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "min_max" algo = "min_max"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
] ]
...@@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = False is_optimize_model = False
diff_threshold = 0.025 diff_threshold = 0.025
self.run_test(model, algo, weight_round_algo, data_urls, data_md5s, self.run_test(model, algo, round_type, data_urls, data_md5s,
quantizable_op_type, is_full_quantize, is_use_cache_file, quantizable_op_type, is_full_quantize, is_use_cache_file,
is_optimize_model, diff_threshold) is_optimize_model, diff_threshold)
...@@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): ...@@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "min_max" algo = "min_max"
weight_round_algo = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
] ]
...@@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): ...@@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
onnx_format = True onnx_format = True
self.run_test(model, self.run_test(model,
algo, algo,
weight_round_algo, round_type,
data_urls, data_urls,
data_md5s, data_md5s,
quantizable_op_type, quantizable_op_type,
......
...@@ -49,7 +49,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest): ...@@ -49,7 +49,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
dtype, dtype,
input_shape, input_shape,
distribution, distribution,
round_type='TiesToEven'): round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
scale = np.max(np.abs(input_data)) scale = np.max(np.abs(input_data))
...@@ -58,12 +58,12 @@ class TestFakeQuantizeAbsMaxOp(OpTest): ...@@ -58,12 +58,12 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
if round_type == 'TiesToEven': if round_type == 'TiesToEven':
round_out = np.round( round_out = np.round(
input_data.astype(compute_type) * inv_scale * bnt) input_data.astype(compute_type) * inv_scale * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0 self.attrs['round_type'] = 0
else: else:
round_out = round_c( output_data = round_c(
input_data.astype(compute_type) * inv_scale * bnt) input_data.astype(compute_type) * inv_scale * bnt)
self.attrs['round_type'] = 1 self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt)
self.inputs = {'X': input_data} self.inputs = {'X': input_data}
self.outputs = {'Out': output_data, 'OutScale': scale} self.outputs = {'Out': output_data, 'OutScale': scale}
self.dtype = dtype self.dtype = dtype
...@@ -75,7 +75,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest): ...@@ -75,7 +75,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
def test_fake_quantize_abs_max_round1(self): def test_fake_quantize_abs_max_round1(self):
self._fake_quantize_abs_max(np.float32, (124, 240), self._fake_quantize_abs_max(np.float32, (124, 240),
np.random.random, np.random.random,
round_type='TiesAwayFromZero') round_type='TiesToEven')
def test_fake_quantize_abs_max_float16(self): def test_fake_quantize_abs_max_float16(self):
self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random) self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
...@@ -110,12 +110,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): ...@@ -110,12 +110,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
if round_type == 'TiesToEven': if round_type == 'TiesToEven':
round_out = np.round( round_out = np.round(
input_data.astype(compute_type) / scale_broadcast * bnt) input_data.astype(compute_type) / scale_broadcast * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0 self.attrs['round_type'] = 0
else: else:
round_out = round_c( output_data = round_c(bnt * input_data.astype(compute_type) /
input_data.astype(compute_type) / scale_broadcast * bnt) scale_broadcast)
self.attrs['round_type'] = 1 self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt)
if quant_axis == 1: if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast, scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis) (1, ) + compute_axis)
...@@ -169,11 +169,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): ...@@ -169,11 +169,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
round_out = np.round( round_out = np.round(
input_data.astype(compute_type) / out_scale[0] * bnt) input_data.astype(compute_type) / out_scale[0] * bnt)
self.attrs['round_type'] = 0 self.attrs['round_type'] = 0
output_data = np.clip(round_out, -bnt - 1, bnt)
else: else:
round_out = round_c( if is_test:
input_data.astype(compute_type) / out_scale[0] * bnt) clip_data = np.clip(input_data, -in_scale, in_scale)
else:
clip_data = input_data
output_data = round_c(
clip_data.astype(compute_type) / out_scale[0] * bnt)
self.attrs['round_type'] = 1 self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt)
self.inputs = { self.inputs = {
'X': input_data, 'X': input_data,
'Iter': np.zeros(1).astype(np.int64), 'Iter': np.zeros(1).astype(np.int64),
...@@ -250,7 +254,7 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): ...@@ -250,7 +254,7 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
distribution, distribution,
dequantize=False, dequantize=False,
with_gradient=False, with_gradient=False,
round_type='TiesToEven'): round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
...@@ -267,12 +271,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): ...@@ -267,12 +271,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
if round_type == 'TiesToEven': if round_type == 'TiesToEven':
round_out = np.round( round_out = np.round(
input_data.astype(compute_type) / out_scale * bnt) input_data.astype(compute_type) / out_scale * bnt)
quant_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0 self.attrs['round_type'] = 0
else: else:
round_out = round_c( quant_data = round_c(
input_data.astype(compute_type) / out_scale * bnt) input_data.astype(compute_type) / out_scale * bnt)
self.attrs['round_type'] = 1 self.attrs['round_type'] = 1
quant_data = np.clip(round_out, -bnt - 1, bnt)
if dequantize: if dequantize:
output_data = (quant_data * out_scale / bnt).astype(dtype) output_data = (quant_data * out_scale / bnt).astype(dtype)
self.op_type = 'fake_quantize_dequantize_moving_average_abs_max' self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
...@@ -307,10 +311,9 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): ...@@ -307,10 +311,9 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
np.random.random) np.random.random)
def test_fake_quantize_moving_average_abs_max_round1(self): def test_fake_quantize_moving_average_abs_max_round1(self):
self._fake_quantize_moving_average_abs_max( self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
np.float32, (8, 16, 7, 7),
np.random.random, np.random.random,
round_type='TiesAwayFromZero') round_type='TiesToEven')
def test_fake_quantize_dequantize_moving_average_abs_max(self): def test_fake_quantize_dequantize_moving_average_abs_max(self):
self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7), self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
...@@ -329,17 +332,17 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): ...@@ -329,17 +332,17 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
dtype, dtype,
input_shape, input_shape,
distribution, distribution,
round_type='TiesToEven'): round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
scale = np.max(np.abs(input_data)).astype(dtype) scale = np.max(np.abs(input_data)).astype(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
if round_type == 'TiesToEven': if round_type == 'TiesToEven':
round_out = np.round(input_data / scale * bnt) round_out = np.round(input_data / scale * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
self.attrs['round_type'] = 0 self.attrs['round_type'] = 0
else: else:
round_out = round_c(input_data / scale * bnt) output_data = round_c(input_data / scale * bnt) * scale / bnt
self.attrs['round_type'] = 1 self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
self.inputs = {'X': input_data} self.inputs = {'X': input_data}
self.outputs = { self.outputs = {
'Out': output_data, 'Out': output_data,
...@@ -357,7 +360,7 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): ...@@ -357,7 +360,7 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
def test_fake_quantize_dequantize_abs_max_round1(self): def test_fake_quantize_dequantize_abs_max_round1(self):
self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), self._fake_quantize_dequantize_abs_max(np.float32, (124, 240),
np.random.random, np.random.random,
round_type='TiesAwayFromZero') round_type='TiesToEven')
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
...@@ -382,11 +385,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): ...@@ -382,11 +385,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
if round_type == 'TiesToEven': if round_type == 'TiesToEven':
round_out = np.round(bnt * output_data / scale_broadcast) round_out = np.round(bnt * output_data / scale_broadcast)
output_data = np.clip(round_out, -bnt - 1,
bnt) * scale_broadcast / bnt
self.attrs['round_type'] = 0 self.attrs['round_type'] = 0
else: else:
round_out = round_c(bnt * output_data / scale_broadcast) output_data = round_c(
bnt * output_data / scale_broadcast) * scale_broadcast / bnt
self.attrs['round_type'] = 1 self.attrs['round_type'] = 1
output_data = np.clip(round_out, -bnt - 1, bnt) * scale_broadcast / bnt
if quant_axis == 1: if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast, scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis) (1, ) + compute_axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册