未验证 提交 591744eb 编写于 作者: L lokoppakmsft 提交者: GitHub

Support N-dimension input in quantization kernel (#2575)

* Add support for inputs > 2D

* use vec

* Add N-Dim support to Dequant kernel

* merge master and fix format

* Bug Fix

* fix num_bits

* Fix dequant
Co-authored-by: NConnor Holmes <connorholmes@microsoft.com>
上级 18d55e54
......@@ -89,10 +89,10 @@ std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int K = input_vals.size(1) / (numBits == 8 ? 1 : 2);
int M = input_vals.size(0);
auto output = torch::empty({M, K}, output_options);
auto output_sizes = input_vals.sizes().vec();
output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
......@@ -113,18 +113,6 @@ std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
return {output, params};
}
int num_decompressed_elems(at::Tensor& quantized_data, int num_bits)
{
if (num_bits == 8) {
return quantized_data.size(-1);
} else if (num_bits == 4) {
return quantized_data.size(-1) * 2;
} else {
assert(false);
return 0;
}
}
at::Tensor dequantize(at::Tensor& quantized_data,
at::Tensor& params,
int groups,
......@@ -136,10 +124,12 @@ at::Tensor dequantize(at::Tensor& quantized_data,
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int final_dim_size = num_decompressed_elems(quantized_data, num_bits);
auto output = torch::empty({quantized_data.size(0), final_dim_size}, output_options);
const int total_elems = quantized_data.size(0) * final_dim_size;
auto output_sizes = quantized_data.sizes().vec();
output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int total_elems = at::numel(output);
const int elems_per_group = total_elems / groups;
launch_dequantize_kernel((__half*)output.data_ptr(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册