提交 1b59fa8f 编写于 作者: E Eugene Zhulenev 提交者: TensorFlower Gardener

[xla] Define primitive type bit and byte width in a header file

BEFORE

---------------------------------------------------------------
Benchmark                     Time             CPU   Iterations
---------------------------------------------------------------
BM_FlatMemrefX12All        81.3 ns         81.3 ns      8617874
BM_FlatMemrefX12None       69.1 ns         69.1 ns     10129964

AFTER

---------------------------------------------------------------
Benchmark                     Time             CPU   Iterations
---------------------------------------------------------------
BM_FlatMemrefX12All        54.8 ns         54.8 ns     12683098
BM_FlatMemrefX12None       47.7 ns         47.7 ns     14667555

PiperOrigin-RevId: 481192231
上级 51b971cc
......@@ -94,48 +94,6 @@ bool IsIntegralType(PrimitiveType type) {
return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
}
int BitWidth(PrimitiveType type) {
switch (type) {
case PRED:
return 1;
case S8:
case U8:
return 8;
case S16:
case U16:
case F16:
case BF16:
return 16;
case U32:
case S32:
case F32:
return 32;
case U64:
case S64:
case F64:
case C64:
return 64;
case C128:
return 128;
case TUPLE:
LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
case OPAQUE_TYPE:
LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
default:
LOG(FATAL) << "Unhandled primitive type " << type;
}
}
int ByteWidth(PrimitiveType type) { return CeilOfRatio(BitWidth(type), 8); }
xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth) {
switch (src_bitwidth) {
case 8:
......
......@@ -22,6 +22,7 @@ limitations under the License.
#include <tuple>
#include <type_traits>
#include "absl/base/attributes.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
......@@ -154,10 +155,86 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) {
}
// Returns the number of bits in the representation for a given type.
int BitWidth(PrimitiveType type);
ABSL_ATTRIBUTE_ALWAYS_INLINE inline int BitWidth(PrimitiveType type) {
switch (type) {
case PRED:
return 1;
case S8:
case U8:
return 8;
case S16:
case U16:
case F16:
case BF16:
return 16;
case U32:
case S32:
case F32:
return 32;
case U64:
case S64:
case F64:
case C64:
return 64;
case C128:
return 128;
case TUPLE:
LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
case OPAQUE_TYPE:
LOG(FATAL) << "OPAQUE_TYPE is an invalid type for BitWidth";
default:
LOG(FATAL) << "Unhandled primitive type " << type;
}
}
// Returns the number of bytes in the representation for a given type.
int ByteWidth(PrimitiveType type);
ABSL_ATTRIBUTE_ALWAYS_INLINE inline int ByteWidth(PrimitiveType type) {
switch (type) {
case PRED:
return 1;
case S8:
case U8:
return 1;
case S16:
case U16:
case F16:
case BF16:
return 2;
case U32:
case S32:
case F32:
return 4;
case U64:
case S64:
case F64:
case C64:
return 8;
case C128:
return 16;
case TUPLE:
LOG(FATAL) << "TUPLE is an invalid type for ByteWidth";
case OPAQUE_TYPE:
LOG(FATAL) << "OPAQUE_TYPE is an invalid type for ByteWidth";
default:
LOG(FATAL) << "Unhandled primitive type " << type;
}
}
PrimitiveType UnsignedIntegralTypeForBitWidth(int64_t src_bitwidth);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册