diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 26d82b9a726433865e5b23bb9efcce368a2f5b26..b494f52410ee2a8ba5b54a3662b634db64d32cc9 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -73,14 +73,11 @@ def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) -def ExtractBitsFromFloat8e5m2(x): - return np.asarray( - x, dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8).item() - - def SlowAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.float8_val.extend( - [ExtractBitsFromFloat8e5m2(x) for x in proto_values] + tensor_proto.float8_val += ( + np.asarray(proto_values, dtype=dtypes.float8_e5m2.as_numpy_dtype) + .view(np.uint8) + .tobytes() ) @@ -91,14 +88,11 @@ def FastAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values): dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8)) -def ExtractBitsFromFloat8e4m3fn(x): - return np.asarray( - x, dtype=dtypes.float8_e4m3fn.as_numpy_dtype).view(np.uint8).item() - - def SlowAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.float8_val.extend( - [ExtractBitsFromFloat8e4m3fn(x) for x in proto_values] + tensor_proto.float8_val += ( + np.asarray(proto_values, dtype=dtypes.float8_e4m3fn.as_numpy_dtype) + .view(np.uint8) + .tobytes() )