提交 a9b99418 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!580 store data with default datatype when call numpy.array in TensorContainer...

!580 store data with default datatype when call numpy.array in TensorContainer and remove limitaion of datatype
Merge pull request !580 from wangshuide/wsd0827_r07
...@@ -188,15 +188,6 @@ class ResponseDataExceedMaxValueError(MindInsightException): ...@@ -188,15 +188,6 @@ class ResponseDataExceedMaxValueError(MindInsightException):
http_code=400) http_code=400)
class DataTypeError(MindInsightException):
"""Data_type does not support."""
def __init__(self, error_detail):
error_msg = f'Data type does not support. Detail: {error_detail}'
super(DataTypeError, self).__init__(DataVisualErrors.DATA_TYPE_NOT_SUPPORT,
error_msg,
http_code=400)
class TrainJobDetailNotInCacheError(MindInsightException): class TrainJobDetailNotInCacheError(MindInsightException):
"""Detail info of given train job is not in cache.""" """Detail info of given train job is not in cache."""
def __init__(self, error_detail="no detail provided."): def __init__(self, error_detail="no detail provided."):
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Tensor data container.""" """Tensor data container."""
import numpy as np import numpy as np
from mindinsight.datavisual.common.exceptions import DataTypeError
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
...@@ -253,19 +252,14 @@ class TensorContainer: ...@@ -253,19 +252,14 @@ class TensorContainer:
Get ndarray of tensor. Get ndarray of tensor.
Args: Args:
tensor (float16|float32|float64): tensor data. tensor (mindinsight_anf_ir.proto.DataType): tensor data.
Returns: Returns:
numpy.ndarray, ndarray of tensor. numpy.ndarray, ndarray of tensor.
Raises:
DataTypeError, If data type of tensor is not among float16 or float32 or float64.
""" """
data_type_str = anf_ir_pb2.DataType.Name(self.data_type) data_type_str = anf_ir_pb2.DataType.Name(self.data_type)
if data_type_str == 'DT_FLOAT16': if data_type_str == 'DT_FLOAT16':
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims) return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims)
if data_type_str == 'DT_FLOAT32': if data_type_str == 'DT_FLOAT32':
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims) return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims)
if data_type_str == 'DT_FLOAT64': return np.array(tuple(tensor)).reshape(self.dims)
return np.array(tuple(tensor), dtype=np.float64).reshape(self.dims)
raise DataTypeError("Data type: {}.".format(data_type_str))
...@@ -77,7 +77,6 @@ class DataVisualErrors(Enum): ...@@ -77,7 +77,6 @@ class DataVisualErrors(Enum):
TENSOR_NOT_EXIST = 18 TENSOR_NOT_EXIST = 18
MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19 MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19
STEP_TENSOR_DATA_NOT_IN_CACHE = 20 STEP_TENSOR_DATA_NOT_IN_CACHE = 20
DATA_TYPE_NOT_SUPPORT = 21
class ScriptConverterErrors(Enum): class ScriptConverterErrors(Enum):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册