提交 c872622d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!591 use default np.float64 to create ndarray to reduce error of calculation...

!591 use default np.float64 to create ndarray to reduce error of calculation precision in resample of tensor histogram
Merge pull request !591 from wangshuide/wsd0727
......@@ -17,7 +17,6 @@ import numpy as np
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils.utils import calc_histogram_bins
from mindinsight.utils.exceptions import ParamValueError
......@@ -192,10 +191,8 @@ class TensorContainer:
self._stats = get_statistics_from_tensor(self._np_array)
original_buckets = calc_original_buckets(self._np_array, self._stats)
self._count = sum(bucket.count for bucket in original_buckets)
# convert the type of max and min value to np.float64 so that it cannot overflow
# when calculating width of histogram.
self._max = np.float64(self._stats.max)
self._min = np.float64(self._stats.min)
self._max = self._stats.max
self._min = self._stats.min
self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count)
@property
......@@ -257,9 +254,4 @@ class TensorContainer:
Returns:
numpy.ndarray, ndarray of tensor.
"""
data_type_str = anf_ir_pb2.DataType.Name(self.data_type)
if data_type_str == 'DT_FLOAT16':
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims)
if data_type_str == 'DT_FLOAT32':
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims)
return np.array(tuple(tensor)).reshape(self.dims)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册