diff --git a/mace/python/tools/tensor_util.py b/mace/python/tools/tensor_util.py index 62d33052982fe19784385cf2d32cd3faad146233..61a5e90cf0a8a73c73bd092dfc340760e493ab18 100644 --- a/mace/python/tools/tensor_util.py +++ b/mace/python/tools/tensor_util.py @@ -165,6 +165,7 @@ def del_tensor_data(net_def, runtime, gpu_data_type): elif t.data_type == mace_pb2.DT_UINT8: del t.int32_data[:] + def update_tensor_data_type(net_def, runtime, gpu_data_type): for t in net_def.tensors: if t.data_type == mace_pb2.DT_FLOAT and runtime == 'gpu' \