未验证 提交 3f816bc8 编写于 作者: C cc 提交者: GitHub

[Quantization] Conv2d_transpose and mul support channnelwise quantization (#25639)

* Conv2d_transpose and mul support channnelwise quantization, test=develop
* Skip collecting out threshold for output tensor of which the type is not fp32 or fp64, test=develop
* Fix error in test_user_defined_quantization, test=develop
* Add depthwise_conv_bn_fuse, test=develop
* Add conv_transpose_bn_fuse_pass for post_training_quant, test=develop
上级 2101dfd2
...@@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass, ...@@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass,
paddle::framework::ir::ConvTransposeBNFusePass); paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass, REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass); paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
...@@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { ...@@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
std::string conv_type() const { return "conv2d_transpose"; } std::string conv_type() const { return "conv2d_transpose"; }
}; };
class DepthwiseConvBNFusePass : public ConvBNFusePass {
public:
std::string conv_type() const { return "depthwise_conv2d"; }
};
class DepthwiseConvEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public:
std::string conv_type() const { return "depthwise_conv2d"; }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -37,11 +37,16 @@ template <typename T> ...@@ -37,11 +37,16 @@ template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> { struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx, void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales, const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) { const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
if (scale_num == 1) { if (scale_num == 1) {
const int channel = in->dims()[0]; // Dequant op is before quantized op
// Dequantize the weight of quantized op
auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis];
const T* scale_factor = scales[0]->data<T>(); const T* scale_factor = scales[0]->data<T>();
for (int i = 0; i < channel; i++) { if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i]; T s = scale_factor[i];
framework::Tensor one_channel_in = in->Slice(i, i + 1); framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
...@@ -50,7 +55,31 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -50,7 +55,31 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
auto& dev = *dev_ctx.eigen_device(); auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range; out_e.device(dev) = in_e * s / max_range;
} }
} else if (quant_axis == 1) {
int64_t out_iter = 1;
for (int i = 0; i < quant_axis; i++) {
out_iter *= in_dims[i];
}
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range;
++cur_in;
++cur_out;
}
}
}
}
} else if (scale_num == 2) { } else if (scale_num == 2) {
// Dequant op is after quantized op
// Dequantize the output tensor of quantized op
int batch_size = in->dims()[0]; int batch_size = in->dims()[0];
int channel = in->dims()[1]; int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>(); const T* scale_one = scales[0]->data<T>();
...@@ -157,6 +186,18 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker ...@@ -157,6 +186,18 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
"Quantization bit numbers in quantization stage. " "Quantization bit numbers in quantization stage. "
"The size of `quant_bits` should be equal to the size of `Scales`.") "The size of `quant_bits` should be equal to the size of `Scales`.")
.SetDefault({8}); .SetDefault({8});
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddComment(R"DOC( AddComment(R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator. FakeChannelWiseDequantizeMaxAbsOp operator.
......
...@@ -45,8 +45,9 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> { ...@@ -45,8 +45,9 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
}; };
template <typename T> template <typename T>
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, __global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale,
int num, int channel, T* out) { T max_range, int num, int channel,
T* out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = num / channel; int channel_size = num / channel;
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
...@@ -56,6 +57,23 @@ __global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, ...@@ -56,6 +57,23 @@ __global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
} }
} }
template <typename T>
__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
T max_range, const int num,
const int cin, const int cout,
T* out) {
int cout_wh_size = num / cin;
int wh_size = cout_wh_size / cout;
T s = scale[blockIdx.x];
const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size;
T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size;
for (int i = 0; i < wh_size; i++) {
out_current[i] = in_current[i] * s / max_range;
}
}
template <typename T> template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one, __global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num, const T* scale_two, T max_range, int num,
...@@ -74,18 +92,29 @@ template <typename T> ...@@ -74,18 +92,29 @@ template <typename T>
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> { struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx, void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales, const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) { const int scale_num, T max_range, const int quant_axis,
framework::Tensor* out) {
auto in_dims = in->dims();
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) { if (scale_num == 1) {
int num = in->numel(); int num = in->numel();
int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>(); const T* scale_factor = scales[0]->data<T>();
if (quant_axis == 0) {
int grid = in_dims[0];
int block = 1024; int block = 1024;
int grid = channel; DequantizeOneScaleQuantAxis0<T><<<grid, block, 0, dev_ctx.stream()>>>(
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>( in_data, scale_factor, max_range, num, in_dims[0], out_data);
in_data, scale_factor, max_range, num, channel, out_data); } else if (quant_axis == 1) {
// Dequantize weight of Cin * Cout * W * H
int grid = in_dims[1];
int block = in_dims[0];
DequantizeOneScaleQuantAxis1<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, in_dims[0], in_dims[1],
out_data);
}
} else if (scale_num == 2) { } else if (scale_num == 2) {
// Not need to consider quant_axis
int num = in->numel(); int num = in->numel();
int batch_size = in->dims()[0]; int batch_size = in->dims()[0];
int channel = in->dims()[1]; int channel = in->dims()[1];
......
...@@ -33,7 +33,7 @@ template <typename DeviceContext, typename T> ...@@ -33,7 +33,7 @@ template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor { struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num, const framework::Tensor** scales, const int scale_num,
T max_range, framework::Tensor* out); T max_range, const int quant_axis, framework::Tensor* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -63,6 +63,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -63,6 +63,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits"); auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
auto quant_axis = ctx.Attr<int>("quant_axis");
int max_range = 1; int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
...@@ -70,12 +71,12 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -70,12 +71,12 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
int scale_num = scales.size(); int scale_num = scales.size();
if (scale_num == 1) { if (scale_num == 1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[0], scales[0]->numel(), in->dims()[quant_axis],
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with " "The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only " "quant_axis dimension value of Input(X) when the `Scales` has "
"one element, but %ld != %ld here.", "only one element, but %ld != %ld here.",
scales[0]->numel(), in->dims()[0])); scales[0]->numel(), in->dims()[quant_axis]));
max_range *= (std::pow(2, quant_bits[0] - 1) - 1); max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) { } else if (scale_num == 2) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -94,7 +95,8 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> { ...@@ -94,7 +95,8 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
(std::pow(2, quant_bits[1] - 1) - 1); (std::pow(2, quant_bits[1] - 1) - 1);
} }
ChannelDequantizeFunctor<DeviceContext, T>()( ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out); dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range),
quant_axis, out);
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
...@@ -39,13 +40,41 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -39,13 +40,41 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in, void operator()(const platform::CPUDeviceContext& ctx,
const int num, const int channel, T* out) { const framework::Tensor& in_tensor, const int quant_axis,
const int channel_size = num / channel; T* out_abs_max) {
for (int i = 0; i < channel; i++) { // At present, channelwise quantization supports conv2d, depthwise_conv2d
auto* start = in + i * channel_size; // conv2d_transpose and mul
auto* end = in + (i + 1) * channel_size; PADDLE_ENFORCE_EQ(
out[i] = std::abs(*(std::max_element(start, end, Compare<T>()))); quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* in_data = in_tensor.data<T>();
auto in_dims = in_tensor.dims();
const int64_t channel = in_dims[quant_axis];
if (quant_axis == 0) {
const int64_t channel_size = in_tensor.numel() / channel;
for (int64_t i = 0; i < channel; i++) {
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
out_abs_max[i] =
std::abs(*(std::max_element(start, end, Compare<T>())));
}
} else if (quant_axis == 1) {
for (int64_t i = 0; i < channel; i++) {
out_abs_max[i] = 0;
}
const int64_t step_i = in_tensor.numel() / in_dims[0];
const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]);
for (int64_t i = 0; i < in_dims[0]; i++) {
for (int64_t j = 0; j < in_dims[1]; j++) {
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
T abs_max = std::abs(*(std::max_element(start, end, Compare<T>())));
out_abs_max[j] = std::max(out_abs_max[j], abs_max);
}
}
} }
} }
}; };
...@@ -92,27 +121,54 @@ template <typename T> ...@@ -92,27 +121,54 @@ template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel, const int bin_cnt, const int quant_axis,
framework::Tensor* out) { framework::Tensor* out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* scale_data = scale.data<T>(); auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>(); auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto* out_data = out->mutable_data<T>(ctx.GetPlace());
const int channel_size = in.numel() / channel; auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
for (int i = 0; i < channel; i++) { if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel;
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
auto* start = in_data + i * channel_size; auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size, trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s)); ClipFunctor<T>(-s, s));
} }
for (int i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
T inv_s = inverse(s); T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
} }
} else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
for (int i = 0; i < in_dims[0]; i++) {
for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j];
T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
}
}
}
}
} }
}; };
...@@ -247,8 +303,9 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -247,8 +303,9 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
"FakeChannelWiseQuantizeAbsMax"); "FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeChannelWiseQuantizeAbsMax"); "FakeChannelWiseQuantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -269,6 +326,18 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -269,6 +326,18 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"(Tensor) Output of quantized low level tensor, " "(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."); "but also saved as float data type.");
AddOutput("OutScale", "(Tensor) Current channel wise scale"); AddOutput("OutScale", "(Tensor) Current channel wise scale");
AddAttr<int>("quant_axis",
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int& quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int& bit_length) {
......
...@@ -75,8 +75,8 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -75,8 +75,8 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>; template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
T* out) { const int c, T* out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
...@@ -100,14 +100,69 @@ __global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, ...@@ -100,14 +100,69 @@ __global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
} }
} }
template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const int cin, const int cout,
T* out) {
extern __shared__ T shared_max_data[];
int cout_wh_size = n / cin;
int wh_size = n / (cin * cout);
int tid = threadIdx.x;
int bid = blockIdx.x;
const T* in_current = in + tid * cout_wh_size + bid * wh_size;
shared_max_data[tid] = T(0);
for (int i = 0; i < wh_size; i++) {
T tmp = fabs(in_current[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
}
}
__syncthreads();
int len = blockDim.x;
for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) {
if (tid < i && tid + i < len &&
shared_max_data[tid] < shared_max_data[tid + i]) {
shared_max_data[tid] = shared_max_data[tid + i];
}
if (i == 1) {
i = 0; // break the loop
}
__syncthreads();
}
if (tid == 0) {
out[bid] = shared_max_data[0];
}
}
template <typename T> template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in, void operator()(const platform::CUDADeviceContext& ctx,
const int num, const int channel, T* out) { const framework::Tensor& in_tensor, const int quant_axis,
int block = 1024; T* out_abs_max) {
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
const int num = in_tensor.numel();
auto in_dims = in_tensor.dims();
int channel = in_dims[quant_axis];
const T* in_data = in_tensor.data<T>();
if (quant_axis == 0) {
int grid = channel; int grid = channel;
FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>( int block = 1024;
in, num, channel, out); FindChannelAbsMaxKernelQuantAxis0<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, channel, out_abs_max);
} else if (quant_axis == 1) {
int grid = in_dims[1];
int block = in_dims[0];
FindChannelAbsMaxKernelQuantAxis1<
T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, in_dims[0], in_dims[1], out_abs_max);
}
} }
}; };
...@@ -189,10 +244,12 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { ...@@ -189,10 +244,12 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
float>; float>;
// ChannelClipAndQuantKernel for quant_axis is 0
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const int bin_cnt, const int n, const int bin_cnt,
const int c, T* out) { const int n, const int c,
T* out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
...@@ -211,22 +268,57 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, ...@@ -211,22 +268,57 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
} }
} }
// ChannelClipAndQuantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale,
const int bin_cnt,
const int n, const int cin,
const int cout, T* out) {
T s = scale[blockIdx.x % cout];
T inv_s = inverse(s);
int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v);
}
}
template <typename T> template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel, const int bin_cnt, const int quant_axis,
framework::Tensor* out) { framework::Tensor* out) {
int num = in.numel(); PADDLE_ENFORCE_EQ(
int block = 1024; quant_axis == 0 || quant_axis == 1, true,
int grid = channel; platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
int num = in.numel();
auto in_dims = in.dims();
const T* in_data = in.data<T>(); const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T* out_data = out->mutable_data<T>(ctx.GetPlace());
ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>( if (quant_axis == 0) {
in_data, scale_data, bin_cnt, num, channel, out_data); int grid = in_dims[0];
int block = 1024;
ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
} else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1];
int block = 1024;
ChannelClipAndQuantKernelQuantAxis1<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
}
} }
}; };
......
...@@ -61,15 +61,15 @@ struct FindRangeAbsMaxFunctor { ...@@ -61,15 +61,15 @@ struct FindRangeAbsMaxFunctor {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor { struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor,
const int channel, T* out); const int quant_axis, T* out_abs_max);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor { struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor& scale, const int bin_cnt,
const int channel, framework::Tensor* out); const int quant_axis, framework::Tensor* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -144,12 +144,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -144,12 +144,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
FindChannelAbsMaxFunctor<DeviceContext, T>()( FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis,
dev_ctx, in->data<T>(), in->numel(), in->dims()[0], out_scale_data); out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out); dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out);
} }
}; };
......
...@@ -29,6 +29,7 @@ from .quantization_pass import _out_scale_op_list ...@@ -29,6 +29,7 @@ from .quantization_pass import _out_scale_op_list
from .quantization_pass import _get_op_input_var_names from .quantization_pass import _get_op_input_var_names
from .quantization_pass import _get_op_output_var_names from .quantization_pass import _get_op_output_var_names
from .quantization_pass import _get_output_name_index from .quantization_pass import _get_output_name_index
from .quantization_pass import _channelwise_quant_axis1_ops
__all__ = ['PostTrainingQuantization', 'WeightQuantization'] __all__ = ['PostTrainingQuantization', 'WeightQuantization']
...@@ -316,6 +317,7 @@ class PostTrainingQuantization(object): ...@@ -316,6 +317,7 @@ class PostTrainingQuantization(object):
self._out_scale_op_list = _out_scale_op_list self._out_scale_op_list = _out_scale_op_list
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self.weight_op_pairs = {}
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_kl_threshold = {} self._quantized_var_kl_threshold = {}
self._quantized_var_min = {} self._quantized_var_min = {}
...@@ -436,6 +438,8 @@ class PostTrainingQuantization(object): ...@@ -436,6 +438,8 @@ class PostTrainingQuantization(object):
graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = IrGraph(core.Graph(self._program.desc), for_test=True)
graph = _remove_ctrl_vars(graph) graph = _remove_ctrl_vars(graph)
graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass') graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
graph = _apply_pass(self._scope, graph, 'depthwise_conv_bn_fuse_pass')
graph = _apply_pass(self._scope, graph, 'conv_transpose_bn_fuse_pass')
self._program = graph.to_program() self._program = graph.to_program()
def _collect_target_varnames(self): def _collect_target_varnames(self):
...@@ -446,10 +450,11 @@ class PostTrainingQuantization(object): ...@@ -446,10 +450,11 @@ class PostTrainingQuantization(object):
# TODO(juncaipeng), consider the name_scope of skip_quant # TODO(juncaipeng), consider the name_scope of skip_quant
_logger.info("Collect quantized variable names ...") _logger.info("Collect quantized variable names ...")
def collect_var_name(var_name_list, persistable_var_names): def collect_var_name(var_name_list, persistable_var_names, op_type):
for var_name in var_name_list: for var_name in var_name_list:
if var_name in persistable_var_names: if var_name in persistable_var_names:
self._quantized_weight_var_name.add(var_name) self._quantized_weight_var_name.add(var_name)
self.weight_op_pairs[var_name] = op_type
else: else:
self._quantized_act_var_name.add(var_name) self._quantized_act_var_name.add(var_name)
...@@ -462,13 +467,15 @@ class PostTrainingQuantization(object): ...@@ -462,13 +467,15 @@ class PostTrainingQuantization(object):
# For quantized ops, sample inputs and outputs # For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
collect_var_name( collect_var_name(
_get_op_input_var_names(op), persistable_var_names) _get_op_input_var_names(op), persistable_var_names, op_type)
collect_var_name( collect_var_name(
_get_op_output_var_names(op), persistable_var_names) _get_op_output_var_names(op), persistable_var_names,
op_type)
# For other op, only sample output scale # For other op, only sample output scale
elif op_type in self._out_scale_op_list: elif op_type in self._out_scale_op_list:
collect_var_name( collect_var_name(
_get_op_output_var_names(op), persistable_var_names) _get_op_output_var_names(op), persistable_var_names,
op_type)
def _set_activation_persistable(self): def _set_activation_persistable(self):
''' '''
...@@ -492,35 +499,65 @@ class PostTrainingQuantization(object): ...@@ -492,35 +499,65 @@ class PostTrainingQuantization(object):
Sample the input threshold(min, max, or abs_max) in every iterations. Sample the input threshold(min, max, or abs_max) in every iterations.
''' '''
assert self._algo in ["abs_max", "min_max"], \ assert self._algo in ["abs_max", "min_max"], \
"The algo should be abs_max or min_max to sample min max value." "The algo should be abs_max or min_max for _sample_threshold."
if self._algo == "abs_max": if self._algo == "abs_max":
self._sample_threshold_abs_max()
elif self._algo == "min_max":
self._sample_threshold_min_max()
def _sample_threshold_abs_max(self):
assert self._algo == "abs_max", \
"The algo should be abs_max for _sample_threshold_abs_max."
# Only calculate abs_max value for weight for once # Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}: if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
abs_max_per_channel = [] if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self.weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]): for i in range(var_tensor.shape[0]):
abs_max_per_channel.append( abs_max_value.append(
float(np.max(np.abs(var_tensor[i])))) float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_per_channel self._quantized_var_abs_max[var_name] = abs_max_value
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \ if (var_name not in self._quantized_var_abs_max) or \
(abs_max_value > self._quantized_var_abs_max[var_name]): (abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value self._quantized_var_abs_max[var_name] = abs_max_value
elif self._algo == "min_max":
def _sample_threshold_min_max(self):
assert self._algo == "min_max", \
"The algo should be min_max for _sample_threshold_min_max."
if self._quantized_var_min == {} and self._quantized_var_max == {}: if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
min_per_channel = [] if self._weight_quantize_type == "abs_max":
max_per_channle = [] min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
elif self._weight_quantize_type == "channel_wise_abs_max":
min_value = []
max_value = []
if self.weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
min_value.append(float(np.min(var_tensor[:, i])))
max_value.append(float(np.max(var_tensor[:, i])))
else:
for i in range(var_tensor.shape[0]): for i in range(var_tensor.shape[0]):
min_per_channel.append(float(np.min(var_tensor[i]))) min_value.append(float(np.min(var_tensor[i])))
max_per_channle.append(float(np.max(var_tensor[i]))) max_value.append(float(np.max(var_tensor[i])))
self._quantized_var_min[var_name] = min_per_channel self._quantized_var_min[var_name] = min_value
self._quantized_var_max[var_name] = max_per_channle self._quantized_var_max[var_name] = max_value
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
...@@ -554,11 +591,6 @@ class PostTrainingQuantization(object): ...@@ -554,11 +591,6 @@ class PostTrainingQuantization(object):
applied in every iteration. applied in every iteration.
''' '''
assert self._algo == "KL", "The algo should be KL to sample data." assert self._algo == "KL", "The algo should be KL to sample data."
for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data:
var_tensor = _load_variable_data(self._scope, var_name)
self._sampling_data[var_name] = var_tensor
if self._is_use_cache_file: if self._is_use_cache_file:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
...@@ -584,15 +616,20 @@ class PostTrainingQuantization(object): ...@@ -584,15 +616,20 @@ class PostTrainingQuantization(object):
# Abs_max threshold for weights # Abs_max threshold for weights
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
weight_data = self._sampling_data[var_name] weight_data = _load_variable_data(self._scope, var_name)
weight_threshold = None
if self._weight_quantize_type == "abs_max": if self._weight_quantize_type == "abs_max":
weight_threshold = np.max(np.abs(weight_data)) weight_threshold = float(np.max(np.abs(weight_data)))
elif self._weight_quantize_type == "channel_wise_abs_max": elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = [] weight_threshold = []
if self.weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(weight_data.shape[1]):
weight_threshold.append(
float(np.max(np.abs(weight_data[:, i]))))
else:
for i in range(weight_data.shape[0]): for i in range(weight_data.shape[0]):
abs_max_value = np.max(np.abs(weight_data[i])) weight_threshold.append(
weight_threshold.append(abs_max_value) float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold self._quantized_var_kl_threshold[var_name] = weight_threshold
# KL threshold for activations # KL threshold for activations
......
...@@ -111,6 +111,10 @@ _op_real_in_out_name = { ...@@ -111,6 +111,10 @@ _op_real_in_out_name = {
"scale": [["X"], ["Out"]], "scale": [["X"], ["Out"]],
} }
_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
_channelwise_quant_axis1_ops = ['conv2d_transpose', 'mul']
def _get_op_input_var_names(op): def _get_op_input_var_names(op):
""" """ """ """
...@@ -185,10 +189,24 @@ def _is_input_all_not_persistable(graph, op_node): ...@@ -185,10 +189,24 @@ def _is_input_all_not_persistable(graph, op_node):
return is_input_all_not_persistable return is_input_all_not_persistable
def _check_grandchild_op_node(op_node, grandchild_op_name):
'''
Check whether the fake_quant node has a grandchild op node named
grandchild_op_name.
'''
for out1_var_node in op_node.outputs:
for out1_op_node in out1_var_node.outputs:
for out2_var_node in out1_op_node.outputs:
for out2_op_node in out2_var_node.outputs:
if out2_op_node.name() == grandchild_op_name:
return True
return False
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
""" """
Quantize the ops that have weights. Add quant and dequant ops for the quantized Quantize the ops that have weights. Add quant and dequant ops for
ops's inputs. the quantized ops's inputs.
""" """
_supported_quantizable_op_type = [ _supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul' 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul'
...@@ -311,8 +329,8 @@ class QuantizationTransformPass(object): ...@@ -311,8 +329,8 @@ class QuantizationTransformPass(object):
if weight_quantize_type not in quant_type: if weight_quantize_type not in quant_type:
raise ValueError( raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be " "Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'." "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' "
% (str(weight_quantize_type))) "or 'moving_average_abs_max'." % (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
...@@ -323,7 +341,6 @@ class QuantizationTransformPass(object): ...@@ -323,7 +341,6 @@ class QuantizationTransformPass(object):
for op in self._quantizable_ops: for op in self._quantizable_ops:
assert op in QuantizationTransformPass._supported_quantizable_op_type, \ assert op in QuantizationTransformPass._supported_quantizable_op_type, \
op + " is not supported for quantization." op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
...@@ -356,10 +373,12 @@ class QuantizationTransformPass(object): ...@@ -356,10 +373,12 @@ class QuantizationTransformPass(object):
user_skipped = False user_skipped = False
if isinstance(self._skip_pattern, list): if isinstance(self._skip_pattern, list):
user_skipped = op_node.op().has_attr("op_namescope") and \ user_skipped = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) any(pattern in op_node.op().attr("op_namescope") \
for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str): elif isinstance(self._skip_pattern, str):
user_skipped = op_node.op().has_attr("op_namescope") and \ user_skipped = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 op_node.op().attr("op_namescope").find(
self._skip_pattern) != -1
if user_skipped: if user_skipped:
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
...@@ -373,15 +392,11 @@ class QuantizationTransformPass(object): ...@@ -373,15 +392,11 @@ class QuantizationTransformPass(object):
if var_node.name() in dequantized_vars: if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
else: else:
name = var_node.name() name = var_node.name()
if name in processed_vars: if name in processed_vars:
continue continue
is_weight = True if var_node.name() in persistable_vars \
if var_node.name() in persistable_vars: else False
is_weight = True
else:
is_weight = False
# if var node is weight and weight_preprocess_func is not None, # if var node is weight and weight_preprocess_func is not None,
# will insert weight preprocess func # will insert weight preprocess func
...@@ -415,20 +430,14 @@ class QuantizationTransformPass(object): ...@@ -415,20 +430,14 @@ class QuantizationTransformPass(object):
else self._activation_bits else self._activation_bits
quant_type = self._weight_quantize_type if is_weight \ quant_type = self._weight_quantize_type if is_weight \
else self._activation_quantize_type else self._activation_quantize_type
if quant_type == 'channel_wise_abs_max': if quant_type == 'channel_wise_abs_max': # Weight quantization
assert is_weight, "'channel_wise_abs_max' can only be applied on weights." quant_axis = 1 if op.name() in \
if op.name() in self._conv_ops: _channelwise_quant_axis1_ops else 0
quant_var_node, scale_var_node = self._insert_channel_quant_op( quant_var_node, scale_var_node = self._insert_channel_quant_op(
graph, var_node, name, quant_bits) graph, var_node, name, quant_bits, quant_axis)
dequant_var_node = self._insert_channel_dequant_op( dequant_var_node = self._insert_channel_dequant_op(
graph, quant_var_node, [scale_var_node], graph, quant_var_node, [scale_var_node],
[quant_bits]) [quant_bits], quant_axis)
else:
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, name, quant_bits, 'abs_max')
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node,
quant_bits)
else: else:
quant_var_node, scale_var_node = self._insert_quant_op( quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, name, quant_bits, quant_type) graph, var_node, name, quant_bits, quant_type)
...@@ -529,11 +538,19 @@ class QuantizationTransformPass(object): ...@@ -529,11 +538,19 @@ class QuantizationTransformPass(object):
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node( scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(name), name=self._quantized_scale_name(name),
var_type=var_node.type(), var_type=var_node.type(),
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_var_node,
np.zeros(
scale_var_node.shape(), dtype=data_type),
self._scope,
self._place)
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={
...@@ -706,7 +723,8 @@ class QuantizationTransformPass(object): ...@@ -706,7 +723,8 @@ class QuantizationTransformPass(object):
return quant_var_node, scale_out_node return quant_var_node, scale_out_node
def _insert_channel_quant_op(self, graph, var_node, name, quant_bits): def _insert_channel_quant_op(self, graph, var_node, name, quant_bits,
quant_axis):
""" """
Insert fake_channel_wise_quantize_abs_max op in the graph. Insert fake_channel_wise_quantize_abs_max op in the graph.
""" """
...@@ -717,15 +735,24 @@ class QuantizationTransformPass(object): ...@@ -717,15 +735,24 @@ class QuantizationTransformPass(object):
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node( scale_var_node = graph.create_persistable_node(
name=self._quantized_scale_name(name), name=self._quantized_scale_name(name),
var_type=var_node.type(), var_type=var_node.type(),
shape=[var_node.shape()[0]], shape=[var_node.shape()[quant_axis]],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_var_node,
np.zeros(
scale_var_node.shape(), dtype=data_type),
self._scope,
self._place)
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max', op_type='fake_channel_wise_quantize_abs_max',
attrs={ attrs={
'bit_length': quant_bits, 'bit_length': quant_bits,
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node}, inputs={'X': var_node},
...@@ -763,7 +790,7 @@ class QuantizationTransformPass(object): ...@@ -763,7 +790,7 @@ class QuantizationTransformPass(object):
return dequant_var_node return dequant_var_node
def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes,
quant_bits): quant_bits, quant_axis):
""" """
Insert fake_channel_wise_dequantize_max_abs in the graph. Insert fake_channel_wise_dequantize_max_abs in the graph.
""" """
...@@ -778,6 +805,7 @@ class QuantizationTransformPass(object): ...@@ -778,6 +805,7 @@ class QuantizationTransformPass(object):
op_type='fake_channel_wise_dequantize_max_abs', op_type='fake_channel_wise_dequantize_max_abs',
attrs={ attrs={
'quant_bits': quant_bits, 'quant_bits': quant_bits,
'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node, inputs={'X': var_node,
...@@ -1036,7 +1064,6 @@ class QuantizationFreezePass(object): ...@@ -1036,7 +1064,6 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose']
self._fake_quant_op_names = _fake_quant_op_list self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list self._fake_dequant_op_names = _fake_dequant_op_list
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
...@@ -1063,34 +1090,37 @@ class QuantizationFreezePass(object): ...@@ -1063,34 +1090,37 @@ class QuantizationFreezePass(object):
if input_arg_name in graph.out_node_mapping_table.keys(): if input_arg_name in graph.out_node_mapping_table.keys():
input_arg_name = graph.out_node_mapping_table[ input_arg_name = graph.out_node_mapping_table[
input_arg_name] input_arg_name]
if input_arg_name in persistable_vars: if input_arg_name not in persistable_vars:
if self._weight_quantize_type == 'abs_max': scale_v = graph._find_node_by_name(
param = self._load_var(input_arg_name) op_node.outputs, op_node.output('OutScale')[0])
scale_v = np.max(np.abs(param)) self._quant_var_scale_map[input_arg_name] = scale_v
elif self._weight_quantize_type == 'channel_wise_abs_max':
param = self._load_var(input_arg_name)
if len(param.shape) == 4: # conv2d or depthwise_conv2d
scale_v = []
for i in range(param.shape[0]):
scale_v.append(np.max(np.abs(param[i])))
else: else:
scale_v = np.max(np.abs(param)) # Obtain scale from OutScale var node
scale_v = self._load_var(op_node.output('OutScale')[0])
assert scale_v.ndim in [
1, 2
], "the dim of scale_v should be 1 or 2"
if scale_v.ndim == 2:
scale_v = scale_v[0]
if scale_v.size == 1:
scale_v = scale_v[0]
else: else:
scale_v = self._load_var( scale_v = scale_v.tolist()
op_node.output('OutScale')[0])[0]
self._quant_var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
self._remove_fake_quant_and_dequant_op(graph, op_node) # Quantize weight and restore
# quantize weight and restore
param_v = self._load_var(input_arg_name) param_v = self._load_var(input_arg_name)
quantized_param_v = self._quant(param_v, scale_v, if isinstance(scale_v, list) and \
self._weight_bits) any(_check_grandchild_op_node(op_node, op)
self._restore_var(input_arg_name, quantized_param_v) for op in _channelwise_quant_axis1_ops):
quant_axis = 1
else: else:
scale_v = graph._find_node_by_name( quant_axis = 0
op_node.outputs, op_node.output('OutScale')[0]) quantized_param_v = self._quant(
self._quant_var_scale_map[input_arg_name] = scale_v param_v, scale_v, self._weight_bits, quant_axis)
self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Remove all fake dequant op # Remove all fake dequant op
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
...@@ -1103,8 +1133,7 @@ class QuantizationFreezePass(object): ...@@ -1103,8 +1133,7 @@ class QuantizationFreezePass(object):
op_node_desc = op_node.op() op_node_desc = op_node.op()
if op_node_desc.has_attr("quantization_type") and \ if op_node_desc.has_attr("quantization_type") and \
op_node_desc.attr("quantization_type") == "qat_with_weight": op_node_desc.attr("quantization_type") == "qat_with_weight":
if self._weight_quantize_type == 'channel_wise_abs_max' \ if self._weight_quantize_type == 'channel_wise_abs_max':
and op_node.name() in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node) self._insert_post_channel_dequant_op(graph, op_node)
else: else:
self._insert_post_dequant_op(graph, op_node) self._insert_post_dequant_op(graph, op_node)
...@@ -1295,10 +1324,15 @@ class QuantizationFreezePass(object): ...@@ -1295,10 +1324,15 @@ class QuantizationFreezePass(object):
return isinstance(v, float) or isinstance(v, np.float32) \ return isinstance(v, float) or isinstance(v, np.float32) \
or isinstance(v, np.float64) or isinstance(v, np.float64)
def _quant(self, x, scale, num_bits): def _quant(self, x, scale, num_bits, quant_axis):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.'
if isinstance(scale, list): if isinstance(scale, list):
for i, s in enumerate(scale): for i, s in enumerate(scale):
if quant_axis == 0:
x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1))
else:
x[:, i] = np.round(x[:, i] / s * (
(1 << (num_bits - 1)) - 1))
return x return x
else: else:
return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
...@@ -1468,6 +1502,10 @@ class OutScaleForTrainingPass(object): ...@@ -1468,6 +1502,10 @@ class OutScaleForTrainingPass(object):
for op in target_ops: for op in target_ops:
for output_var_name in _get_op_output_var_names(op): for output_var_name in _get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs, output_var_name) in_node = graph._find_node_by_name(op.outputs, output_var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -1570,17 +1608,26 @@ class OutScaleForInferencePass(object): ...@@ -1570,17 +1608,26 @@ class OutScaleForInferencePass(object):
if op_node.name() in self._teller_set: if op_node.name() in self._teller_set:
var_names = _get_op_output_var_names(op_node) var_names = _get_op_output_var_names(op_node)
for var_name in var_names: for var_name in var_names:
# For compatibility, we save output threshold by two methods. in_node = graph._find_node_by_name(op_node.outputs,
var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
scale_name = self._scale_name(var_name) scale_name = self._scale_name(var_name)
scale_v = np.array( scale_var = self._scope.find_var(scale_name)
self._scope.find_var(scale_name).get_tensor())[0] assert scale_var is not None, \
op_node.op()._set_attr("out_threshold", float(scale_v)) "Can not find {} variable in the scope".format(scale_name)
scale_value = np.array(scale_var.get_tensor())[0]
# For compatibility, we save output threshold by two methods.
op_node.op()._set_attr("out_threshold", float(scale_value))
argname_index = _get_output_name_index(op_node, var_name) argname_index = _get_output_name_index(op_node, var_name)
assert argname_index is not None, \ assert argname_index is not None, \
var_name + " is not the output of the op" var_name + " is not the output of the op"
op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \
+ "_threshold", float(scale_v)) + "_threshold", float(scale_value))
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
......
...@@ -33,34 +33,29 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0" ...@@ -33,34 +33,29 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
def residual_block(img, label, num=1): def conv_net(img, label):
def conv_bn_layer(input, conv_pool_1 = fluid.nets.simple_img_conv_pool(
ch_out, input=img,
filter_size, filter_size=5,
stride, num_filters=20,
padding, pool_size=2,
act='relu', pool_stride=2,
bias_attr=False): pool_type='max',
tmp = fluid.layers.conv2d( act="relu")
input=input, conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
filter_size=filter_size, conv_pool_2 = fluid.nets.simple_img_conv_pool(
num_filters=ch_out, input=conv_pool_1,
stride=stride, filter_size=5,
padding=padding, num_filters=50,
use_cudnn=False, pool_size=2,
act=None, pool_stride=2,
bias_attr=bias_attr) pool_type='avg',
return fluid.layers.batch_norm(input=tmp, act=act) act="relu")
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
hidden = img prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
for _ in six.moves.xrange(num): loss = fluid.layers.cross_entropy(input=prediction, label=label)
conv = conv_bn_layer(hidden, 20, 3, 1, 1, act=None, bias_attr=True) avg_loss = fluid.layers.mean(loss)
short = conv_bn_layer(hidden, 20, 1, 1, 0, act=None) return avg_loss
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
def pact(x, name=None): def pact(x, name=None):
...@@ -102,7 +97,7 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -102,7 +97,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
img.stop_gradient = False img.stop_gradient = False
label = fluid.layers.data( label = fluid.layers.data(
name='label', shape=[1], dtype='int64') name='label', shape=[1], dtype='int64')
loss = residual_block(img, label, 1) loss = conv_net(img, label)
if not is_test: if not is_test:
opt = fluid.optimizer.SGD(learning_rate=0.0001) opt = fluid.optimizer.SGD(learning_rate=0.0001)
opt.minimize(loss) opt.minimize(loss)
......
...@@ -31,45 +31,45 @@ def dequantize_max_abs(x, scale, max_range): ...@@ -31,45 +31,45 @@ def dequantize_max_abs(x, scale, max_range):
return y return y
def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False): def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
scales = [] scales = []
if not use_second_dim:
for i in range(x.shape[0]):
scales.append(np.max(np.abs(x[i])).astype("float32"))
y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1
for i, scale in enumerate(scales):
y[i] = np.round(x[i] / scale * max_range)
else:
for i in range(x.shape[0]):
s = []
for j in range(x.shape[1]):
s.append(np.max(np.abs(x[i][j])).astype("float32"))
scales.append(s)
scales = np.amax(np.array(scales), axis=0)
y = x.copy() y = x.copy()
max_range = math.pow(2, quant_bit - 1) - 1 max_range = math.pow(2, quant_bit - 1) - 1
if quant_axis == 0:
for i in range(x.shape[0]): for i in range(x.shape[0]):
for j, scale in enumerate(scales): scale = np.max(np.abs(x[i])).astype("float32")
y[i][j] = np.round(x[i][j] / scale * max_range) scales.append(scale)
y[i] = np.round(x[i] * max_range / scale)
elif quant_axis == 1:
for i in range(x.shape[1]):
scale = np.max(np.abs(x[:, i])).astype("float32")
scales.append(scale)
y[:, i] = np.round(x[:, i] * max_range / scale)
return y, scales return y, scales
def channel_wise_dequantize_max_abs(x, def channel_wise_dequantize_max_abs(x,
scales, scales,
quant_bits, quant_bits,
quant_axis,
activation_scale=None): activation_scale=None):
if activation_scale is None: assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
y = x.copy()
for i in range(x.shape[0]): if isinstance(quant_bits, list):
y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i] max_range = math.pow(2, quant_bits[0] - 1) - 1
else: else:
max_range = math.pow(2, quant_bits - 1) - 1
y = x.copy() y = x.copy()
if quant_axis == 0:
for i in range(x.shape[0]): for i in range(x.shape[0]):
for j in range(x.shape[1]): y[i] = x[i] * scales[i] / max_range
y[i][j] = (scales[j] / elif quant_axis == 1:
(math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j] for i in range(x.shape[1]):
y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) y[:, i] = x[:, i] * scales[i] / max_range
if activation_scale is not None:
y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1)
return y return y
...@@ -83,9 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): ...@@ -83,9 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self.set_args() self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs" self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type) x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs( yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1)
x, self.quant_bits[0], use_second_dim=True) ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1,
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.activation_scale) self.activation_scale)
self.inputs = { self.inputs = {
...@@ -105,25 +104,39 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): ...@@ -105,25 +104,39 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
def set_args(self): def set_args(self):
self.quant_bits = [8] self.quant_bits = [8]
self.data_type = "float32" self.data_type = "float32"
self.quant_axis = 0
def setUp(self): def setUp(self):
self.set_args() self.set_args()
self.op_type = "fake_channel_wise_dequantize_max_abs" self.op_type = "fake_channel_wise_dequantize_max_abs"
x = np.random.randn(4, 3, 64, 64).astype(self.data_type) x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0],
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits) self.quant_axis)
ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits,
self.quant_axis)
self.inputs = { self.inputs = {
'X': yq, 'X': yq,
'Scales': [("scales0", np.array(scales).astype(self.data_type))] 'Scales': [("scales0", np.array(scales).astype(self.data_type))]
} }
self.attrs = {'quant_bits': self.quant_bits} self.attrs = {
'quant_bits': self.quant_bits,
'quant_axis': self.quant_axis
}
self.outputs = {'Out': ydq} self.outputs = {'Out': ydq}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale):
def set_args(self):
self.quant_bits = [8]
self.data_type = "float32"
self.quant_axis = 1
class TestFakeDequantizeMaxAbsOp(OpTest): class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self): def set_args(self):
self.num_bits = 8 self.num_bits = 8
......
...@@ -72,28 +72,62 @@ class TestFakeQuantizeOp2(OpTest): ...@@ -72,28 +72,62 @@ class TestFakeQuantizeOp2(OpTest):
class TestFakeChannelWiseQuantizeOp(OpTest): class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self): def setUp(self):
self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
self.op_type = "fake_channel_wise_quantize_abs_max" self.op_type = "fake_channel_wise_quantize_abs_max"
self.attrs = {'bit_length': 8} self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}
self.inputs = {
'X': np.random.random((4, 3, 64, 64)).astype("float32"),
}
scales = [] scales = []
for i in range(self.inputs['X'].shape[0]):
scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32"))
outputs = self.inputs['X'].copy() outputs = self.inputs['X'].copy()
for i, scale in enumerate(scales): bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
outputs[i] = np.round(outputs[i] / scale * ( if self.quant_axis == 0:
(1 << (self.attrs['bit_length'] - 1)) - 1)) for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
scales.append(scale_v)
outputs[i] = np.round(outputs[i] / scale_v * bnt)
elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
"float32")
scales.append(scale_v)
outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt)
self.outputs = { self.outputs = {
'Out': outputs, 'Out': outputs,
'OutScale': np.array(scales).astype("float32"), 'OutScale': np.array(scales).astype("float32"),
} }
def set_arg(self):
self.quant_axis = 0
self.inputs = {
'X': np.random.random((20, 15, 6, 6)).astype("float32"),
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 1
self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype("float32"),
}
class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestFakeQuantizeRangeAbsMaxOp(OpTest): class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fake_quantize_range_abs_max" self.op_type = "fake_quantize_range_abs_max"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册