提交 2b442889 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add transpose_output option to Dequantize Op in XLA.

PiperOrigin-RevId: 225396120
上级 4ba2816b
......@@ -74,9 +74,14 @@ inline std::vector<uint32> PackToUint32(absl::Span<const T> input) {
// Only uint8 or uint16 is supported for the original unpacked input.
// Returns a tensor of shape [d0,..., dn * unpack_size] if
// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T).
// If transpose_output is true, will return a tensor of shape
// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when
// input's rank higher than 1. The input needs to be transposed to use
// transpose_output feature.
template <typename T>
inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
absl::string_view mode_string = "MIN_COMBINED") {
absl::string_view mode_string = "MIN_COMBINED",
bool transpose_output = false) {
XlaBuilder* const builder = input.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
float half_range =
......@@ -94,14 +99,9 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
"Only U32 is supported for input type of xla::Dequantize Op.");
}
auto broadcast_size = shape.dimensions();
broadcast_size.push_back(unpack_size);
std::vector<int64> broadcast_dimensions(shape.dimensions_size());
std::iota(broadcast_dimensions.begin(), broadcast_dimensions.end(), 0);
// Broadcast the input to [d0, ..., dn, unpack_size] if input size is
// Broadcast the input to [unpack_size, d0, ..., dn] if input size is
// [d0, ..., dn].
auto broadcast_input =
BroadcastInDim(input, broadcast_size, broadcast_dimensions);
auto broadcast_input = Broadcast(input, {unpack_size});
XlaOp iota_r1 = Iota(builder, U32, unpack_size);
// Highest significant bytes needs to shift more bytes than lower
......@@ -110,8 +110,9 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
xla::ConstantR0<uint32>(builder, unpack_size - 1) - iota_r1;
const int bytes_of_type = sizeof(T) / sizeof(uint8);
XlaOp shift_bits = shift_bytes * xla::ConstantR0<uint32>(
builder, kBitsOfByte * bytes_of_type);
std::vector<uint32> shift_vec(unpack_size, kBitsOfByte * bytes_of_type);
XlaOp shift_bits =
shift_bytes * xla::ConstantR1<uint32>(builder, shift_vec);
// Make bit_mask for different data type T.
uint32 bit_mask = 0x00000000;
......@@ -120,9 +121,16 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
bit_mask |= 0x000000ff;
}
std::vector<int64> shift_transpose_dimensions(shape.dimensions_size());
std::iota(shift_transpose_dimensions.begin(),
shift_transpose_dimensions.end(), 0);
shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1,
shape.dimensions_size());
// Shift the input by sizeof(T) bytes and apply bit_mask to unpack.
XlaOp shifted_input = ShiftRightLogical(
broadcast_input, Broadcast(shift_bits, shape.dimensions()));
broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()),
shift_transpose_dimensions));
XlaOp unpack_input =
And(shifted_input, xla::ConstantR0<uint32>(builder, bit_mask));
......@@ -148,12 +156,28 @@ inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range,
"Only MIN_COMBINED mode is supported in xla::Dequantize Op.");
}
// Reshape the result to [d0,..., dn * unpack_size] if
// input shape is [d0, ..., dn].
std::vector<int64> result_shape(shape.dimensions());
result_shape[shape.dimensions_size() - 1] =
shape.dimensions(shape.dimensions_size() - 1) * unpack_size;
return Reshape(result, result_shape);
std::vector<int64> transpose_dimensions(shape.dimensions_size());
std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1);
std::reverse(transpose_dimensions.begin(), transpose_dimensions.end());
transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0);
// Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0].
XlaOp transposed_result = Transpose(result, transpose_dimensions);
// Reshape to be [dn * unpack_size, dn-1, ..., d1, d0].
XlaOp reshaped_result = Collapse(transposed_result, {0, 1});
// Return the transpose result if transpose_output is true.
if (transpose_output) {
return reshaped_result;
}
// Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size].
std::vector<int64> result_dimensions(shape.dimensions_size());
std::iota(result_dimensions.begin(), result_dimensions.end(), 0);
std::reverse(result_dimensions.begin(), result_dimensions.end());
return Transpose(reshaped_result, result_dimensions);
});
}
......
......@@ -77,13 +77,25 @@ Array2D<uint32> PackLargeInput(Array2D<NativeT> &input) {
template <typename NativeT>
Array2D<bfloat16> GenerateLargeSizeMinCombinedOutput(
Array2D<NativeT> &input, const QuantizedRange &range) {
Array2D<NativeT> &input, const QuantizedRange &range,
bool transpose_output = false) {
const int64 size_per_pack = sizeof(uint32) / sizeof(NativeT);
int64 width = input.width();
int64 padded_output_width = CeilOfRatio(width, size_per_pack) * size_per_pack;
Array2D<bfloat16> output(input.height(), padded_output_width, bfloat16(0.0));
int64 output_height;
int64 output_width;
if (transpose_output) {
output_height = padded_output_width;
output_width = input.height();
} else {
output_height = input.height();
output_width = padded_output_width;
}
Array2D<bfloat16> output(output_height, output_width, bfloat16(0.0));
float half_range =
!std::is_signed<NativeT>::value
......@@ -102,7 +114,11 @@ Array2D<bfloat16> GenerateLargeSizeMinCombinedOutput(
bfloat16 result =
static_cast<bfloat16>(input(h, w) + half_range) * scale_factor +
range.min;
output(h, w) = result;
if (transpose_output) {
output(w, h) = result;
} else {
output(h, w) = result;
}
}
}
......@@ -206,6 +222,29 @@ XLA_TEST_F(DequantizeTest, MinCombinedUint8R2) {
ComputeAndCompareR2<bfloat16>(&builder, expected, {});
}
XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TransposeOutput) {
XlaBuilder builder(TestName());
std::vector<std::vector<uint8>> input = {
{0, 1, 2, 3},
{4, 5, 6, 7},
{8, 9, 10, 11},
{12, 13, 16, 15},
};
auto x = ConstantR2<uint32>(&builder, {{PackToUint32<uint8>(input[0])[0]},
{PackToUint32<uint8>(input[1])[0]},
{PackToUint32<uint8>(input[2])[0]},
{PackToUint32<uint8>(input[3])[0]}});
QuantizedRange range(0, 255.0f);
xla::Dequantize<uint8>(x, range, "MIN_COMBINED", /*transpose_output=*/true);
const Array2D<bfloat16> expected = {
{bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)},
{bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)},
{bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)},
{bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)},
};
ComputeAndCompareR2<bfloat16>(&builder, expected, {});
}
XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) {
XlaBuilder builder(TestName());
std::vector<std::vector<uint8>> input = {
......@@ -236,6 +275,36 @@ XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZero) {
ComputeAndCompareR2<bfloat16>(&builder, expected, {});
}
XLA_TEST_F(DequantizeTest, MinCombinedUint8R2TailingZeroTransposeOutput) {
XlaBuilder builder(TestName());
std::vector<std::vector<uint8>> input = {
{0, 1, 2, 3, 16},
{4, 5, 6, 7, 17},
{8, 9, 10, 11, 18},
{12, 13, 16, 15, 19},
};
auto x = ConstantR2<uint32>(
&builder,
{{PackToUint32<uint8>(input[0])[0], PackToUint32<uint8>(input[0])[1]},
{PackToUint32<uint8>(input[1])[0], PackToUint32<uint8>(input[1])[1]},
{PackToUint32<uint8>(input[2])[0], PackToUint32<uint8>(input[2])[1]},
{PackToUint32<uint8>(input[3])[0], PackToUint32<uint8>(input[3])[1]}});
QuantizedRange range(0, 255.0f);
xla::Dequantize<uint8>(x, range, "MIN_COMBINED", /*transpose_output=*/true);
const Array2D<bfloat16> expected = {
{bfloat16(0.0), bfloat16(4.0), bfloat16(8.0), bfloat16(12.0)},
{bfloat16(1.0), bfloat16(5.0), bfloat16(9.0), bfloat16(13.0)},
{bfloat16(2.0), bfloat16(6.0), bfloat16(10.0), bfloat16(16.0)},
{bfloat16(3.0), bfloat16(7.0), bfloat16(11.0), bfloat16(15.0)},
{bfloat16(16.0), bfloat16(17.0), bfloat16(18.0), bfloat16(19.0)},
{bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(0.0), bfloat16(0.0), bfloat16(0.0)},
};
ComputeAndCompareR2<bfloat16>(&builder, expected, {});
}
XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) {
XlaBuilder builder(TestName());
Array2D<uint8> input = GenerateLargeSizeInput<uint8>(500, 3547);
......@@ -250,5 +319,19 @@ XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTest) {
ComputeAndCompareR2<bfloat16>(&builder, expected, {});
}
XLA_TEST_F(DequantizeTest, MinCombinedUint8LargeSizeTestTransposeOutput) {
XlaBuilder builder(TestName());
Array2D<uint8> input = GenerateLargeSizeInput<uint8>(500, 3547);
Array2D<uint32> input_packed = PackLargeInput<uint8>(input);
auto x = ConstantR2FromArray2D<uint32>(&builder, input_packed);
QuantizedRange range(0, 255.0f);
xla::Dequantize<uint8>(x, range, "MIN_COMBINED", /*transpose_output=*/true);
const Array2D<bfloat16> expected = GenerateLargeSizeMinCombinedOutput<uint8>(
input, range, /*transpose_output=*/true);
ComputeAndCompareR2<bfloat16>(&builder, expected, {});
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册