提交 4d43bd6c 编写于 作者: W wangshuide2020

Store data with default datatype in TensorContainer and remove limitation of datatype.

上级 1fd40dc0
...@@ -188,15 +188,6 @@ class ResponseDataExceedMaxValueError(MindInsightException): ...@@ -188,15 +188,6 @@ class ResponseDataExceedMaxValueError(MindInsightException):
http_code=400) http_code=400)
class DataTypeError(MindInsightException):
"""Data_type does not support."""
def __init__(self, error_detail):
error_msg = f'Data type does not support. Detail: {error_detail}'
super(DataTypeError, self).__init__(DataVisualErrors.DATA_TYPE_NOT_SUPPORT,
error_msg,
http_code=400)
class TrainJobDetailNotInCacheError(MindInsightException): class TrainJobDetailNotInCacheError(MindInsightException):
"""Detail info of given train job is not in cache.""" """Detail info of given train job is not in cache."""
def __init__(self, error_detail="no detail provided."): def __init__(self, error_detail="no detail provided."):
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tensor data container.""" """Tensor data container."""
import numpy as np import numpy as np
from mindinsight.datavisual.common.exceptions import DataTypeError
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
...@@ -253,19 +252,14 @@ class TensorContainer: ...@@ -253,19 +252,14 @@ class TensorContainer:
Get ndarray of tensor. Get ndarray of tensor.
Args: Args:
tensor (float16|float32|float64): tensor data. tensor (mindinsight_anf_ir.proto.DataType): tensor data.
Returns: Returns:
numpy.ndarray, ndarray of tensor. numpy.ndarray, ndarray of tensor.
Raises:
DataTypeError, If data type of tensor is not among float16 or float32 or float64.
""" """
data_type_str = anf_ir_pb2.DataType.Name(self.data_type) data_type_str = anf_ir_pb2.DataType.Name(self.data_type)
if data_type_str == 'DT_FLOAT16': if data_type_str == 'DT_FLOAT16':
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims) return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims)
if data_type_str == 'DT_FLOAT32': if data_type_str == 'DT_FLOAT32':
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims) return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims)
if data_type_str == 'DT_FLOAT64': return np.array(tuple(tensor)).reshape(self.dims)
return np.array(tuple(tensor), dtype=np.float64).reshape(self.dims)
raise DataTypeError("Data type: {}.".format(data_type_str))
...@@ -77,7 +77,6 @@ class DataVisualErrors(Enum): ...@@ -77,7 +77,6 @@ class DataVisualErrors(Enum):
TENSOR_NOT_EXIST = 18 TENSOR_NOT_EXIST = 18
MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19 MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19
STEP_TENSOR_DATA_NOT_IN_CACHE = 20 STEP_TENSOR_DATA_NOT_IN_CACHE = 20
DATA_TYPE_NOT_SUPPORT = 21
class ScriptConverterErrors(Enum): class ScriptConverterErrors(Enum):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册