From b0fa50e2e57cd40bff2ee2f9414c2e3d2dd3101d Mon Sep 17 00:00:00 2001 From: wangshuide2020 <7511764+wangshuide2020@user.noreply.gitee.com> Date: Thu, 3 Sep 2020 17:45:19 +0800 Subject: [PATCH] use default np.float64 to create ndarray to reduce error of calculation precision --- .../datavisual/data_transform/tensor_container.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/mindinsight/datavisual/data_transform/tensor_container.py b/mindinsight/datavisual/data_transform/tensor_container.py index 1f41409..ef430b3 100644 --- a/mindinsight/datavisual/data_transform/tensor_container.py +++ b/mindinsight/datavisual/data_transform/tensor_container.py @@ -17,7 +17,6 @@ import numpy as np from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket -from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.utils.utils import calc_histogram_bins from mindinsight.utils.exceptions import ParamValueError @@ -192,10 +191,8 @@ class TensorContainer: self._stats = get_statistics_from_tensor(self._np_array) original_buckets = calc_original_buckets(self._np_array, self._stats) self._count = sum(bucket.count for bucket in original_buckets) - # convert the type of max and min value to np.float64 so that it cannot overflow - # when calculating width of histogram. - self._max = np.float64(self._stats.max) - self._min = np.float64(self._stats.min) + self._max = self._stats.max + self._min = self._stats.min self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count) @property @@ -257,9 +254,4 @@ class TensorContainer: Returns: numpy.ndarray, ndarray of tensor. """ - data_type_str = anf_ir_pb2.DataType.Name(self.data_type) - if data_type_str == 'DT_FLOAT16': - return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims) - if data_type_str == 'DT_FLOAT32': - return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims) return np.array(tuple(tensor)).reshape(self.dims) -- GitLab