From 4d43bd6c5a9a531bac8c021b196608daa0a88ecb Mon Sep 17 00:00:00 2001 From: wangshuide2020 <7511764+wangshuide2020@user.noreply.gitee.com> Date: Fri, 28 Aug 2020 17:22:36 +0800 Subject: [PATCH] Store data with default datatype in TensorContainer and remove limitation of datatype. --- mindinsight/datavisual/common/exceptions.py | 9 --------- .../datavisual/data_transform/tensor_container.py | 10 ++-------- mindinsight/utils/constant.py | 1 - 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/mindinsight/datavisual/common/exceptions.py b/mindinsight/datavisual/common/exceptions.py index f8fa401..47322d4 100644 --- a/mindinsight/datavisual/common/exceptions.py +++ b/mindinsight/datavisual/common/exceptions.py @@ -188,15 +188,6 @@ class ResponseDataExceedMaxValueError(MindInsightException): http_code=400) -class DataTypeError(MindInsightException): - """Data_type does not support.""" - def __init__(self, error_detail): - error_msg = f'Data type does not support. Detail: {error_detail}' - super(DataTypeError, self).__init__(DataVisualErrors.DATA_TYPE_NOT_SUPPORT, - error_msg, - http_code=400) - - class TrainJobDetailNotInCacheError(MindInsightException): """Detail info of given train job is not in cache.""" def __init__(self, error_detail="no detail provided."): diff --git a/mindinsight/datavisual/data_transform/tensor_container.py b/mindinsight/datavisual/data_transform/tensor_container.py index 0bdc94d..1f41409 100644 --- a/mindinsight/datavisual/data_transform/tensor_container.py +++ b/mindinsight/datavisual/data_transform/tensor_container.py @@ -15,7 +15,6 @@ """Tensor data container.""" import numpy as np -from mindinsight.datavisual.common.exceptions import DataTypeError from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 @@ -253,19 +252,14 @@ class TensorContainer: Get ndarray of tensor. Args: - tensor (float16|float32|float64): tensor data. + tensor (mindinsight_anf_ir.proto.DataType): tensor data. Returns: numpy.ndarray, ndarray of tensor. - - Raises: - DataTypeError, If data type of tensor is not among float16 or float32 or float64. """ data_type_str = anf_ir_pb2.DataType.Name(self.data_type) if data_type_str == 'DT_FLOAT16': return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims) if data_type_str == 'DT_FLOAT32': return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims) - if data_type_str == 'DT_FLOAT64': - return np.array(tuple(tensor), dtype=np.float64).reshape(self.dims) - raise DataTypeError("Data type: {}.".format(data_type_str)) + return np.array(tuple(tensor)).reshape(self.dims) diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 3920e95..d964513 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -77,7 +77,6 @@ class DataVisualErrors(Enum): TENSOR_NOT_EXIST = 18 MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19 STEP_TENSOR_DATA_NOT_IN_CACHE = 20 - DATA_TYPE_NOT_SUPPORT = 21 class ScriptConverterErrors(Enum): -- GitLab