提交 7068e708 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1063 回退 'Pull Request !713 : Use a resident process to write summary files'

Merge pull request !1063 from guozhijian/revert-merge-713-master
......@@ -14,77 +14,91 @@
# ============================================================================
"""Writes events to disk in a logdir."""
import os
import time
import stat
from collections import deque
from multiprocessing import Pool, Process, Queue, cpu_count
from mindspore import log as logger
from ..._c_expression import EventWriter_
from ._summary_adapter import package_summary_event
from ._summary_adapter import package_init_event
def _pack(result, step):
summary_event = package_summary_event(result, step)
return summary_event.SerializeToString()
class _WrapEventWriter(EventWriter_):
"""
Wrap the c++ EventWriter object.
Args:
full_file_name (str): Include directory and file name.
"""
def __init__(self, full_file_name):
if full_file_name is not None:
EventWriter_.__init__(self, full_file_name)
class EventWriter(Process):
class EventRecord:
"""
Creates a `EventWriter` and write event to file.
Creates a `EventFileWriter` and write event to file.
Args:
filepath (str): Summary event file path and file name.
flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120.
full_file_name (str): Summary event file path and file name.
flush_time (int): The flush seconds to flush the pending events to disk. Default: 120.
"""
def __init__(self, filepath: str, flush_interval: int) -> None:
super().__init__()
with open(filepath, 'w'):
os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR)
self._writer = EventWriter_(filepath)
self._queue = Queue(cpu_count() * 2)
self.start()
def run(self):
with Pool() as pool:
deq = deque()
while True:
while deq and deq[0].ready():
self._writer.Write(deq.popleft().get())
if not self._queue.empty():
action, data = self._queue.get()
if action == 'WRITE':
if not isinstance(data, (str, bytes)):
deq.append(pool.apply_async(_pack, data))
else:
self._writer.Write(data)
elif action == 'FLUSH':
self._writer.Flush()
elif action == 'END':
break
for res in deq:
self._writer.Write(res.get())
self._writer.Shut()
def write(self, data) -> None:
"""
Write the event to file.
Args:
data (Optional[str, Tuple[list, int]]): The data to write.
"""
self._queue.put(('WRITE', data))
def __init__(self, full_file_name: str, flush_time: int = 120):
self.full_file_name = full_file_name
# The first event will be flushed immediately.
self.flush_time = flush_time
self.next_flush_time = 0
# create event write object
self.event_writer = self._create_event_file()
self._init_event_file()
# count the events
self.event_count = 0
def _create_event_file(self):
"""Create the event write file."""
with open(self.full_file_name, 'w'):
os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR)
# create c++ event write object
event_writer = _WrapEventWriter(self.full_file_name)
return event_writer
def _init_event_file(self):
"""Send the init event to file."""
self.event_writer.Write((package_init_event()).SerializeToString())
self.flush()
return True
def write_event_to_file(self, event_str):
"""Write the event to file."""
self.event_writer.Write(event_str)
def get_data_count(self):
"""Return the event count."""
return self.event_count
def flush_cycle(self):
"""Flush file by timer."""
self.event_count = self.event_count + 1
# Flush the event writer every so often.
now = int(time.time())
if now > self.next_flush_time:
self.flush()
# update the flush time
self.next_flush_time = now + self.flush_time
def count_event(self):
"""Count event."""
logger.debug("Write the event count is %r", self.event_count)
self.event_count = self.event_count + 1
return self.event_count
def flush(self):
"""Flush the writer."""
self._queue.put(('FLUSH', None))
def close(self) -> None:
"""Close the writer."""
self._queue.put(('END', None))
self.join()
"""Flush the event file to disk."""
self.event_writer.Flush()
def __del__(self) -> None:
self.close()
def close(self):
"""Flush the event file to disk and close the file."""
self.flush()
self.event_writer.Shut()
......@@ -13,17 +13,17 @@
# limitations under the License.
# ============================================================================
"""Generate the summary event which conform to proto format."""
import socket
import time
import socket
import math
from enum import Enum, unique
import numpy as np
from PIL import Image
from mindspore import log as logger
from ..._checkparam import _check_str_by_regular
from ..anf_ir_pb2 import DataType, ModelProto
from ..summary_pb2 import Event
from ..anf_ir_pb2 import ModelProto, DataType
from ..._checkparam import _check_str_by_regular
# define the MindSpore image format
MS_IMAGE_TENSOR_FORMAT = 'NCHW'
......@@ -32,6 +32,55 @@ EVENT_FILE_NAME_MARK = ".out.events.summary."
# Set the init event of version and mark
EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:"
EVENT_FILE_INIT_VERSION = 1
# cache the summary data dict
# {id: SummaryData}
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_dict = {}
def save_summary_data(data_id, data):
"""Save the global summary cache."""
global g_summary_data_dict
g_summary_data_dict[data_id] = data
def del_summary_data(data_id):
"""Save the global summary cache."""
global g_summary_data_dict
if data_id in g_summary_data_dict:
del g_summary_data_dict[data_id]
else:
logger.warning("Can't del the data because data_id(%r) "
"does not have data in g_summary_data_dict", data_id)
def get_summary_data(data_id):
"""Save the global summary cache."""
ret = None
global g_summary_data_dict
if data_id in g_summary_data_dict:
ret = g_summary_data_dict.get(data_id)
else:
logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id)
return ret
@unique
class SummaryType(Enum):
"""
Summary type.
Args:
SCALAR (Number): Summary Scalar enum.
TENSOR (Number): Summary TENSOR enum.
IMAGE (Number): Summary image enum.
GRAPH (Number): Summary graph enum.
HISTOGRAM (Number): Summary histogram enum.
INVALID (Number): Unknow type.
"""
SCALAR = 1 # Scalar summary
TENSOR = 2 # Tensor summary
IMAGE = 3 # Image summary
GRAPH = 4 # graph
HISTOGRAM = 5 # Histogram Summary
INVALID = 0xFF # unknow type
def get_event_file_name(prefix, suffix):
......@@ -89,7 +138,7 @@ def package_graph_event(data):
return graph_event
def package_summary_event(data_list, step):
def package_summary_event(data_id, step):
"""
Package the summary to event protobuffer.
......@@ -100,37 +149,50 @@ def package_summary_event(data_list, step):
Returns:
Summary, the summary event.
"""
data_list = get_summary_data(data_id)
if data_list is None:
logger.error("The step(%r) does not have record data.", step)
del_summary_data(data_id)
# create the event of summary
summary_event = Event()
summary = summary_event.summary
summary_event.wall_time = time.time()
summary_event.step = int(step)
for value in data_list:
summary_type = value["_type"]
data = value["data"]
tag = value["name"]
data = value["data"]
summary_type = value["type"]
logger.debug("Now process %r summary, tag = %r", summary_type, tag)
summary_value = summary.value.add()
summary_value.tag = tag
# get the summary type and parse the tag
if summary_type == 'Scalar':
if summary_type is SummaryType.SCALAR:
logger.debug("Now process Scalar summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
summary_value.scalar_value = _get_scalar_summary(tag, data)
elif summary_type == 'Tensor':
elif summary_type is SummaryType.TENSOR:
logger.debug("Now process Tensor summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
summary_tensor = summary_value.tensor
_get_tensor_summary(tag, data, summary_tensor)
elif summary_type == 'Image':
elif summary_type is SummaryType.IMAGE:
logger.debug("Now process Image summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
summary_image = summary_value.image
_get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT)
elif summary_type == 'Histogram':
elif summary_type is SummaryType.HISTOGRAM:
logger.debug("Now process Histogram summary, tag = %r", tag)
summary_value = summary.value.add()
summary_value.tag = tag
summary_histogram = summary_value.histogram
_fill_histogram_summary(tag, data, summary_histogram)
else:
# The data is invalid ,jump the data
logger.error("Summary type(%r) is error, tag = %r", summary_type, tag)
logger.error("Summary type is error, tag = %r", tag)
continue
summary_event.wall_time = time.time()
summary_event.step = int(step)
return summary_event
......@@ -193,11 +255,11 @@ def _get_scalar_summary(tag: str, np_value):
# So consider the dim = 1, shape = (1,) tensor is scalar
scalar_value = np_value[0]
if np_value.shape != (1,):
logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape)
logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value)
else:
np_list = np_value.reshape(-1).tolist()
scalar_value = np_list[0]
logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim)
logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value)
logger.debug("The tag(%r) value is: %r", tag, scalar_value)
return scalar_value
......@@ -245,7 +307,8 @@ def _calc_histogram_bins(count):
Returns:
int, number of histogram bins.
"""
max_bins, max_per_bin = 90, 10
number_per_bucket = 10
max_bins = 90
if not count:
return 1
......@@ -255,50 +318,78 @@ def _calc_histogram_bins(count):
return 3
if count <= 880:
# note that math.ceil(881/10) + 1 equals 90
return count // max_per_bin + 1
return int(math.ceil(count / number_per_bucket) + 1)
return max_bins
def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None:
def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None:
"""
Package the histogram summary.
Args:
tag (str): Summary tag describe.
np_value (np.ndarray): Summary data.
summary (summary_pb2.Summary.Histogram): Summary histogram data.
np_value (np.array): Summary data.
summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data.
"""
logger.debug("Set(%r) the histogram summary value", tag)
# Default bucket for tensor with no valid data.
ma_value = np.ma.masked_invalid(np_value)
total, valid = np_value.size, ma_value.count()
invalids = []
for isfn in np.isnan, np.isposinf, np.isneginf:
if total - valid > sum(invalids):
count = np.count_nonzero(isfn(np_value))
invalids.append(count)
else:
invalids.append(0)
default_bucket_left = -0.5
default_bucket_width = 1.0
summary.count = total
summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids
if not valid:
logger.warning('There are no valid values in the ndarray(size=%d, shape=%d)', total, np_value.shape)
# summary.{min, max, sum} are 0s by default, no need to explicitly set
else:
summary.min = ma_value.min()
summary.max = ma_value.max()
summary.sum = ma_value.sum()
bins = _calc_histogram_bins(valid)
range_ = summary.min, summary.max
hists, edges = np.histogram(np_value, bins=bins, range=range_)
if np_value.size == 0:
bucket = summary_histogram.buckets.add()
bucket.left = default_bucket_left
bucket.width = default_bucket_width
bucket.count = 0
summary_histogram.nan_count = 0
summary_histogram.pos_inf_count = 0
summary_histogram.neg_inf_count = 0
summary_histogram.max = 0
summary_histogram.min = 0
summary_histogram.sum = 0
summary_histogram.count = 0
return
summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value))
summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value))
summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value))
summary_histogram.count = np_value.size
masked_value = np.ma.masked_invalid(np_value)
tensor_max = masked_value.max()
tensor_min = masked_value.min()
tensor_sum = masked_value.sum()
# No valid value in tensor.
if tensor_max is np.ma.masked:
bucket = summary_histogram.buckets.add()
bucket.left = default_bucket_left
bucket.width = default_bucket_width
bucket.count = 0
summary_histogram.max = np.nan
summary_histogram.min = np.nan
summary_histogram.sum = 0
return
bin_number = _calc_histogram_bins(masked_value.count())
counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max))
for ind, count in enumerate(counts):
bucket = summary_histogram.buckets.add()
bucket.left = edges[ind]
bucket.width = edges[ind + 1] - edges[ind]
bucket.count = count
for hist, edge1, edge2 in zip(hists, edges, edges[1:]):
bucket = summary.buckets.add()
bucket.width = edge2 - edge1
bucket.count = hist
bucket.left = edge1
summary_histogram.max = tensor_max
summary_histogram.min = tensor_min
summary_histogram.sum = tensor_sum
def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
......@@ -316,7 +407,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'):
"""
logger.debug("Set(%r) the image summary value", tag)
if np_value.ndim != 4:
logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim)
logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value)
# convert the tensor format
tensor = _convert_image_format(np_value, input_format)
......@@ -378,8 +469,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'):
"""
out_tensor = None
if np_tensor.ndim != len(input_format):
logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim,
input_format)
logger.error("The tensor(%r) can't convert the format(%r) because dim not same",
np_tensor, input_format)
return out_tensor
input_format = input_format.upper()
......@@ -421,7 +512,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8):
# check the tensor format
if tensor.ndim != 4 or tensor.shape[1] != 3:
logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape)
logger.error("The image tensor(%r) is not 'NCHW' format", tensor)
return out_canvas
# expand the N
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Schedule the event writer process."""
import multiprocessing as mp
from enum import Enum, unique
from mindspore import log as logger
from ..._c_expression import Tensor
from ._summary_adapter import SummaryType, package_summary_event, save_summary_data
# define the type of summary
FORMAT_SCALAR_STR = "Scalar"
FORMAT_TENSOR_STR = "Tensor"
FORMAT_IMAGE_STR = "Image"
FORMAT_HISTOGRAM_STR = "Histogram"
FORMAT_BEGIN_SLICE = "[:"
FORMAT_END_SLICE = "]"
# cache the summary data dict
# {id: SummaryData}
# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...]
g_summary_data_id = 0
g_summary_data_dict = {}
# cache the summary data file
g_summary_writer_id = 0
g_summary_file = {}
@unique
class ScheduleMethod(Enum):
"""Schedule method type."""
FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue
TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy)
CACHE_DATA = 2 # Cache data util have idle worker to process it
@unique
class WorkerStatus(Enum):
"""Worker status."""
WORKER_INIT = 0 # data is exist but not process
WORKER_PROCESSING = 1 # data is processing
WORKER_PROCESSED = 2 # data already processed
def _parse_tag_format(tag: str):
"""
Parse the tag.
Args:
tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor].
Returns:
Tuple, (SummaryType, summary_tag).
"""
summary_type = SummaryType.INVALID
summary_tag = tag
if tag is None:
logger.error("The tag is None")
return summary_type, summary_tag
# search the slice
slice_begin = FORMAT_BEGIN_SLICE
slice_end = FORMAT_END_SLICE
index = tag.rfind(slice_begin)
if index is -1:
logger.error("The tag(%s) have not the key slice.", tag)
return summary_type, summary_tag
# slice the tag
summary_tag = tag[:index]
# check the slice end
if tag[-1:] != slice_end:
logger.error("The tag(%s) end format is error", tag)
return summary_type, summary_tag
# check the type
type_str = tag[index + 2: -1]
logger.debug("The summary_tag is = %r", summary_tag)
logger.debug("The type_str value is = %r", type_str)
if type_str == FORMAT_SCALAR_STR:
summary_type = SummaryType.SCALAR
elif type_str == FORMAT_TENSOR_STR:
summary_type = SummaryType.TENSOR
elif type_str == FORMAT_IMAGE_STR:
summary_type = SummaryType.IMAGE
elif type_str == FORMAT_HISTOGRAM_STR:
summary_type = SummaryType.HISTOGRAM
else:
logger.error("The tag(%s) type is invalid.", tag)
summary_type = SummaryType.INVALID
return summary_type, summary_tag
class SummaryDataManager:
"""Manage the summary global data cache."""
def __init__(self):
global g_summary_data_dict
self.size = len(g_summary_data_dict)
@classmethod
def summary_data_save(cls, data):
"""Save the global summary cache."""
global g_summary_data_id
data_id = g_summary_data_id
save_summary_data(data_id, data)
g_summary_data_id += 1
return data_id
@classmethod
def summary_file_set(cls, event_writer):
"""Support the many event_writer."""
global g_summary_file, g_summary_writer_id
g_summary_writer_id += 1
g_summary_file[g_summary_writer_id] = event_writer
return g_summary_writer_id
@classmethod
def summary_file_get(cls, writer_id=1):
ret = None
global g_summary_file
if writer_id in g_summary_file:
ret = g_summary_file.get(writer_id)
return ret
class WorkerScheduler:
"""
Create worker and schedule data to worker.
Args:
writer_id (int): The index of writer.
"""
def __init__(self, writer_id):
# Create the process of write event file
self.write_lock = mp.Lock()
# Schedule info for all worker
# Format: {worker: (step, WorkerStatus)}
self.schedule_table = {}
# write id
self.writer_id = writer_id
self.has_graph = False
def dispatch(self, step, data):
"""
Select schedule strategy and dispatch data.
Args:
step (Number): The number of step index.
data (Object): The data of recode for summary.
Retruns:
bool, run successfully or not.
"""
# save the data to global cache , convert the tensor to numpy
result, size, data = self._data_convert(data)
if result is False:
logger.error("The step(%r) summary data(%r) is invalid.", step, size)
return False
data_id = SummaryDataManager.summary_data_save(data)
self._start_worker(step, data_id)
return True
def _start_worker(self, step, data_id):
"""
Start worker.
Args:
step (Number): The index of recode.
data_id (str): The id of work.
Return:
bool, run successfully or not.
"""
# assign the worker
policy = self._make_policy()
if policy == ScheduleMethod.TEMP_WORKER:
worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id)
# update the schedule table
self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT)
# start the worker
worker.start()
else:
logger.error("Do not support the other scheduler policy now.")
# update the scheduler infor
self._update_scheduler()
return True
def _data_convert(self, data_list):
"""Convert the data."""
if data_list is None:
logger.warning("The step does not have record data.")
return False, 0, None
# convert the summary to numpy
size = 0
for v_dict in data_list:
tag = v_dict["name"]
data = v_dict["data"]
# confirm the data is valid
summary_type, summary_tag = _parse_tag_format(tag)
if summary_type == SummaryType.INVALID:
logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
return False, 0, None
if isinstance(data, Tensor):
# get the summary type and parse the tag
v_dict["name"] = summary_tag
v_dict["type"] = summary_type
v_dict["data"] = data.asnumpy()
size += v_dict["data"].size
else:
logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data)
return False, 0, None
return True, size, data_list
def _update_scheduler(self):
"""Check the worker status and update schedule table."""
workers = list(self.schedule_table.keys())
for worker in workers:
if not worker.is_alive():
# update the table
worker.join()
del self.schedule_table[worker]
def close(self):
"""Confirm all worker is end."""
workers = self.schedule_table.keys()
for worker in workers:
if worker.is_alive():
worker.join()
def _make_policy(self):
"""Select the schedule strategy by data."""
# now only support the temp worker
return ScheduleMethod.TEMP_WORKER
class SummaryDataProcess(mp.Process):
"""
Process that consume the summarydata.
Args:
step (int): The index of step.
data_id (int): The index of summary data.
write_lock (Lock): The process lock for writer same file.
writer_id (int): The index of writer.
"""
def __init__(self, step, data_id, write_lock, writer_id):
super(SummaryDataProcess, self).__init__()
self.daemon = True
self.writer_id = writer_id
self.writer = SummaryDataManager.summary_file_get(self.writer_id)
if self.writer is None:
logger.error("The writer_id(%r) does not have writer", writer_id)
self.step = step
self.data_id = data_id
self.write_lock = write_lock
self.name = "SummaryDataConsumer_" + str(self.step)
def run(self):
"""The consumer is process the step data and exit."""
# convert the data to event
# All exceptions need to be caught and end the queue
try:
logger.debug("process(%r) process a data(%r)", self.name, self.step)
# package the summary event
summary_event = package_summary_event(self.data_id, self.step)
# send the event to file
self._write_summary(summary_event)
except Exception as e:
logger.error("Summary data mq consumer exception occurred, value = %r", e)
def _write_summary(self, summary_event):
"""
Write the summary to event file.
Note:
The write record format:
1 uint64 : data length.
2 uint32 : mask crc value of data length.
3 bytes : data.
4 uint32 : mask crc value of data.
Args:
summary_event (Event): The summary event of proto.
"""
event_str = summary_event.SerializeToString()
self.write_lock.acquire()
self.writer.write_event_to_file(event_str)
self.writer.flush()
self.write_lock.release()
......@@ -14,22 +14,17 @@
# ============================================================================
"""Record the summary event."""
import os
import re
import threading
from mindspore import log as logger
from ..._c_expression import Tensor
from ..._checkparam import _check_str_by_regular
from ._summary_scheduler import WorkerScheduler, SummaryDataManager
from ._summary_adapter import get_event_file_name, package_graph_event
from ._event_writer import EventRecord
from .._utils import _make_directory
from ._event_writer import EventWriter
from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event
from ..._checkparam import _check_str_by_regular
# for the moment, this lock is for caution's sake,
# there are actually no any concurrencies happening.
_summary_lock = threading.Lock()
# cache the summary data
_summary_tensor_cache = {}
_summary_lock = threading.Lock()
def _cache_summary_tensor_data(summary):
......@@ -39,18 +34,14 @@ def _cache_summary_tensor_data(summary):
Args:
summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
"""
with _summary_lock:
for item in summary:
_summary_tensor_cache[item['name']] = item['data']
return True
def _get_summary_tensor_data():
global _summary_tensor_cache
with _summary_lock:
data = _summary_tensor_cache
_summary_tensor_cache = {}
return data
_summary_lock.acquire()
if "SummaryRecord" in _summary_tensor_cache:
for record in summary:
_summary_tensor_cache["SummaryRecord"].append(record)
else:
_summary_tensor_cache["SummaryRecord"] = summary
_summary_lock.release()
return True
class SummaryRecord:
......@@ -80,7 +71,6 @@ class SummaryRecord:
>>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6,
>>> file_prefix="xxx_", file_suffix="_yyy")
"""
def __init__(self,
log_dir,
queue_max_size=0,
......@@ -111,18 +101,26 @@ class SummaryRecord:
self.prefix = file_prefix
self.suffix = file_suffix
self.network = network
self.has_graph = False
self._closed = False
# create the summary writer file
self.event_file_name = get_event_file_name(self.prefix, self.suffix)
if self.log_path[-1:] == '/':
self.full_file_name = self.log_path + self.event_file_name
else:
self.full_file_name = self.log_path + '/' + self.event_file_name
try:
self.full_file_name = os.path.join(self.log_path, self.event_file_name)
self.full_file_name = os.path.realpath(self.full_file_name)
except Exception as ex:
raise RuntimeError(ex)
self.event_writer = EventWriter(self.full_file_name, self.flush_time)
self.event_writer.write(package_init_event().SerializeToString())
self.event_writer = EventRecord(self.full_file_name, self.flush_time)
self.writer_id = SummaryDataManager.summary_file_set(self.event_writer)
self.worker_scheduler = WorkerScheduler(self.writer_id)
self.step = 0
self._closed = False
self.network = network
self.has_graph = False
def record(self, step, train_network=None):
"""
......@@ -147,34 +145,42 @@ class SummaryRecord:
if not isinstance(step, int) or isinstance(step, bool):
raise ValueError("`step` should be int")
# Set the current summary of train step
self.step = step
if self.network is not None and not self.has_graph:
if self.network is not None and self.has_graph is False:
graph_proto = self.network.get_func_graph_proto()
if graph_proto is None and train_network is not None:
graph_proto = train_network.get_func_graph_proto()
if graph_proto is None:
logger.error("Failed to get proto for graph")
else:
self.event_writer.write(package_graph_event(graph_proto).SerializeToString())
self.event_writer.write_event_to_file(
package_graph_event(graph_proto).SerializeToString())
self.event_writer.flush()
self.has_graph = True
if not _summary_tensor_cache:
data = _summary_tensor_cache.get("SummaryRecord")
if data is None:
return True
data = _get_summary_tensor_data()
if not data:
logger.error("The step(%r) does not have record data.", step)
data = _summary_tensor_cache.get("SummaryRecord")
if data is None:
logger.error("The step(%r) does not have record data.", self.step)
return False
if self.queue_max_size > 0 and len(data) > self.queue_max_size:
logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data),
self.queue_max_size)
# clean the data of cache
del _summary_tensor_cache["SummaryRecord"]
# process the data
result = self._data_convert(data)
if not result:
logger.error("The step(%r) summary data is invalid.", step)
return False
self.event_writer.write((result, step))
logger.debug("Send the summary data to scheduler for saving, step = %d", step)
self.worker_scheduler.dispatch(self.step, data)
# count & flush
self.event_writer.count_event()
self.event_writer.flush_cycle()
logger.debug("Send the summary data to scheduler for saving, step = %d", self.step)
return True
@property
......@@ -190,7 +196,7 @@ class SummaryRecord:
Returns:
String, the full path of log file.
"""
return self.full_file_name
return self.event_writer.full_file_name
def flush(self):
"""
......@@ -218,44 +224,20 @@ class SummaryRecord:
>>> summary_record.close()
"""
if not self._closed:
self._check_data_before_close()
self.worker_scheduler.close()
# event writer flush and close
self.event_writer.close()
self._closed = True
def _data_convert(self, summary):
"""Convert the data."""
# convert the summary to numpy
result = []
for name, data in summary.items():
# confirm the data is valid
summary_tag, summary_type = SummaryRecord._parse_from(name)
if summary_tag is None:
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
return None
if isinstance(data, Tensor):
result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type})
else:
logger.error("The data type is invalid, name = %r, tensor = %r", name, data)
return None
return result
@staticmethod
def _parse_from(name: str = None):
"""
Parse the tag and type from name.
Args:
name (str): Format: TAG[:TYPE].
Returns:
Tuple, (summary_tag, summary_type).
"""
if name is None:
logger.error("The name is None")
return None, None
match = re.match(r'(.+)\[:(.+)\]', name)
if match:
return match.groups()
logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name)
return None, None
def __del__(self):
"""Process exit is called."""
if hasattr(self, "worker_scheduler"):
if self.worker_scheduler:
self.close()
def _check_data_before_close(self):
"Check whether there is any data in the cache, and if so, call record"
data = _summary_tensor_cache.get("SummaryRecord")
if data is not None:
self.record(self.step)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册