提交 cf9260a9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!562 remove redundant data to save memory and simplify tensorcontainer.

Merge pull request !562 from wangshuide/wsd0727
......@@ -209,7 +209,7 @@ class GunicornLogger(Logger):
super(GunicornLogger, self).__init__(cfg)
def now(self):
"""return the log format"""
"""Get log format."""
return time.strftime('[%Y-%m-%d-%H:%M:%S %z]')
def setup(self, cfg):
......
......@@ -188,6 +188,15 @@ class ResponseDataExceedMaxValueError(MindInsightException):
http_code=400)
class DataTypeError(MindInsightException):
"""Data_type does not support."""
def __init__(self, error_detail):
error_msg = f'Data type does not support. Detail: {error_detail}'
super(DataTypeError, self).__init__(DataVisualErrors.DATA_TYPE_NOT_SUPPORT,
error_msg,
http_code=400)
class TrainJobDetailNotInCacheError(MindInsightException):
"""Detail info of given train job is not in cache."""
def __init__(self, error_detail="no detail provided."):
......
......@@ -296,6 +296,7 @@ class _SummaryParser(_Parser):
self._load_single_file(self._summary_file_handler, executor)
# Wait for data in this file to be processed to avoid loading multiple files at the same time.
executor.wait_all_tasks_finish()
logger.info("Parse summary file finished, file path: %s.", file_path)
except UnknownError as ex:
logger.warning("Parse summary file failed, detail: %r,"
"file path: %s.", str(ex), file_path)
......@@ -383,7 +384,7 @@ class _SummaryParser(_Parser):
# read the header
header_str = file_handler.read(HEADER_SIZE)
if not header_str:
logger.info("End of file, file_path=%s.", file_handler.file_path)
logger.info("Load summary file finished, file_path=%s.", file_handler.file_path)
return None
header_crc_str = file_handler.read(CRC_STR_SIZE)
if not header_crc_str:
......@@ -441,12 +442,9 @@ class _SummaryParser(_Parser):
elif plugin == PluginNameEnum.TENSOR.value:
tensor_event_value = TensorContainer(tensor_event_value)
tensor_count = 1
for d in tensor_event_value.dims:
tensor_count *= d
if tensor_count > MAX_TENSOR_COUNT:
if tensor_event_value.size > MAX_TENSOR_COUNT:
logger.warning('tag: %s/tensor, dims: %s, tensor count: %d exceeds %d and drop it.',
value.tag, tensor_event_value.dims, tensor_count, MAX_TENSOR_COUNT)
value.tag, tensor_event_value.dims, tensor_event_value.size, MAX_TENSOR_COUNT)
return None
elif plugin == PluginNameEnum.IMAGE.value:
......
......@@ -13,12 +13,12 @@
# limitations under the License.
# ============================================================================
"""Tensor data container."""
import threading
import numpy as np
from mindinsight.datavisual.common.exceptions import DataTypeError
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils.utils import calc_histogram_bins
from mindinsight.utils.exceptions import ParamValueError
......@@ -139,19 +139,6 @@ def get_statistics_from_tensor(tensors):
return statistics
def _get_data_from_tensor(tensor):
"""
Get data from tensor and convert to tuple.
Args:
tensor (TensorProto): Tensor proto data.
Returns:
tuple, the item of tensor value.
"""
return tuple(tensor.float_data)
def calc_original_buckets(np_value, stats):
"""
Calculate buckets from tensor data.
......@@ -199,19 +186,22 @@ class TensorContainer:
"""
def __init__(self, tensor_message):
self._lock = threading.Lock
# Original dims can not be pickled to transfer to other process, so tuple is used.
self._dims = tuple(tensor_message.dims)
self._data_type = tensor_message.data_type
self._np_array = None
self._data = _get_data_from_tensor(tensor_message)
self._stats = get_statistics_from_tensor(self.get_or_calc_ndarray())
original_buckets = calc_original_buckets(self.get_or_calc_ndarray(), self._stats)
self._np_array = self.get_ndarray(tensor_message.float_data)
self._stats = get_statistics_from_tensor(self._np_array)
original_buckets = calc_original_buckets(self._np_array, self._stats)
self._count = sum(bucket.count for bucket in original_buckets)
self._max = self._stats.max
self._min = self._stats.min
self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count)
@property
def size(self):
"""Get size of tensor."""
return self._np_array.size
@property
def dims(self):
"""Get dims of tensor."""
......@@ -222,6 +212,11 @@ class TensorContainer:
"""Get data type of tensor."""
return self._data_type
@property
def ndarray(self):
"""Get ndarray of tensor."""
return self._np_array
@property
def max(self):
"""Get max value of tensor."""
......@@ -251,19 +246,24 @@ class TensorContainer:
"""Get histogram buckets."""
return self._histogram.buckets()
def get_or_calc_ndarray(self):
"""Get or calculate ndarray."""
with self._lock():
if self._np_array is None:
self._convert_to_numpy_array()
return self._np_array
def _convert_to_numpy_array(self):
"""Convert a list data to numpy array."""
try:
ndarray = np.array(self._data).reshape(self._dims)
except ValueError as ex:
logger.error("Reshape array fail, detail: %r", str(ex))
return
self._np_array = ndarray
def get_ndarray(self, tensor):
"""
Get ndarray of tensor.
Args:
tensor (float16|float32|float64): tensor data.
Returns:
numpy.ndarray, ndarray of tensor.
Raises:
DataTypeError, If data type of tensor is not among float16 or float32 or float64.
"""
data_type_str = anf_ir_pb2.DataType.Name(self.data_type)
if data_type_str == 'DT_FLOAT16':
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims)
if data_type_str == 'DT_FLOAT32':
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims)
if data_type_str == 'DT_FLOAT64':
return np.array(tuple(tensor), dtype=np.float64).reshape(self.dims)
raise DataTypeError("Data type: {}.".format(data_type_str))
......@@ -99,9 +99,9 @@ def get_statistics_dict(stats):
dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'.
"""
statistics = {
"max": stats.max,
"min": stats.min,
"avg": stats.avg,
"max": float(stats.max),
"min": float(stats.min),
"avg": float(stats.avg),
"count": stats.count,
"nan_count": stats.nan_count,
"neg_inf_count": stats.neg_inf_count,
......@@ -302,8 +302,7 @@ class TensorProcessor(BaseProcessor):
if step != tensor.step:
continue
step_in_cache = True
ndarray = value.get_or_calc_ndarray()
res_data = get_specific_dims_data(ndarray, dims, list(value.dims))
res_data = get_specific_dims_data(value.ndarray, dims, list(value.dims))
flatten_data = res_data.flatten().tolist()
if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE:
raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}."
......@@ -326,7 +325,7 @@ class TensorProcessor(BaseProcessor):
elif np.isposinf(data):
transfer_data[index] = 'INF'
else:
transfer_data[index] = data
transfer_data[index] = float(data)
return transfer_data
stats = get_statistics_from_tensor(res_data)
......
......@@ -77,6 +77,7 @@ class DataVisualErrors(Enum):
TENSOR_NOT_EXIST = 18
MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19
STEP_TENSOR_DATA_NOT_IN_CACHE = 20
DATA_TYPE_NOT_SUPPORT = 21
class ScriptConverterErrors(Enum):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册