提交 b0fa50e2 编写于 作者: W wangshuide2020

use default np.float64 to create ndarray to reduce error of calculation precision

上级 b5bb6f17
...@@ -17,7 +17,6 @@ import numpy as np ...@@ -17,7 +17,6 @@ import numpy as np
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils.utils import calc_histogram_bins from mindinsight.datavisual.utils.utils import calc_histogram_bins
from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError
...@@ -192,10 +191,8 @@ class TensorContainer: ...@@ -192,10 +191,8 @@ class TensorContainer:
self._stats = get_statistics_from_tensor(self._np_array) self._stats = get_statistics_from_tensor(self._np_array)
original_buckets = calc_original_buckets(self._np_array, self._stats) original_buckets = calc_original_buckets(self._np_array, self._stats)
self._count = sum(bucket.count for bucket in original_buckets) self._count = sum(bucket.count for bucket in original_buckets)
# convert the type of max and min value to np.float64 so that it cannot overflow self._max = self._stats.max
# when calculating width of histogram. self._min = self._stats.min
self._max = np.float64(self._stats.max)
self._min = np.float64(self._stats.min)
self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count) self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count)
@property @property
...@@ -257,9 +254,4 @@ class TensorContainer: ...@@ -257,9 +254,4 @@ class TensorContainer:
Returns: Returns:
numpy.ndarray, ndarray of tensor. numpy.ndarray, ndarray of tensor.
""" """
data_type_str = anf_ir_pb2.DataType.Name(self.data_type)
if data_type_str == 'DT_FLOAT16':
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims)
if data_type_str == 'DT_FLOAT32':
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims)
return np.array(tuple(tensor)).reshape(self.dims) return np.array(tuple(tensor)).reshape(self.dims)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册