提交 8965819f 编写于 作者: Z Zhen Wang

rewrite the cuda kernels of channel_wise_quant_op and channe_wise_dequant_op. test=develop

上级 ec88b6cc
......@@ -33,8 +33,51 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
}
};
template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
if (scale_num == 1) {
const int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
for (int i = 0; i < channel; i++) {
T s = scale_factor[i];
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s / max_range) * in_e;
}
} else if (scale_num == 2) {
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
for (int i = 0; i < batch_size; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int j = 0; j < channel; j++) {
T s = scale_one[j];
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s * scale_two[0] / max_range) * in_e;
}
}
}
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CPUDeviceContext, double>;
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
......
......@@ -44,8 +44,66 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
}
};
template <typename T>
__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range,
int num, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale[blockIdx.x] / max_range;
}
}
template <typename T>
__global__ void DequantizeTwoScale(const T* in, const T* scale_one,
const T* scale_two, T max_range, int num,
int batch_size, int channel, T* out) {
int tid = threadIdx.x;
int channel_size = num / (batch_size * channel);
int scale_index = blockIdx.x % channel;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
for (int i = tid; i < channel_size; i += blockDim.x) {
out_c[i] = in_c[i] * scale_one[scale_index] * scale_two[0] / max_range;
}
}
template <typename T>
struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor** scales,
const int scale_num, T max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
if (scale_num == 1) {
int num = in->numel();
int channel = in->dims()[0];
const T* scale_factor = scales[0]->data<T>();
int block = 1024;
int grid = channel;
DequantizeOneScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, channel, out_data);
} else if (scale_num == 2) {
int num = in->numel();
int batch_size = in->dims()[0];
int channel = in->dims()[1];
const T* scale_one = scales[0]->data<T>();
const T* scale_two = scales[1]->data<T>();
int block = 1024;
int grid = batch_size * channel;
DequantizeTwoScale<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_one, scale_two, max_range, num, batch_size, channel,
out_data);
}
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, float>;
template struct ChannelDequantizeFunctor<platform::CUDADeviceContext, double>;
} // namespace operators
} // namespace paddle
......
......@@ -29,6 +29,13 @@ struct DequantizeFunctor {
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
......@@ -56,50 +63,32 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto* out = ctx.Output<framework::Tensor>("Out");
auto quant_bits = ctx.Attr<std::vector<int>>("quant_bits");
int max_range = std::pow(2, quant_bits[0] - 1) - 1;
int max_range = 1;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
auto dequant = DequantizeFunctor<DeviceContext, T>();
if (scales.size() == 1) {
int scale_num = scales.size();
if (scale_num == 1) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only one "
"element.");
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
} else if (scales.size() == 2) {
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements.");
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize(
framework::slice_ddim(in->dims(), 1, in->dims().size()));
framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize(
framework::slice_ddim(out->dims(), 1, out->dims().size()));
for (int64_t j = 0; j < in->dims()[1]; j++) {
framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1);
framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1);
framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1);
dequant(dev_ctx, &one_channel_in, &one_channel_scale,
static_cast<T>(max_range), &one_channel_out);
}
}
PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1,
"The second scale tensor should only have one value at now.");
max_range = std::pow(2, quant_bits[1] - 1) - 1;
dequant(dev_ctx, out, scales[1], static_cast<T>(max_range), out);
max_range *= (std::pow(2, quant_bits[0] - 1) - 1) *
(std::pow(2, quant_bits[1] - 1) - 1);
}
ChannelDequantizeFunctor<DeviceContext, T>()(
dev_ctx, in, scales.data(), scale_num, static_cast<T>(max_range), out);
}
};
......
......@@ -37,6 +37,21 @@ struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
const int channel_size = num / channel;
for (int i = 0; i < channel; i++) {
auto* start = in + i * channel_size;
auto* end = in + (i + 1) * channel_size;
out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
}
}
};
template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
......@@ -53,6 +68,36 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
framework::Tensor* out) {
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
const int channel_size = in.numel() / channel;
platform::Transform<platform::CPUDeviceContext> trans;
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
}
}
};
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>;
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
......
......@@ -74,6 +74,45 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
extern __shared__ T shared_max_data[];
shared_max_data[tid] = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = fabs(in_c[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
shared_max_data[tid] = shared_max_data[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[blockIdx.x] = shared_max_data[0];
}
}
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in,
const int num, const int channel, T* out) {
int block = 1024;
int grid = channel;
FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
in, num, channel, out);
}
};
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, T* out) {
......@@ -82,14 +121,76 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[bid];
T x = in[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out[bid] = round(v);
out[i] = round(v);
}
}
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n,
const int c, T* out) {
int tid = threadIdx.x;
int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out_c[i] = round(v);
}
}
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int channel,
framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = channel;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, channel, out_data);
}
};
template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
float>;
template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
const T* last_scale,
......@@ -182,26 +283,6 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
float>;
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
} // namespace operators
} // namespace paddle
......
......@@ -42,6 +42,19 @@ struct FindRangeAbsMaxFunctor {
framework::Tensor* scales_arr, framework::Tensor* out_scale);
};
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num,
const int channel, T* out);
};
template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int channel, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum,
......@@ -86,21 +99,10 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto find_abs_max = FindAbsMaxFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel = in->Slice(i, i + 1);
const T* one_channel_data = one_channel.data<T>();
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scale_data[i]);
}
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scale->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out);
}
FindChannelAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, in->data<T>(), in->numel(), in->dims()[0], out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out);
}
};
......
......@@ -576,8 +576,6 @@ class QuantizationFreezePass(object):
elif self._weight_quantize_type == 'channel_wise_abs_max':
param = self._load_var(input_arg_name)
if len(param.shape) == 4: # conv2d or depthwise_conv2d
print('DEBUG**************************: %s' %
input_arg_name)
scale_v = []
for i in range(param.shape[0]):
scale_v.append(np.max(np.abs(param[i])))
......
......@@ -127,7 +127,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type, for_ci=False):
def linear_fc_quant(self, activation_quant_type, for_ci=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -140,14 +140,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=place,
activation_quantize_type=quant_type)
activation_quantize_type=activation_quant_type)
transform_pass.apply(graph)
if not for_ci:
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
graph.draw('.', 'quantize_fc_' + activation_quant_type,
marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
......@@ -156,7 +157,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
val_graph.draw('.', 'val_fc_' + activation_quant_type,
val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.linear_fc_quant('abs_max', for_ci=True)
......@@ -167,7 +169,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True)
def residual_block_quant(self, quant_type, for_ci=False):
def residual_block_quant(self, activation_quant_type, for_ci=False):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -180,14 +182,15 @@ class TestQuantizationTransformPass(unittest.TestCase):
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
place=place,
activation_quantize_type=quant_type)
activation_quantize_type=activation_quant_type)
transform_pass.apply(graph)
if not for_ci:
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
graph.draw('.', 'quantize_residual_' + activation_quant_type,
marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
......@@ -196,7 +199,8 @@ class TestQuantizationTransformPass(unittest.TestCase):
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
val_graph.draw('.', 'val_residual_' + activation_quant_type,
val_marked_nodes)
def test_residual_block_abs_max(self):
self.residual_block_quant('abs_max', for_ci=True)
......@@ -209,7 +213,12 @@ class TestQuantizationTransformPass(unittest.TestCase):
class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type, for_ci=False):
def freeze_graph(self,
use_cuda,
seed,
activation_quant_type,
weight_quant_type='abs_max',
for_ci=False):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
......@@ -245,10 +254,10 @@ class TestQuantizationFreezePass(unittest.TestCase):
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=quant_type,
weight_quantize_type='channel_wise_abs_max')
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=quant_type)
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
......@@ -257,12 +266,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
main_graph.draw('.', 'main' + dev_name + activation_quant_type + '_'
+ weight_quant_type, marked_nodes)
marked_nodes = set()
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
test_graph.draw('.', 'test' + dev_name + activation_quant_type + '_'
+ weight_quant_type, marked_nodes)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
......@@ -287,8 +298,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(data),
fetch_list=[loss])
if not for_ci:
print('{}: {}'.format('loss' + dev_name + quant_type,
loss_v))
print('{}: {}'.format('loss' + dev_name +
activation_quant_type + '_' +
weight_quant_type, loss_v))
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
......@@ -302,9 +314,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type='channel_wise_abs_max')
scope=scope, place=place, weight_quantize_type=weight_quant_type)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
if not for_ci:
......@@ -312,7 +322,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
test_graph.draw('.', 'test_freeze' + dev_name +
activation_quant_type + '_' + weight_quant_type,
marked_nodes)
server_program = test_graph.to_program()
......@@ -322,18 +333,20 @@ class TestQuantizationFreezePass(unittest.TestCase):
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
if not for_ci:
print('{}: {}'.format('test_loss1' + dev_name + quant_type,
test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type,
test_loss2))
print(
'{}: {}'.format('test_loss1' + dev_name + activation_quant_type
+ '_' + weight_quant_type, test_loss1))
print(
'{}: {}'.format('test_loss2' + dev_name + activation_quant_type
+ '_' + weight_quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
if not for_ci:
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type,
np.sum(w_quant)))
print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
+ '_' + weight_quant_type, np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + activation_quant_type +
'_' + weight_quant_type, np.sum(w_quant)))
# Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
......@@ -343,26 +356,28 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type,
marked_nodes)
test_graph.draw('.', 'test_int8' + dev_name + activation_quant_type
+ '_' + weight_quant_type, marked_nodes)
server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file.
with fluid.scope_guard(scope):
fluid.io.save_inference_model('server_int8' + dev_name + quant_type,
['image', 'label'], [loss], exe,
server_program_int8)
fluid.io.save_inference_model(
'server_int8' + dev_name + activation_quant_type + '_' +
weight_quant_type, ['image', 'label'], [loss], exe,
server_program_int8)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model(
'server_int8' + dev_name + quant_type, exe)
'server_int8' + dev_name + activation_quant_type + '_' +
weight_quant_type, exe)
# Check the loaded 8-bit weight.
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
if not for_ci:
print('{}: {}'.format('w_8bit' + dev_name + quant_type,
np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
print('{}: {}'.format('w_8bit' + dev_name + activation_quant_type +
'_' + weight_quant_type, np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + activation_quant_type
+ '_' + weight_quant_type, np.sum(w_freeze)))
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
......@@ -371,45 +386,103 @@ class TestQuantizationFreezePass(unittest.TestCase):
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
test_graph.draw('.', 'test_mobile' + dev_name +
activation_quant_type + '_' + weight_quant_type,
marked_nodes)
mobile_program = test_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type,
['image', 'label'], [loss], exe,
mobile_program)
fluid.io.save_inference_model(
'mobile_int8' + dev_name + activation_quant_type + '_' +
weight_quant_type, ['image', 'label'], [loss], exe,
mobile_program)
def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(
True, seed=1, quant_type='abs_max', for_ci=False)
True,
seed=1,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True)
with fluid.unique_name.guard():
self.freeze_graph(
True,
seed=1,
activation_quant_type='abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=False)
self.freeze_graph(
False,
seed=2,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
activation_quant_type='abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(
True, seed=1, quant_type='range_abs_max', for_ci=False)
True,
seed=1,
activation_quant_type='range_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
True,
seed=1,
activation_quant_type='moving_average_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
True,
seed=1,
quant_type='moving_average_abs_max',
for_ci=False)
activation_quant_type='range_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
self.freeze_graph(
True,
seed=1,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
self.freeze_graph(
False, seed=2, quant_type='range_abs_max', for_ci=False)
False,
seed=2,
activation_quant_type='range_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
activation_quant_type='moving_average_abs_max',
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
activation_quant_type='range_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
self.freeze_graph(
False,
seed=2,
quant_type='moving_average_abs_max',
for_ci=False)
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册