From 70fc4e4a971affb31850d8ead3222a219abe73f6 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Fri, 8 Sep 2023 17:36:41 -0700 Subject: [PATCH] Fix proto bytes extension for windows. PiperOrigin-RevId: 563892946 --- tensorflow/python/framework/tensor_util.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 26d82b9a726..b494f52410e 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -73,14 +73,11 @@ def FastAppendBFloat16ArrayToTensorProto(tensor_proto, proto_values): proto_values, dtype=dtypes.bfloat16.as_numpy_dtype).view(np.uint16)) -def ExtractBitsFromFloat8e5m2(x): - return np.asarray( - x, dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8).item() - - def SlowAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.float8_val.extend( - [ExtractBitsFromFloat8e5m2(x) for x in proto_values] + tensor_proto.float8_val += ( + np.asarray(proto_values, dtype=dtypes.float8_e5m2.as_numpy_dtype) + .view(np.uint8) + .tobytes() ) @@ -91,14 +88,11 @@ def FastAppendFloat8e5m2ArrayToTensorProto(tensor_proto, proto_values): dtype=dtypes.float8_e5m2.as_numpy_dtype).view(np.uint8)) -def ExtractBitsFromFloat8e4m3fn(x): - return np.asarray( - x, dtype=dtypes.float8_e4m3fn.as_numpy_dtype).view(np.uint8).item() - - def SlowAppendFloat8e4m3fnArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.float8_val.extend( - [ExtractBitsFromFloat8e4m3fn(x) for x in proto_values] + tensor_proto.float8_val += ( + np.asarray(proto_values, dtype=dtypes.float8_e4m3fn.as_numpy_dtype) + .view(np.uint8) + .tobytes() ) -- GitLab