提交 995f7d64 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!413 support tensor visualization

Merge pull request !413 from wangshuide/wsd_tensor_visualization
...@@ -25,6 +25,7 @@ from mindinsight.conf import settings ...@@ -25,6 +25,7 @@ from mindinsight.conf import settings
from mindinsight.datavisual.utils.tools import get_train_id from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.datavisual.utils.tools import if_nan_inf_to_none from mindinsight.datavisual.utils.tools import if_nan_inf_to_none
from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor from mindinsight.datavisual.processors.histogram_processor import HistogramProcessor
from mindinsight.datavisual.processors.tensor_processor import TensorProcessor
from mindinsight.datavisual.processors.images_processor import ImageProcessor from mindinsight.datavisual.processors.images_processor import ImageProcessor
from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor from mindinsight.datavisual.processors.scalars_processor import ScalarsProcessor
from mindinsight.datavisual.processors.graph_processor import GraphProcessor from mindinsight.datavisual.processors.graph_processor import GraphProcessor
...@@ -173,6 +174,25 @@ def get_scalars(): ...@@ -173,6 +174,25 @@ def get_scalars():
return jsonify({'scalars': scalars}) return jsonify({'scalars': scalars})
@BLUEPRINT.route("/datavisual/tensors", methods=["GET"])
def get_tensors():
"""
Interface to obtain tensor data.
Returns:
Response, which contains a JSON object.
"""
train_ids = request.args.getlist('train_id')
tags = request.args.getlist('tag')
step = request.args.get("step", default=None)
dims = request.args.get("dims", default=None)
detail = request.args.get("detail", default=None)
processor = TensorProcessor(DATA_MANAGER)
response = processor.get_tensors(train_ids, tags, step, dims, detail)
return jsonify(response)
def init_module(app): def init_module(app):
""" """
Init module entry. Init module entry.
......
...@@ -57,3 +57,5 @@ MAX_IMAGE_STEP_SIZE_PER_TAG = 10 ...@@ -57,3 +57,5 @@ MAX_IMAGE_STEP_SIZE_PER_TAG = 10
MAX_SCALAR_STEP_SIZE_PER_TAG = 1000 MAX_SCALAR_STEP_SIZE_PER_TAG = 1000
MAX_GRAPH_STEP_SIZE_PER_TAG = 1 MAX_GRAPH_STEP_SIZE_PER_TAG = 1
MAX_HISTOGRAM_STEP_SIZE_PER_TAG = 50 MAX_HISTOGRAM_STEP_SIZE_PER_TAG = 50
MAX_TENSOR_STEP_SIZE_PER_TAG = 50
MAX_TENSOR_RESPONSE_DATA_SIZE = 300000
...@@ -38,6 +38,7 @@ class PluginNameEnum(BaseEnum): ...@@ -38,6 +38,7 @@ class PluginNameEnum(BaseEnum):
SCALAR = 'scalar' SCALAR = 'scalar'
GRAPH = 'graph' GRAPH = 'graph'
HISTOGRAM = 'histogram' HISTOGRAM = 'histogram'
TENSOR = 'tensor'
@enum.unique @enum.unique
......
...@@ -161,6 +161,33 @@ class HistogramNotExistError(MindInsightException): ...@@ -161,6 +161,33 @@ class HistogramNotExistError(MindInsightException):
http_code=400) http_code=400)
class TensorNotExistError(MindInsightException):
"""Unable to get tensor values based on a given condition."""
def __init__(self, error_detail):
error_msg = f'Tensor value is not exist. Detail: {error_detail}'
super(TensorNotExistError, self).__init__(DataVisualErrors.TENSOR_NOT_EXIST,
error_msg,
http_code=400)
class StepTensorDataNotInCacheError(MindInsightException):
"""Tensor data with specific step does not in cache."""
def __init__(self, error_detail):
error_msg = f'Tensor data not in cache. Detail: {error_detail}'
super(StepTensorDataNotInCacheError, self).__init__(DataVisualErrors.STEP_TENSOR_DATA_NOT_IN_CACHE,
error_msg,
http_code=400)
class ResponseDataExceedMaxValueError(MindInsightException):
"""Response data exceed max value based on a given condition."""
def __init__(self, error_detail):
error_msg = f'Response data exceed max value. Detail: {error_detail}'
super(ResponseDataExceedMaxValueError, self).__init__(DataVisualErrors.MAX_RESPONSE_DATA_EXCEEDED_ERROR,
error_msg,
http_code=400)
class TrainJobDetailNotInCacheError(MindInsightException): class TrainJobDetailNotInCacheError(MindInsightException):
"""Detail info of given train job is not in cache.""" """Detail info of given train job is not in cache."""
def __init__(self, error_detail="no detail provided."): def __init__(self, error_detail="no detail provided."):
......
...@@ -41,7 +41,8 @@ CONFIG = { ...@@ -41,7 +41,8 @@ CONFIG = {
PluginNameEnum.SCALAR.value: settings.MAX_SCALAR_STEP_SIZE_PER_TAG, PluginNameEnum.SCALAR.value: settings.MAX_SCALAR_STEP_SIZE_PER_TAG,
PluginNameEnum.IMAGE.value: settings.MAX_IMAGE_STEP_SIZE_PER_TAG, PluginNameEnum.IMAGE.value: settings.MAX_IMAGE_STEP_SIZE_PER_TAG,
PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_STEP_SIZE_PER_TAG, PluginNameEnum.GRAPH.value: settings.MAX_GRAPH_STEP_SIZE_PER_TAG,
PluginNameEnum.HISTOGRAM.value: settings.MAX_HISTOGRAM_STEP_SIZE_PER_TAG PluginNameEnum.HISTOGRAM.value: settings.MAX_HISTOGRAM_STEP_SIZE_PER_TAG,
PluginNameEnum.TENSOR.value: settings.MAX_TENSOR_STEP_SIZE_PER_TAG
} }
} }
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Histogram data."""
import math
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.utils.utils import calc_histogram_bins
def mask_invalid_number(num):
"""Mask invalid number to 0."""
if math.isnan(num) or math.isinf(num):
return type(num)(0)
return num
class Bucket:
"""
Bucket data class.
Args:
left (double): Left edge of the histogram bucket.
width (double): Width of the histogram bucket.
count (int): Count of numbers fallen in the histogram bucket.
"""
def __init__(self, left, width, count):
self._left = left
self._width = width
self._count = count
@property
def left(self):
"""Gets left edge of the histogram bucket."""
return self._left
@property
def count(self):
"""Gets count of numbers fallen in the histogram bucket."""
return self._count
@property
def width(self):
"""Gets width of the histogram bucket."""
return self._width
@property
def right(self):
"""Gets right edge of the histogram bucket."""
return self._left + self._width
def as_tuple(self):
"""Gets the bucket as tuple."""
return self._left, self._width, self._count
def __repr__(self):
"""Returns repr(self)."""
return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count)
class Histogram:
"""
Histogram data class.
Args:
buckets (tuple[Bucket])
"""
# Max quantity of original buckets.
MAX_ORIGINAL_BUCKETS_COUNT = 90
def __init__(self, buckets, max_val, min_val, count):
self._visual_max = max_val
self._visual_min = min_val
self._count = count
self._original_buckets = buckets
# default bin number
self._visual_bins = calc_histogram_bins(count)
# Note that tuple is immutable, so sharing tuple is often safe.
self._re_sampled_buckets = ()
@property
def original_buckets_count(self):
"""Gets original buckets quantity."""
return len(self._original_buckets)
def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None:
"""
Sets visual range for later re-sampling.
It's caller's duty to ensure input is valid.
Why we need visual range for histograms? Aligned buckets between steps can help users know about the trend of
tensors. Miss aligned buckets between steps might miss-lead users about the trend of a tensor. Because for
given tensor, if you have thinner buckets, count of every bucket will get lower, however, if you have
thicker buckets, count of every bucket will get higher. When they are displayed together, user might think
the histogram with thicker buckets has more values. This is miss-leading. So we need to unify buckets across
steps. Visual range for histogram is a technology for unifying buckets.
Args:
max_val (float): Max value for visual histogram.
min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram.
"""
if max_val < min_val:
raise ParamValueError(
"Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val))
if bins < 1:
raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins))
self._visual_max = max_val
self._visual_min = min_val
self._visual_bins = bins
# mark _re_sampled_buckets to empty
self._re_sampled_buckets = ()
def _calc_intersection_len(self, max1, min1, max2, min2):
"""Calculates intersection length of [min1, max1] and [min2, max2]."""
if max1 < min1:
raise ParamValueError(
"Invalid input. max1({}) is less than min1({}).".format(max1, min1))
if max2 < min2:
raise ParamValueError(
"Invalid input. max2({}) is less than min2({}).".format(max2, min2))
if min1 <= min2:
if max1 <= min2:
# return value must be calculated by max1.__sub__
return max1 - max1
if max1 <= max2:
return max1 - min2
# max1 > max2
return max2 - min2
# min1 > min2
if max2 <= min1:
return max2 - max2
if max2 <= max1:
return max2 - min1
return max1 - min1
def _re_sample_buckets(self):
"""Re-samples buckets according to visual_max, visual_min and visual_bins."""
if self._visual_max == self._visual_min:
# Adjust visual range if max equals min.
self._visual_max += 0.5
self._visual_min -= 0.5
width = (self._visual_max - self._visual_min) / self._visual_bins
if not self._count:
self._re_sampled_buckets = tuple(
Bucket(self._visual_min + width * i, width, 0)
for i in range(self._visual_bins))
return
re_sampled = []
original_pos = 0
original_bucket = self._original_buckets[original_pos]
for i in range(self._visual_bins):
cur_left = self._visual_min + width * i
cur_right = cur_left + width
cur_estimated_count = 0.0
# Skip no bucket range.
if cur_right <= original_bucket.left:
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
continue
# Skip no intersect range.
while cur_left >= original_bucket.right:
original_pos += 1
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
while True:
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
intersection = self._calc_intersection_len(
min1=cur_left, max1=cur_right,
min2=original_bucket.left, max2=original_bucket.right)
if not original_bucket.width:
estimated_count = original_bucket.count
else:
estimated_count = (intersection / original_bucket.width) * original_bucket.count
cur_estimated_count += estimated_count
if cur_right > original_bucket.right:
# Need to sample next original bucket to this visual bucket.
original_pos += 1
else:
# Current visual bucket has taken all intersect buckets into account.
break
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
self._re_sampled_buckets = tuple(re_sampled)
def buckets(self, convert_to_tuple=True):
"""
Get visual buckets instead of original buckets.
Args:
convert_to_tuple (bool): Whether convert bucket object to tuple.
Returns:
tuple, contains buckets.
"""
if not self._re_sampled_buckets:
self._re_sample_buckets()
if not convert_to_tuple:
return self._re_sampled_buckets
return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets)
...@@ -13,90 +13,27 @@ ...@@ -13,90 +13,27 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Histogram data container.""" """Histogram data container."""
import math from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket, mask_invalid_number
from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.utils.utils import calc_histogram_bins
def _mask_invalid_number(num):
"""Mask invalid number to 0."""
if math.isnan(num) or math.isinf(num):
return type(num)(0)
return num
class Bucket:
"""
Bucket data class.
Args:
left (double): Left edge of the histogram bucket.
width (double): Width of the histogram bucket.
count (int): Count of numbers fallen in the histogram bucket.
"""
def __init__(self, left, width, count):
self._left = left
self._width = width
self._count = count
@property
def left(self):
"""Gets left edge of the histogram bucket."""
return self._left
@property
def count(self):
"""Gets count of numbers fallen in the histogram bucket."""
return self._count
@property
def width(self):
"""Gets width of the histogram bucket."""
return self._width
@property
def right(self):
"""Gets right edge of the histogram bucket."""
return self._left + self._width
def as_tuple(self):
"""Gets the bucket as tuple."""
return self._left, self._width, self._count
def __repr__(self):
"""Returns repr(self)."""
return "Bucket(left={}, width={}, count={})".format(self._left, self._width, self._count)
class HistogramContainer: class HistogramContainer:
""" """
Histogram data container. Histogram data container.
Args: Args:
histogram_message (Summary.Histogram): Histogram message in summary file. histogram_message (Summary.Histogram): Histogram message in summary file.
""" """
# Max quantity of original buckets.
MAX_ORIGINAL_BUCKETS_COUNT = 90
def __init__(self, histogram_message: Summary.Histogram): def __init__(self, histogram_message: Summary.Histogram):
self._msg = histogram_message self._msg = histogram_message
original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets] original_buckets = [Bucket(bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets]
# Ensure buckets are sorted from min to max. # Ensure buckets are sorted from min to max.
original_buckets.sort(key=lambda bucket: bucket.left) original_buckets.sort(key=lambda bucket: bucket.left)
self._original_buckets = tuple(original_buckets) self._count = sum(bucket.count for bucket in original_buckets)
self._count = sum(bucket.count for bucket in self._original_buckets) self._max = mask_invalid_number(histogram_message.max)
self._max = _mask_invalid_number(histogram_message.max) self._min = mask_invalid_number(histogram_message.min)
self._min = _mask_invalid_number(histogram_message.min) self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count)
self._visual_max = self._max
self._visual_min = self._min
# default bin number
self._visual_bins = calc_histogram_bins(self._count)
# Note that tuple is immutable, so sharing tuple is often safe.
self._re_sampled_buckets = ()
@property @property
def max(self): def max(self):
...@@ -114,148 +51,10 @@ class HistogramContainer: ...@@ -114,148 +51,10 @@ class HistogramContainer:
return self._count return self._count
@property @property
def original_msg(self): def histogram(self):
"""Gets original proto message.""" """Gets histogram data"""
return self._msg return self._histogram
@property
def original_buckets_count(self):
"""Gets original buckets quantity."""
return len(self._original_buckets)
def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None:
"""
Sets visual range for later re-sampling.
It's caller's duty to ensure input is valid.
Why we need visual range for histograms? Aligned buckets between steps can help users know about the trend of
tensors. Miss aligned buckets between steps might miss-lead users about the trend of a tensor. Because for
given tensor, if you have thinner buckets, count of every bucket will get lower, however, if you have
thicker buckets, count of every bucket will get higher. When they are displayed together, user might think
the histogram with thicker buckets has more values. This is miss-leading. So we need to unify buckets across
steps. Visual range for histogram is a technology for unifying buckets.
Args:
max_val (float): Max value for visual histogram.
min_val (float): Min value for visual histogram.
bins (int): Bins number for visual histogram.
"""
if max_val < min_val:
raise ParamValueError(
"Invalid input. max_val({}) is less or equal than min_val({}).".format(max_val, min_val))
if bins < 1:
raise ParamValueError("Invalid input bins({}). Must be greater than 0.".format(bins))
self._visual_max = max_val
self._visual_min = min_val
self._visual_bins = bins
# mark _re_sampled_buckets to empty
self._re_sampled_buckets = ()
def _calc_intersection_len(self, max1, min1, max2, min2):
"""Calculates intersection length of [min1, max1] and [min2, max2]."""
if max1 < min1:
raise ParamValueError(
"Invalid input. max1({}) is less than min1({}).".format(max1, min1))
if max2 < min2:
raise ParamValueError(
"Invalid input. max2({}) is less than min2({}).".format(max2, min2))
if min1 <= min2:
if max1 <= min2:
# return value must be calculated by max1.__sub__
return max1 - max1
if max1 <= max2:
return max1 - min2
# max1 > max2
return max2 - min2
# min1 > min2
if max2 <= min1:
return max2 - max2
if max2 <= max1:
return max2 - min1
return max1 - min1
def _re_sample_buckets(self):
"""Re-samples buckets according to visual_max, visual_min and visual_bins."""
if self._visual_max == self._visual_min:
# Adjust visual range if max equals min.
self._visual_max += 0.5
self._visual_min -= 0.5
width = (self._visual_max - self._visual_min) / self._visual_bins
if not self.count:
self._re_sampled_buckets = tuple(
Bucket(self._visual_min + width * i, width, 0)
for i in range(self._visual_bins))
return
re_sampled = []
original_pos = 0
original_bucket = self._original_buckets[original_pos]
for i in range(self._visual_bins):
cur_left = self._visual_min + width * i
cur_right = cur_left + width
cur_estimated_count = 0.0
# Skip no bucket range.
if cur_right <= original_bucket.left:
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
continue
# Skip no intersect range.
while cur_left >= original_bucket.right:
original_pos += 1
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
# entering with this condition: cur_right > original_bucket.left and cur_left < original_bucket.right
while True:
if original_pos >= len(self._original_buckets):
break
original_bucket = self._original_buckets[original_pos]
intersection = self._calc_intersection_len(
min1=cur_left, max1=cur_right,
min2=original_bucket.left, max2=original_bucket.right)
if not original_bucket.width:
estimated_count = original_bucket.count
else:
estimated_count = (intersection / original_bucket.width) * original_bucket.count
cur_estimated_count += estimated_count
if cur_right > original_bucket.right:
# Need to sample next original bucket to this visual bucket.
original_pos += 1
else:
# Current visual bucket has taken all intersect buckets into account.
break
re_sampled.append(Bucket(cur_left, width, math.ceil(cur_estimated_count)))
self._re_sampled_buckets = tuple(re_sampled)
def buckets(self, convert_to_tuple=True):
"""
Get visual buckets instead of original buckets.
Args:
convert_to_tuple (bool): Whether convert bucket object to tuple.
Returns:
tuple, contains buckets.
"""
if not self._re_sampled_buckets:
self._re_sample_buckets()
if not convert_to_tuple:
return self._re_sampled_buckets
return tuple(bucket.as_tuple() for bucket in self._re_sampled_buckets) def buckets(self):
"""Gets histogram buckets"""
return self._histogram.buckets()
...@@ -36,7 +36,9 @@ from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summar ...@@ -36,7 +36,9 @@ from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summar
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils import crc32 from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import UnknownError from mindinsight.utils.exceptions import UnknownError
from mindinsight.datavisual.data_transform.histogram import Histogram
from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer
from mindinsight.datavisual.data_transform.tensor_container import TensorContainer
HEADER_SIZE = 8 HEADER_SIZE = 8
CRC_STR_SIZE = 4 CRC_STR_SIZE = 4
...@@ -390,6 +392,7 @@ class _SummaryParser(_Parser): ...@@ -390,6 +392,7 @@ class _SummaryParser(_Parser):
'scalar_value': PluginNameEnum.SCALAR, 'scalar_value': PluginNameEnum.SCALAR,
'image': PluginNameEnum.IMAGE, 'image': PluginNameEnum.IMAGE,
'histogram': PluginNameEnum.HISTOGRAM, 'histogram': PluginNameEnum.HISTOGRAM,
'tensor': PluginNameEnum.TENSOR
} }
if event.HasField('summary'): if event.HasField('summary'):
...@@ -404,10 +407,12 @@ class _SummaryParser(_Parser): ...@@ -404,10 +407,12 @@ class _SummaryParser(_Parser):
tensor_event_value = HistogramContainer(tensor_event_value) tensor_event_value = HistogramContainer(tensor_event_value)
# Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT # Drop steps if original_buckets_count exceeds HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT
# to avoid time-consuming re-sample process. # to avoid time-consuming re-sample process.
if tensor_event_value.original_buckets_count > HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT: if tensor_event_value.histogram.original_buckets_count > Histogram.MAX_ORIGINAL_BUCKETS_COUNT:
logger.info('original_buckets_count exceeds ' logger.info('original_buckets_count exceeds '
'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT') 'HistogramContainer.MAX_ORIGINAL_BUCKETS_COUNT')
continue continue
elif plugin == 'tensor':
tensor_event_value = TensorContainer(tensor_event_value)
tensor_event = TensorEvent(wall_time=event.wall_time, tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step, step=event.step,
......
...@@ -205,12 +205,12 @@ class HistogramReservoir(Reservoir): ...@@ -205,12 +205,12 @@ class HistogramReservoir(Reservoir):
visual_range = _VisualRange() visual_range = _VisualRange()
max_count = 0 max_count = 0
for sample in self._samples: for sample in self._samples:
histogram = sample.value histogram_container = sample.value
if histogram.count == 0: if histogram_container.count == 0:
# ignore empty tensor # ignore empty tensor
continue continue
max_count = max(histogram.count, max_count) max_count = max(histogram_container.count, max_count)
visual_range.update(histogram.max, histogram.min) visual_range.update(histogram_container.max, histogram_container.min)
if visual_range.max == visual_range.min and not max_count: if visual_range.max == visual_range.min and not max_count:
logger.info("Max equals to min. Count is zero.") logger.info("Max equals to min. Count is zero.")
...@@ -225,7 +225,7 @@ class HistogramReservoir(Reservoir): ...@@ -225,7 +225,7 @@ class HistogramReservoir(Reservoir):
bins, bins,
max_count) max_count)
for sample in self._samples: for sample in self._samples:
histogram = sample.value histogram = sample.value.histogram
histogram.set_visual_range(visual_range.max, visual_range.min, bins) histogram.set_visual_range(visual_range.max, visual_range.min, bins)
self._visual_range_up_to_date = True self._visual_range_up_to_date = True
...@@ -245,6 +245,6 @@ class ReservoirFactory: ...@@ -245,6 +245,6 @@ class ReservoirFactory:
Returns: Returns:
Reservoir, reservoir instance for given plugin name. Reservoir, reservoir instance for given plugin name.
""" """
if plugin_name == PluginNameEnum.HISTOGRAM.value: if plugin_name in (PluginNameEnum.HISTOGRAM.value, PluginNameEnum.TENSOR.value):
return HistogramReservoir(size) return HistogramReservoir(size)
return Reservoir(size) return Reservoir(size)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensor data container."""
import threading
import numpy as np
from mindinsight.datavisual.common.log import logger
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket
from mindinsight.datavisual.utils.utils import calc_histogram_bins
from mindinsight.utils.exceptions import ParamValueError
F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
class Statistics:
"""Statistics data class.
Args:
max_value (float): max value of tensor data.
min_value (float): min value of tensor data.
avg_value (float): avg value of tensor data.
count (int): total count of tensor data.
nan_count (int): count of NAN.
neg_inf_count (int): count of negative INF.
pos_inf_count (int): count of positive INF.
"""
def __init__(self, max_value=0, min_value=0, avg_value=0,
count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0):
self._max = max_value
self._min = min_value
self._avg = avg_value
self._count = count
self._nan_count = nan_count
self._neg_inf_count = neg_inf_count
self._pos_inf_count = pos_inf_count
@property
def max(self):
"""Get max value of tensor."""
return self._max
@property
def min(self):
"""Get min value of tensor."""
return self._min
@property
def avg(self):
"""Get avg value of tensor."""
return self._avg
@property
def count(self):
"""Get total count of tensor."""
return self._count
@property
def nan_count(self):
"""Get count of NAN."""
return self._nan_count
@property
def neg_inf_count(self):
"""Get count of negative INF."""
return self._neg_inf_count
@property
def pos_inf_count(self):
"""Get count of positive INF."""
return self._pos_inf_count
def get_statistics_from_tensor(tensors):
"""
Calculates statistics data of tensor.
Args:
tensors (numpy.ndarray): An numpy.ndarray of tensor data.
Returns:
an instance of Statistics.
"""
ma_value = np.ma.masked_invalid(tensors)
total, valid = tensors.size, ma_value.count()
invalids = []
for isfn in np.isnan, np.isposinf, np.isneginf:
if total - valid > sum(invalids):
count = np.count_nonzero(isfn(tensors))
invalids.append(count)
else:
invalids.append(0)
nan_count, pos_inf_count, neg_inf_count = invalids
if not valid:
logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape)
statistics = Statistics(max_value=0,
min_value=0,
avg_value=0,
count=total,
nan_count=nan_count,
neg_inf_count=neg_inf_count,
pos_inf_count=pos_inf_count)
return statistics
# BUG: max of a masked array with dtype np.float16 returns inf
# See numpy issue#15077
if issubclass(tensors.dtype.type, np.floating):
tensor_min = ma_value.min(fill_value=np.PINF)
tensor_max = ma_value.max(fill_value=np.NINF)
if tensor_min < F32_MIN or tensor_max > F32_MAX:
logger.warning('Values(%f, %f) are too large, you may encounter some undefined '
'behaviours hereafter.', tensor_min, tensor_max)
else:
tensor_min = ma_value.min()
tensor_max = ma_value.max()
tensor_sum = ma_value.sum(dtype=np.float64)
statistics = Statistics(max_value=tensor_max,
min_value=tensor_min,
avg_value=tensor_sum / valid,
count=total,
nan_count=nan_count,
neg_inf_count=neg_inf_count,
pos_inf_count=pos_inf_count)
return statistics
def _get_data_from_tensor(tensor):
"""
Get data from tensor and convert to tuple.
Args:
tensor (TensorProto): Tensor proto data.
Returns:
tuple, the item of tensor value.
"""
return tuple(tensor.float_data)
def calc_original_buckets(np_value, stats):
"""
Calculate buckets from tensor data.
Args:
np_value (numpy.ndarray): An numpy.ndarray of tensor data.
stats (Statistics): An instance of Statistics about tensor data.
Returns:
list, a list of bucket about tensor data.
Raises:
ParamValueError, If np_value or stats is None.
"""
if np_value is None or stats is None:
raise ParamValueError("Invalid input. np_value or stats is None.")
valid_count = stats.count - stats.nan_count - stats.neg_inf_count - stats.pos_inf_count
if not valid_count:
return []
bins = calc_histogram_bins(valid_count)
first_edge, last_edge = stats.min, stats.max
if not first_edge < last_edge:
first_edge -= 0.5
last_edge += 0.5
bins = np.linspace(first_edge, last_edge, bins + 1, dtype=np_value.dtype)
hists, edges = np.histogram(np_value, bins=bins)
buckets = []
for hist, edge1, edge2 in zip(hists, edges, edges[1:]):
bucket = Bucket(edge1, edge2 - edge1, hist)
buckets.append(bucket)
return buckets
class TensorContainer:
"""
Tensor data container.
Args:
tensor_message (Summary.TensorProto): Tensor message in summary file.
"""
def __init__(self, tensor_message):
self._lock = threading.Lock
self._msg = tensor_message
self._dims = tensor_message.dims
self._data_type = tensor_message.data_type
self._np_array = None
self._data = _get_data_from_tensor(tensor_message)
self._stats = get_statistics_from_tensor(self.get_or_calc_ndarray())
original_buckets = calc_original_buckets(self.get_or_calc_ndarray(), self._stats)
self._count = sum(bucket.count for bucket in original_buckets)
self._max = self._stats.max
self._min = self._stats.min
self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count)
@property
def dims(self):
"""Get dims of tensor."""
return self._dims
@property
def data_type(self):
"""Get data type of tensor."""
return self._data_type
@property
def max(self):
"""Get max value of tensor."""
return self._max
@property
def min(self):
"""Get min value of tensor."""
return self._min
@property
def stats(self):
"""Get statistics data of tensor."""
return self._stats
@property
def count(self):
"""Get count value of tensor."""
return self._count
@property
def histogram(self):
"""Get histogram data."""
return self._histogram
def buckets(self):
"""Get histogram buckets."""
return self._histogram.buckets()
def get_or_calc_ndarray(self):
"""Get or calculate ndarray."""
with self._lock():
if self._np_array is None:
self._convert_to_numpy_array()
return self._np_array
def _convert_to_numpy_array(self):
"""Convert a list data to numpy array."""
try:
ndarray = np.array(self._data).reshape(self._dims)
except ValueError as ex:
logger.error("Reshape array fail, detail: %r", str(ex))
return
self._msg = None
self._np_array = ndarray
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensor Processor APIs."""
from urllib.parse import unquote
import numpy as np
from mindinsight.datavisual.utils.tools import to_int
from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError
from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError
from mindinsight.datavisual.data_transform.tensor_container import TensorContainer, get_statistics_from_tensor
from mindinsight.datavisual.processors.base_processor import BaseProcessor
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
def convert_array_from_str(dims, limit=0):
"""
Convert string of dims data to array.
Args:
dims (str): Specify dims of tensor.
limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation.
Returns:
list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None].
Raises:
ParamValueError, If flexible dimensions exceed limit value.
"""
dims = dims.replace('[', '') \
.replace(']', '')
dims_list = []
count = 0
for dim in dims.split(','):
dim = dim.strip()
if dim == ':':
dims_list.append(None)
count += 1
else:
dims_list.append(to_int(dim, "dim"))
if limit and count > limit:
raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}"
.format(limit, count))
return dims_list
def get_specific_dims_data(ndarray, dims, tensor_dims):
"""
Get specific dims data.
Args:
ndarray (numpy.ndarray): An ndarray of numpy.
dims (list): A list of specific dims.
tensor_dims (list): A list of tensor dims.
Returns:
numpy.ndarray, an ndarray of specific dims tensor data.
Raises:
ParamValueError, If the length of param dims is not equal to the length of tensor dims or
the index of param dims out of range.
"""
if len(dims) != len(tensor_dims):
raise ParamValueError("The length of param dims: {}, is not equal to the "
"length of tensor dims: {}.".format(len(dims), len(tensor_dims)))
indices = []
for k, d in enumerate(dims):
if d is not None:
if d >= tensor_dims[k]:
raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k]))
indices.append(d)
else:
indices.append(slice(0, tensor_dims[k]))
return ndarray[tuple(indices)]
def get_statistics_dict(tensor_container, tensors):
"""
Get statistics dict according to tensor data.
Args:
tensor_container (TensorContainer): An instance of TensorContainer.
tensors (numpy.ndarray or number): An numpy.ndarray or number of tensor data.
Returns:
dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'.
"""
if tensors is None:
statistics = {
"max": tensor_container.stats.max,
"min": tensor_container.stats.min,
"avg": tensor_container.stats.avg,
"count": tensor_container.stats.count,
"nan_count": tensor_container.stats.nan_count,
"neg_inf_count": tensor_container.stats.neg_inf_count,
"pos_inf_count": tensor_container.stats.pos_inf_count
}
return statistics
if not isinstance(tensors, np.ndarray):
tensors = np.array(tensors)
stats = get_statistics_from_tensor(tensors)
statistics = {
"max": stats.max,
"min": stats.min,
"avg": stats.avg,
"count": stats.count,
"nan_count": stats.nan_count,
"neg_inf_count": stats.neg_inf_count,
"pos_inf_count": stats.pos_inf_count
}
return statistics
class TensorProcessor(BaseProcessor):
"""Tensor Processor."""
def get_tensors(self, train_ids, tags, step, dims, detail):
"""
Get tensor data for given train_ids, tags, step, dims and detail.
Args:
train_ids (list): Specify list of train job ID.
tags (list): Specify list of tag.
step (int): Specify step of tag, it's necessary when detail is equal to 'data'.
dims (str): Specify dims of step, it's necessary when detail is equal to 'data'.
detail (str): Specify which data to query, available values: 'stats', 'histogram' and 'data'.
Returns:
dict, a dict including the `tensors`.
Raises:
UrlDecodeError, If unquote train id error with strict mode.
"""
Validation.check_param_empty(train_id=train_ids, tag=tags)
for index, train_id in enumerate(train_ids):
try:
train_id = unquote(train_id, errors='strict')
except UnicodeDecodeError:
raise UrlDecodeError('Unquote train id error with strict mode')
else:
train_ids[index] = train_id
tensors = []
for train_id in train_ids:
tensors += self._get_train_tensors(train_id, tags, step, dims, detail)
return {"tensors": tensors}
def _get_train_tensors(self, train_id, tags, step, dims, detail):
"""
Get tensor data for given train_id, tags, step, dims and detail.
Args:
train_id (str): Specify list of train job ID.
tags (list): Specify list of tag.
step (int): Specify step of tensor, it's necessary when detail is set to 'data'.
dims (str): Specify dims of tensor, it's necessary when detail is set to 'data'.
detail (str): Specify which data to query, available values: 'stats', 'histogram' and 'data'.
Returns:
list[dict], a list of dictionaries containing the `train_id`, `tag`, `values`.
Raises:
TensorNotExistError, If tensor with specific train_id and tag is not exist in cache.
ParamValueError, If the value of detail is not within available values:
'stats', 'histogram' and 'data'.
"""
tensors_response = []
for tag in tags:
try:
tensors = self._data_manager.list_tensors(train_id, tag)
except ParamValueError as err:
raise TensorNotExistError(err.message)
if tensors and not isinstance(tensors[0].value, TensorContainer):
raise TensorNotExistError("there is no tensor data in this tag: {}".format(tag))
if detail is None or detail == 'stats':
values = self._get_tensors_summary(detail, tensors)
elif detail == 'data':
Validation.check_param_empty(step=step, dims=dims)
step = to_int(step, "step")
values = self._get_tensors_data(step, dims, tensors)
elif detail == 'histogram':
values = self._get_tensors_histogram(tensors)
else:
raise ParamValueError('Can not support this value: {} of detail.'.format(detail))
tensor = {
"train_id": train_id,
"tag": tag,
"values": values
}
tensors_response.append(tensor)
return tensors_response
def _get_tensors_summary(self, detail, tensors):
"""
Builds a JSON-serializable object with information about tensor summary.
Args:
detail (str): Specify which data to query, detail value is None or 'stats' at this method.
tensors (list): The list of _Tensor data.
Returns:
dict, a dict including the `wall_time`, `step`, and `value' for each tensor.
{
"wall_time": 0,
"step": 0,
"value": {
"dims": [1],
"data_type": "DT_FLOAT32"
"statistics": {
"max": 0,
"min": 0,
"avg": 0,
"count": 1,
"nan_count": 0,
"neg_inf_count": 0,
"pos_inf_count": 0
} This dict is being set when detail is equal to stats.
}
}
"""
values = []
for tensor in tensors:
# This value is an instance of TensorContainer
value = tensor.value
value_dict = {
"dims": tuple(value.dims),
"data_type": anf_ir_pb2.DataType.Name(value.data_type)
}
if detail and detail == 'stats':
stats = get_statistics_dict(value, None)
value_dict.update({"statistics": stats})
values.append({
"wall_time": tensor.wall_time,
"step": tensor.step,
"value": value_dict
})
return values
def _get_tensors_data(self, step, dims, tensors):
"""
Builds a JSON-serializable object with information about tensor dims data.
Args:
step (int): Specify step of tensor.
dims (str): Specify dims of tensor.
tensors (list): The list of _Tensor data.
Returns:
dict, a dict including the `wall_time`, `step`, and `value' for each tensor.
{
"wall_time": 0,
"step": 0,
"value": {
"dims": [1],
"data_type": "DT_FLOAT32",
"data": [[0.1]]
"statistics": {
"max": 0,
"min": 0,
"avg": 0,
"count": 1,
"nan_count": 0,
"neg_inf_count": 0,
"pos_inf_count": 0
}
}
}
Raises:
ResponseDataExceedMaxValueError, If the size of response data exceed max value.
StepTensorDataNotInCacheError, If query step is not in cache.
"""
values = []
step_in_cache = False
dims = convert_array_from_str(dims, limit=2)
for tensor in tensors:
# This value is an instance of TensorContainer
value = tensor.value
if step != tensor.step:
continue
step_in_cache = True
ndarray = value.get_or_calc_ndarray()
res_data = get_specific_dims_data(ndarray, dims, list(value.dims))
flatten_data = res_data.flatten().tolist()
if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE:
raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}."
.format(len(flatten_data), MAX_TENSOR_RESPONSE_DATA_SIZE))
values.append({
"wall_time": tensor.wall_time,
"step": tensor.step,
"value": {
"dims": tuple(value.dims),
"data_type": anf_ir_pb2.DataType.Name(value.data_type),
"data": res_data.tolist(),
"statistics": get_statistics_dict(value, flatten_data)
}
})
break
if not step_in_cache:
raise StepTensorDataNotInCacheError("this step: {} data may has been dropped.".format(step))
return values
def _get_tensors_histogram(self, tensors):
"""
Builds a JSON-serializable object with information about tensor histogram data.
Args:
tensors (list): The list of _Tensor data.
Returns:
dict, a dict including the `wall_time`, `step`, and `value' for each tensor.
{
"wall_time": 0,
"step": 0,
"value": {
"dims": [1],
"data_type": "DT_FLOAT32",
"histogram_buckets": [[0.1, 0.2, 3]]
"statistics": {
"max": 0,
"min": 0,
"avg": 0,
"count": 1,
"nan_count": 0,
"neg_inf_count": 0,
"pos_inf_count": 0
}
}
}
"""
values = []
for tensor in tensors:
# This value is an instance of TensorContainer
value = tensor.value
buckets = value.buckets()
values.append({
"wall_time": tensor.wall_time,
"step": tensor.step,
"value": {
"dims": tuple(value.dims),
"data_type": anf_ir_pb2.DataType.Name(value.data_type),
"histogram_buckets": buckets,
"statistics": get_statistics_dict(value, None)
}
})
return values
...@@ -71,6 +71,9 @@ class DataVisualErrors(Enum): ...@@ -71,6 +71,9 @@ class DataVisualErrors(Enum):
HISTOGRAM_NOT_EXIST = 15 HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
QUERY_STRING_CONTAINS_NULL_BYTE = 17 QUERY_STRING_CONTAINS_NULL_BYTE = 17
TENSOR_NOT_EXIST = 18
MAX_RESPONSE_DATA_EXCEEDED_ERROR = 19
STEP_TENSOR_DATA_NOT_IN_CACHE = 20
class ScriptConverterErrors(Enum): class ScriptConverterErrors(Enum):
......
...@@ -29,9 +29,9 @@ class TestHistogram: ...@@ -29,9 +29,9 @@ class TestHistogram:
mocked_bucket.width = 1 mocked_bucket.width = 1
mocked_bucket.count = 1 mocked_bucket.count = 1
mocked_input.buckets = [mocked_bucket] mocked_input.buckets = [mocked_bucket]
histogram = hist.HistogramContainer(mocked_input) histogram_container = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=1) histogram_container.histogram.set_visual_range(max_val=1, min_val=0, bins=1)
buckets = histogram.buckets() buckets = histogram_container.buckets()
assert buckets == ((0.0, 1.0, 1),) assert buckets == ((0.0, 1.0, 1),)
def test_re_sample_buckets_split_original(self): def test_re_sample_buckets_split_original(self):
...@@ -42,9 +42,9 @@ class TestHistogram: ...@@ -42,9 +42,9 @@ class TestHistogram:
mocked_bucket.width = 1 mocked_bucket.width = 1
mocked_bucket.count = 1 mocked_bucket.count = 1
mocked_input.buckets = [mocked_bucket] mocked_input.buckets = [mocked_bucket]
histogram = hist.HistogramContainer(mocked_input) histogram_container = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=1, min_val=0, bins=3) histogram_container.histogram.set_visual_range(max_val=1, min_val=0, bins=3)
buckets = histogram.buckets() buckets = histogram_container.buckets()
assert buckets == ((0.0, 0.3333333333333333, 1), (0.3333333333333333, 0.3333333333333333, 1), assert buckets == ((0.0, 0.3333333333333333, 1), (0.3333333333333333, 0.3333333333333333, 1),
(0.6666666666666666, 0.3333333333333333, 1)) (0.6666666666666666, 0.3333333333333333, 1))
...@@ -60,9 +60,9 @@ class TestHistogram: ...@@ -60,9 +60,9 @@ class TestHistogram:
mocked_bucket2.width = 1 mocked_bucket2.width = 1
mocked_bucket2.count = 2 mocked_bucket2.count = 2
mocked_input.buckets = [mocked_bucket, mocked_bucket2] mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input) histogram_container = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=3, min_val=-1, bins=4) histogram_container.histogram.set_visual_range(max_val=3, min_val=-1, bins=4)
buckets = histogram.buckets() buckets = histogram_container.buckets()
assert buckets == ((-1.0, 1.0, 0), (0.0, 1.0, 1), (1.0, 1.0, 2), (2.0, 1.0, 0)) assert buckets == ((-1.0, 1.0, 0), (0.0, 1.0, 1), (1.0, 1.0, 2), (2.0, 1.0, 0))
def test_re_sample_buckets_merge_bucket(self): def test_re_sample_buckets_merge_bucket(self):
...@@ -77,9 +77,9 @@ class TestHistogram: ...@@ -77,9 +77,9 @@ class TestHistogram:
mocked_bucket2.width = 1 mocked_bucket2.width = 1
mocked_bucket2.count = 10 mocked_bucket2.count = 10
mocked_input.buckets = [mocked_bucket, mocked_bucket2] mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input) histogram_container = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=3, min_val=-1, bins=5) histogram_container.histogram.set_visual_range(max_val=3, min_val=-1, bins=5)
buckets = histogram.buckets() buckets = histogram_container.buckets()
assert buckets == ( assert buckets == (
(-1.0, 0.8, 0), (-0.19999999999999996, 0.8, 1), (0.6000000000000001, 0.8, 5), (1.4000000000000004, 0.8, 6), (-1.0, 0.8, 0), (-0.19999999999999996, 0.8, 1), (0.6000000000000001, 0.8, 5), (1.4000000000000004, 0.8, 6),
(2.2, 0.8, 0)) (2.2, 0.8, 0))
...@@ -96,9 +96,9 @@ class TestHistogram: ...@@ -96,9 +96,9 @@ class TestHistogram:
mocked_bucket2.width = 0 mocked_bucket2.width = 0
mocked_bucket2.count = 2 mocked_bucket2.count = 2
mocked_input.buckets = [mocked_bucket, mocked_bucket2] mocked_input.buckets = [mocked_bucket, mocked_bucket2]
histogram = hist.HistogramContainer(mocked_input) histogram_container = hist.HistogramContainer(mocked_input)
histogram.set_visual_range(max_val=2, min_val=0, bins=3) histogram_container.histogram.set_visual_range(max_val=2, min_val=0, bins=3)
buckets = histogram.buckets() buckets = histogram_container.buckets()
assert buckets == ( assert buckets == (
(0.0, 0.6666666666666666, 1), (0.0, 0.6666666666666666, 1),
(0.6666666666666666, 0.6666666666666666, 3), (0.6666666666666666, 0.6666666666666666, 3),
......
...@@ -69,7 +69,7 @@ class TestTrainTaskManager: ...@@ -69,7 +69,7 @@ class TestTrainTaskManager:
def load_data(self): def load_data(self):
"""Load data.""" """Load data."""
log_operation = LogOperations() log_operation = LogOperations()
self._plugins_id_map = {'image': [], 'scalar': [], 'graph': [], 'histogram': []} self._plugins_id_map = {'image': [], 'scalar': [], 'graph': [], 'histogram': [], 'tensor': []}
self._events_names = [] self._events_names = []
self._train_id_list = [] self._train_id_list = []
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Log generator for tensor data."""
import time
from operator import mul
from functools import reduce
import numpy as np
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from .log_generator import LogGenerator
class TensorLogGenerator(LogGenerator):
"""
Log generator for tensor data.
This is a log generator writing tensor data. User can use it to generate fake
summary logs about tensor.
"""
def generate_event(self, values):
"""
Method for generating tensor event.
Args:
values (dict): A dict contains:
{
wall_time (float): Timestamp.
step (int): Train step.
value (float): Tensor value.
tag (str): Tag name.
}
Returns:
summary_pb2.Event.
"""
tensor_event = summary_pb2.Event()
tensor_event.wall_time = values.get('wall_time')
tensor_event.step = values.get('step')
value = tensor_event.summary.value.add()
value.tag = values.get('tag')
tensor = values.get('value')
value.tensor.dims[:] = tensor.get('dims')
value.tensor.data_type = tensor.get('data_type')
value.tensor.float_data[:] = tensor.get('float_data')
print(tensor.get('float_data'))
return tensor_event
def generate_log(self, file_path, steps_list, tag_name):
"""
Generate log for external calls.
Args:
file_path (str): Path to write logs.
steps_list (list): A list consists of step.
tag_name (str): Tag name.
Returns:
list[dict], generated tensor metadata.
list, generated tensors.
"""
tensor_metadata = []
tensor_values = dict()
for step in steps_list:
tensor = dict()
wall_time = time.time()
tensor.update({'wall_time': wall_time})
tensor.update({'step': step})
tensor.update({'tag': tag_name})
dims = list(np.random.randint(1, 10, 4))
mul_value = reduce(mul, dims)
tensor.update({'value': {
"dims": dims,
"data_type": anf_ir_pb2.DataType.DT_FLOAT32,
"float_data": np.random.randn(mul_value)
}})
tensor_metadata.append(tensor)
tensor_values.update({step: tensor})
self._write_log_one_step(file_path, tensor)
return tensor_metadata, tensor_values
if __name__ == "__main__":
tensor_log_generator = TensorLogGenerator()
test_file_name = '%s.%s.%s' % ('tensor', 'summary', str(time.time()))
test_steps = [1, 3, 5]
test_tag = "test_tensor_tag_name"
tensor_log_generator.generate_log(test_file_name, test_steps, test_tag)
...@@ -25,12 +25,14 @@ from .log_generators.graph_log_generator import GraphLogGenerator ...@@ -25,12 +25,14 @@ from .log_generators.graph_log_generator import GraphLogGenerator
from .log_generators.images_log_generator import ImagesLogGenerator from .log_generators.images_log_generator import ImagesLogGenerator
from .log_generators.scalars_log_generator import ScalarsLogGenerator from .log_generators.scalars_log_generator import ScalarsLogGenerator
from .log_generators.histogram_log_generator import HistogramLogGenerator from .log_generators.histogram_log_generator import HistogramLogGenerator
from .log_generators.tensor_log_generator import TensorLogGenerator
log_generators = { log_generators = {
PluginNameEnum.GRAPH.value: GraphLogGenerator(), PluginNameEnum.GRAPH.value: GraphLogGenerator(),
PluginNameEnum.IMAGE.value: ImagesLogGenerator(), PluginNameEnum.IMAGE.value: ImagesLogGenerator(),
PluginNameEnum.SCALAR.value: ScalarsLogGenerator(), PluginNameEnum.SCALAR.value: ScalarsLogGenerator(),
PluginNameEnum.HISTOGRAM.value: HistogramLogGenerator() PluginNameEnum.HISTOGRAM.value: HistogramLogGenerator(),
PluginNameEnum.TENSOR.value: TensorLogGenerator()
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册