提交 18580a78 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!261 Add summary module python code to extract histogram info from tensor

Merge pull request !261 from wenkai/histogram_dev3cp
......@@ -61,6 +61,30 @@ message Summary {
required bytes encoded_image = 4;
}
message Histogram {
message bucket{
// Count number of values fallen in [left, left + width).
// For the right most bucket, range is [left, left + width].
required double left = 1;
required double width = 2;
required int64 count = 3;
}
repeated bucket buckets = 1;
optional int64 nan_count = 2;
optional int64 pos_inf_count = 3;
optional int64 neg_inf_count = 4;
// max, min, sum will not take nan and inf into account.
// If there is no valid value in tensor, max will be nan, min will be nan, sum will be 0.
optional double max = 5;
optional double min = 6;
optional double sum = 7;
// total number of values, including nan and inf
optional int64 count = 8;
}
message Value {
// Tag name for the data.
required string tag = 1;
......@@ -70,6 +94,7 @@ message Summary {
float scalar_value = 3;
Image image = 4;
TensorProto tensor = 8;
Histogram histogram = 9;
}
}
......
......@@ -71,12 +71,14 @@ class SummaryType(Enum):
TENSOR (Number): Summary TENSOR enum.
IMAGE (Number): Summary image enum.
GRAPH (Number): Summary graph enum.
HISTOGRAM (Number): Summary histogram enum.
INVALID (Number): Unknow type.
"""
SCALAR = 1 # Scalar summary
TENSOR = 2 # Tensor summary
IMAGE = 3 # Image summary
GRAPH = 4 # graph
HISTOGRAM = 5 # Histogram Summary
INVALID = 0xFF # unknow type
......@@ -148,7 +150,7 @@ def package_summary_event(data_id, step):
"""
data_list = get_summary_data(data_id)
if data_list is None:
logger.error("The step(%r) does not have record data.", self.step)
logger.error("The step(%r) does not have record data.", step)
del_summary_data(data_id)
# create the event of summary
summary_event = Event()
......@@ -177,6 +179,12 @@ def package_summary_event(data_id, step):
summary_value.tag = tag
summary_image = summary_value.image
_get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT)
elif summary_type is SummaryType.HISTOGRAM:
logger.debug("Now process Histogram summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
summary_histogram = summary_value.histogram
_fill_histogram_summary(tag, data, summary_histogram)
else:
# The data is invalid ,jump the data
logger.error("Summary type is error, tag = %r", tag)
......@@ -284,6 +292,74 @@ def _get_tensor_summary(tag: str, np_value, summary_tensor):
return summary_tensor
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
"""
Package the histogram summary.
Args:
tag (str): Summary tag describe.
np_value (np.array): Summary data.
summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data.
"""
logger.debug("Set(%r) the histogram summary value", tag)
# Default bucket for tensor with no valid data.
default_bucket_left = -0.5
default_bucket_width = 1.0
if np_value.size == 0:
bucket = summary_histogram.buckets.add()
bucket.left = default_bucket_left
bucket.width = default_bucket_width
bucket.count = 0
summary_histogram.nan_count = 0
summary_histogram.pos_inf_count = 0
summary_histogram.neg_inf_count = 0
summary_histogram.max = 0
summary_histogram.min = 0
summary_histogram.sum = 0
summary_histogram.count = 0
return
summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value))
summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value))
summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value))
summary_histogram.count = np_value.size
masked_value = np.ma.masked_invalid(np_value)
tensor_max = masked_value.max()
tensor_min = masked_value.min()
tensor_sum = masked_value.sum()
# No valid value in tensor.
if tensor_max is np.ma.masked:
bucket = summary_histogram.buckets.add()
bucket.left = default_bucket_left
bucket.width = default_bucket_width
bucket.count = 0
summary_histogram.max = np.nan
summary_histogram.min = np.nan
summary_histogram.sum = 0
return
counts, edges = np.histogram(np_value, bins='auto', range=(tensor_min, tensor_max))
for ind, count in enumerate(counts):
bucket = summary_histogram.buckets.add()
bucket.left = edges[ind]
bucket.width = edges[ind + 1] - edges[ind]
bucket.count = count
summary_histogram.max = tensor_max
summary_histogram.min = tensor_min
summary_histogram.sum = tensor_sum
def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
"""
Package the image summary.
......
......@@ -23,6 +23,7 @@ from ._summary_adapter import SummaryType, package_summary_event, save_summary_d
FORMAT_SCALAR_STR = "Scalar"
FORMAT_TENSOR_STR = "Tensor"
FORMAT_IMAGE_STR = "Image"
FORMAT_HISTOGRAM_STR = "Histogram"
FORMAT_BEGIN_SLICE = "[:"
FORMAT_END_SLICE = "]"
......@@ -95,6 +96,8 @@ def _parse_tag_format(tag: str):
summary_type = SummaryType.TENSOR
elif type_str == FORMAT_IMAGE_STR:
summary_type = SummaryType.IMAGE
elif type_str == FORMAT_HISTOGRAM_STR:
summary_type = SummaryType.HISTOGRAM
else:
logger.error("The tag(%s) type is invalid.", tag)
summary_type = SummaryType.INVALID
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Summary reader."""
import struct
import mindspore.train.summary_pb2 as summary_pb2
_HEADER_SIZE = 8
_HEADER_CRC_SIZE = 4
_DATA_CRC_SIZE = 4
class SummaryReader:
"""Read events from summary file."""
def __init__(self, file_name):
self._file_name = file_name
self._file_handler = open(self._file_name, "rb")
# skip version event
self.read_event()
def read_event(self):
"""Read next event."""
file_handler = self._file_handler
header = file_handler.read(_HEADER_SIZE)
data_len = struct.unpack('Q', header)[0]
file_handler.read(_HEADER_CRC_SIZE)
event_str = file_handler.read(data_len)
file_handler.read(_DATA_CRC_SIZE)
summary_event = summary_pb2.Event.FromString(event_str)
return summary_event
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test histogram summary."""
import logging
import os
import tempfile
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.train.summary.summary_record import SummaryRecord, _cache_summary_tensor_data
from .summary_reader import SummaryReader
CUR_DIR = os.getcwd()
SUMMARY_DIR = os.path.join(CUR_DIR, "/test_temp_summary_event_file/")
LOG = logging.getLogger("test")
LOG.setLevel(level=logging.ERROR)
def _wrap_test_data(input_data: Tensor):
"""
Wraps test data to summary format.
Args:
input_data (Tensor): Input data.
Returns:
dict, the wrapped data.
"""
return [{
"name": "test_data[:Histogram]",
"data": input_data
}]
def test_histogram_summary():
"""Test histogram summary."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
test_data = _wrap_test_data(Tensor([[1, 2, 3], [4, 5, 6]]))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
assert event.summary.value[0].histogram.count == 6
def test_histogram_multi_summary():
"""Test histogram multiple step."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
rng = np.random.RandomState(10)
size = 50
num_step = 5
for i in range(num_step):
arr = rng.normal(size=size)
test_data = _wrap_test_data(Tensor(arr))
_cache_summary_tensor_data(test_data)
test_writer.record(step=i)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
for _ in range(num_step):
event = reader.read_event()
assert event.summary.value[0].histogram.count == size
def test_histogram_summary_scalar_tensor():
"""Test histogram summary, input is a scalar tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
test_data = _wrap_test_data(Tensor(1))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
assert event.summary.value[0].histogram.count == 1
def test_histogram_summary_empty_tensor():
"""Test histogram summary, input is an empty tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
test_data = _wrap_test_data(Tensor([]))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
assert event.summary.value[0].histogram.count == 0
def test_histogram_summary_same_value():
"""Test histogram summary, input is an ones tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
dim1 = 100
dim2 = 100
test_data = _wrap_test_data(Tensor(np.ones([dim1, dim2])))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
LOG.debug(event)
assert len(event.summary.value[0].histogram.buckets) == 1
def test_histogram_summary_high_dims():
"""Test histogram summary, input is a 4-dimension tensor."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
dim = 10
rng = np.random.RandomState(0)
tensor_data = rng.normal(size=[dim, dim, dim, dim])
test_data = _wrap_test_data(Tensor(tensor_data))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
LOG.debug(event)
assert event.summary.value[0].histogram.count == tensor_data.size
def test_histogram_summary_nan_inf():
"""Test histogram summary, input tensor has nan."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
dim1 = 100
dim2 = 100
arr = np.ones([dim1, dim2])
arr[0][0] = np.nan
arr[0][1] = np.inf
arr[0][2] = -np.inf
test_data = _wrap_test_data(Tensor(arr))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
LOG.debug(event)
assert event.summary.value[0].histogram.nan_count == 1
def test_histogram_summary_all_nan_inf():
"""Test histogram summary, input tensor has no valid number."""
with tempfile.TemporaryDirectory() as tmp_dir:
test_writer = SummaryRecord(tmp_dir, file_suffix="_MS_HISTOGRAM")
test_data = _wrap_test_data(Tensor(np.array([np.nan, np.nan, np.nan, np.inf, -np.inf])))
_cache_summary_tensor_data(test_data)
test_writer.record(step=1)
test_writer.close()
file_name = os.path.join(tmp_dir, test_writer.event_file_name)
reader = SummaryReader(file_name)
event = reader.read_event()
LOG.debug(event)
histogram = event.summary.value[0].histogram
assert histogram.nan_count == 3
assert histogram.pos_inf_count == 1
assert histogram.neg_inf_count == 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册