diff --git a/mindspore/train/summary/_event_writer.py b/mindspore/train/summary/_event_writer.py index c04308dcbc4e6cc86cc1a44026870ab5579ee0d1..fa3b09f5ee0e3d8d8c27edcc29157258a1496291 100644 --- a/mindspore/train/summary/_event_writer.py +++ b/mindspore/train/summary/_event_writer.py @@ -14,91 +14,77 @@ # ============================================================================ """Writes events to disk in a logdir.""" import os -import time import stat -from mindspore import log as logger +from collections import deque +from multiprocessing import Pool, Process, Queue, cpu_count + from ..._c_expression import EventWriter_ -from ._summary_adapter import package_init_event +from ._summary_adapter import package_summary_event -class _WrapEventWriter(EventWriter_): - """ - Wrap the c++ EventWriter object. +def _pack(result, step): + summary_event = package_summary_event(result, step) + return summary_event.SerializeToString() - Args: - full_file_name (str): Include directory and file name. - """ - def __init__(self, full_file_name): - if full_file_name is not None: - EventWriter_.__init__(self, full_file_name) - -class EventRecord: +class EventWriter(Process): """ - Creates a `EventFileWriter` and write event to file. + Creates a `EventWriter` and write event to file. Args: - full_file_name (str): Summary event file path and file name. - flush_time (int): The flush seconds to flush the pending events to disk. Default: 120. + filepath (str): Summary event file path and file name. + flush_interval (int): The flush seconds to flush the pending events to disk. Default: 120. """ - def __init__(self, full_file_name: str, flush_time: int = 120): - self.full_file_name = full_file_name - - # The first event will be flushed immediately. - self.flush_time = flush_time - self.next_flush_time = 0 - - # create event write object - self.event_writer = self._create_event_file() - self._init_event_file() - - # count the events - self.event_count = 0 - - def _create_event_file(self): - """Create the event write file.""" - with open(self.full_file_name, 'w'): - os.chmod(self.full_file_name, stat.S_IWUSR | stat.S_IRUSR) - - # create c++ event write object - event_writer = _WrapEventWriter(self.full_file_name) - return event_writer - - def _init_event_file(self): - """Send the init event to file.""" - self.event_writer.Write((package_init_event()).SerializeToString()) - self.flush() - return True - - def write_event_to_file(self, event_str): - """Write the event to file.""" - self.event_writer.Write(event_str) - - def get_data_count(self): - """Return the event count.""" - return self.event_count - - def flush_cycle(self): - """Flush file by timer.""" - self.event_count = self.event_count + 1 - # Flush the event writer every so often. - now = int(time.time()) - if now > self.next_flush_time: - self.flush() - # update the flush time - self.next_flush_time = now + self.flush_time - - def count_event(self): - """Count event.""" - logger.debug("Write the event count is %r", self.event_count) - self.event_count = self.event_count + 1 - return self.event_count + + def __init__(self, filepath: str, flush_interval: int) -> None: + super().__init__() + with open(filepath, 'w'): + os.chmod(filepath, stat.S_IWUSR | stat.S_IRUSR) + self._writer = EventWriter_(filepath) + self._queue = Queue(cpu_count() * 2) + self.start() + + def run(self): + + with Pool() as pool: + deq = deque() + while True: + while deq and deq[0].ready(): + self._writer.Write(deq.popleft().get()) + + if not self._queue.empty(): + action, data = self._queue.get() + if action == 'WRITE': + if not isinstance(data, (str, bytes)): + deq.append(pool.apply_async(_pack, data)) + else: + self._writer.Write(data) + elif action == 'FLUSH': + self._writer.Flush() + elif action == 'END': + break + for res in deq: + self._writer.Write(res.get()) + + self._writer.Shut() + + def write(self, data) -> None: + """ + Write the event to file. + + Args: + data (Optional[str, Tuple[list, int]]): The data to write. + """ + self._queue.put(('WRITE', data)) def flush(self): - """Flush the event file to disk.""" - self.event_writer.Flush() + """Flush the writer.""" + self._queue.put(('FLUSH', None)) + + def close(self) -> None: + """Close the writer.""" + self._queue.put(('END', None)) + self.join() - def close(self): - """Flush the event file to disk and close the file.""" - self.flush() - self.event_writer.Shut() + def __del__(self) -> None: + self.close() diff --git a/mindspore/train/summary/_summary_adapter.py b/mindspore/train/summary/_summary_adapter.py index 9669d0f0541bddb03d23f39780dd842ee2724008..1cfde39b837614898c182778531fe1cbc6622cde 100644 --- a/mindspore/train/summary/_summary_adapter.py +++ b/mindspore/train/summary/_summary_adapter.py @@ -13,17 +13,17 @@ # limitations under the License. # ============================================================================ """Generate the summary event which conform to proto format.""" -import time import socket -import math -from enum import Enum, unique +import time + import numpy as np from PIL import Image from mindspore import log as logger -from ..summary_pb2 import Event -from ..anf_ir_pb2 import ModelProto, DataType + from ..._checkparam import _check_str_by_regular +from ..anf_ir_pb2 import DataType, ModelProto +from ..summary_pb2 import Event # define the MindSpore image format MS_IMAGE_TENSOR_FORMAT = 'NCHW' @@ -32,55 +32,6 @@ EVENT_FILE_NAME_MARK = ".out.events.summary." # Set the init event of version and mark EVENT_FILE_INIT_VERSION_MARK = "Mindspore.Event:" EVENT_FILE_INIT_VERSION = 1 -# cache the summary data dict -# {id: SummaryData} -# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] -g_summary_data_dict = {} - -def save_summary_data(data_id, data): - """Save the global summary cache.""" - global g_summary_data_dict - g_summary_data_dict[data_id] = data - - -def del_summary_data(data_id): - """Save the global summary cache.""" - global g_summary_data_dict - if data_id in g_summary_data_dict: - del g_summary_data_dict[data_id] - else: - logger.warning("Can't del the data because data_id(%r) " - "does not have data in g_summary_data_dict", data_id) - -def get_summary_data(data_id): - """Save the global summary cache.""" - ret = None - global g_summary_data_dict - if data_id in g_summary_data_dict: - ret = g_summary_data_dict.get(data_id) - else: - logger.warning("The data_id(%r) does not have data in g_summary_data_dict", data_id) - return ret - -@unique -class SummaryType(Enum): - """ - Summary type. - - Args: - SCALAR (Number): Summary Scalar enum. - TENSOR (Number): Summary TENSOR enum. - IMAGE (Number): Summary image enum. - GRAPH (Number): Summary graph enum. - HISTOGRAM (Number): Summary histogram enum. - INVALID (Number): Unknow type. - """ - SCALAR = 1 # Scalar summary - TENSOR = 2 # Tensor summary - IMAGE = 3 # Image summary - GRAPH = 4 # graph - HISTOGRAM = 5 # Histogram Summary - INVALID = 0xFF # unknow type def get_event_file_name(prefix, suffix): @@ -138,7 +89,7 @@ def package_graph_event(data): return graph_event -def package_summary_event(data_id, step): +def package_summary_event(data_list, step): """ Package the summary to event protobuffer. @@ -149,50 +100,37 @@ def package_summary_event(data_id, step): Returns: Summary, the summary event. """ - data_list = get_summary_data(data_id) - if data_list is None: - logger.error("The step(%r) does not have record data.", step) - del_summary_data(data_id) # create the event of summary summary_event = Event() summary = summary_event.summary + summary_event.wall_time = time.time() + summary_event.step = int(step) for value in data_list: - tag = value["name"] + summary_type = value["_type"] data = value["data"] - summary_type = value["type"] + tag = value["name"] + logger.debug("Now process %r summary, tag = %r", summary_type, tag) + + summary_value = summary.value.add() + summary_value.tag = tag # get the summary type and parse the tag - if summary_type is SummaryType.SCALAR: - logger.debug("Now process Scalar summary, tag = %r", tag) - summary_value = summary.value.add() - summary_value.tag = tag + if summary_type == 'Scalar': summary_value.scalar_value = _get_scalar_summary(tag, data) - elif summary_type is SummaryType.TENSOR: - logger.debug("Now process Tensor summary, tag = %r", tag) - summary_value = summary.value.add() - summary_value.tag = tag + elif summary_type == 'Tensor': summary_tensor = summary_value.tensor _get_tensor_summary(tag, data, summary_tensor) - elif summary_type is SummaryType.IMAGE: - logger.debug("Now process Image summary, tag = %r", tag) - summary_value = summary.value.add() - summary_value.tag = tag + elif summary_type == 'Image': summary_image = summary_value.image _get_image_summary(tag, data, summary_image, MS_IMAGE_TENSOR_FORMAT) - elif summary_type is SummaryType.HISTOGRAM: - logger.debug("Now process Histogram summary, tag = %r", tag) - summary_value = summary.value.add() - summary_value.tag = tag + elif summary_type == 'Histogram': summary_histogram = summary_value.histogram _fill_histogram_summary(tag, data, summary_histogram) else: # The data is invalid ,jump the data - logger.error("Summary type is error, tag = %r", tag) - continue + logger.error("Summary type(%r) is error, tag = %r", summary_type, tag) - summary_event.wall_time = time.time() - summary_event.step = int(step) return summary_event @@ -255,11 +193,11 @@ def _get_scalar_summary(tag: str, np_value): # So consider the dim = 1, shape = (1,) tensor is scalar scalar_value = np_value[0] if np_value.shape != (1,): - logger.error("The tensor is not Scalar, tag = %r, Value = %r", tag, np_value) + logger.error("The tensor is not Scalar, tag = %r, Shape = %r", tag, np_value.shape) else: np_list = np_value.reshape(-1).tolist() scalar_value = np_list[0] - logger.error("The value is not Scalar, tag = %r, Value = %r", tag, np_value) + logger.error("The value is not Scalar, tag = %r, ndim = %r", tag, np_value.ndim) logger.debug("The tag(%r) value is: %r", tag, scalar_value) return scalar_value @@ -307,8 +245,7 @@ def _calc_histogram_bins(count): Returns: int, number of histogram bins. """ - number_per_bucket = 10 - max_bins = 90 + max_bins, max_per_bin = 90, 10 if not count: return 1 @@ -318,78 +255,50 @@ def _calc_histogram_bins(count): return 3 if count <= 880: # note that math.ceil(881/10) + 1 equals 90 - return int(math.ceil(count / number_per_bucket) + 1) + return count // max_per_bin + 1 return max_bins -def _fill_histogram_summary(tag: str, np_value: np.array, summary_histogram) -> None: +def _fill_histogram_summary(tag: str, np_value: np.ndarray, summary) -> None: """ Package the histogram summary. Args: tag (str): Summary tag describe. - np_value (np.array): Summary data. - summary_histogram (summary_pb2.Summary.Histogram): Summary histogram data. + np_value (np.ndarray): Summary data. + summary (summary_pb2.Summary.Histogram): Summary histogram data. """ logger.debug("Set(%r) the histogram summary value", tag) # Default bucket for tensor with no valid data. - default_bucket_left = -0.5 - default_bucket_width = 1.0 - - if np_value.size == 0: - bucket = summary_histogram.buckets.add() - bucket.left = default_bucket_left - bucket.width = default_bucket_width - bucket.count = 0 - - summary_histogram.nan_count = 0 - summary_histogram.pos_inf_count = 0 - summary_histogram.neg_inf_count = 0 - - summary_histogram.max = 0 - summary_histogram.min = 0 - summary_histogram.sum = 0 - - summary_histogram.count = 0 - - return - - summary_histogram.nan_count = np.count_nonzero(np.isnan(np_value)) - summary_histogram.pos_inf_count = np.count_nonzero(np.isposinf(np_value)) - summary_histogram.neg_inf_count = np.count_nonzero(np.isneginf(np_value)) - summary_histogram.count = np_value.size - - masked_value = np.ma.masked_invalid(np_value) - tensor_max = masked_value.max() - tensor_min = masked_value.min() - tensor_sum = masked_value.sum() - - # No valid value in tensor. - if tensor_max is np.ma.masked: - bucket = summary_histogram.buckets.add() - bucket.left = default_bucket_left - bucket.width = default_bucket_width - bucket.count = 0 - - summary_histogram.max = np.nan - summary_histogram.min = np.nan - summary_histogram.sum = 0 - - return - - bin_number = _calc_histogram_bins(masked_value.count()) - counts, edges = np.histogram(np_value, bins=bin_number, range=(tensor_min, tensor_max)) + ma_value = np.ma.masked_invalid(np_value) + total, valid = np_value.size, ma_value.count() + invalids = [] + for isfn in np.isnan, np.isposinf, np.isneginf: + if total - valid > sum(invalids): + count = np.count_nonzero(isfn(np_value)) + invalids.append(count) + else: + invalids.append(0) - for ind, count in enumerate(counts): - bucket = summary_histogram.buckets.add() - bucket.left = edges[ind] - bucket.width = edges[ind + 1] - edges[ind] - bucket.count = count + summary.count = total + summary.nan_count, summary.pos_inf_count, summary.neg_inf_count = invalids + if not valid: + logger.warning('There are no valid values in the ndarray(size=%d, shape=%d)', total, np_value.shape) + # summary.{min, max, sum} are 0s by default, no need to explicitly set + else: + summary.min = ma_value.min() + summary.max = ma_value.max() + summary.sum = ma_value.sum() + bins = _calc_histogram_bins(valid) + range_ = summary.min, summary.max + hists, edges = np.histogram(np_value, bins=bins, range=range_) - summary_histogram.max = tensor_max - summary_histogram.min = tensor_min - summary_histogram.sum = tensor_sum + for hist, edge1, edge2 in zip(hists, edges, edges[1:]): + bucket = summary.buckets.add() + bucket.width = edge2 - edge1 + bucket.count = hist + bucket.left = edge1 def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): @@ -407,7 +316,7 @@ def _get_image_summary(tag: str, np_value, summary_image, input_format='NCHW'): """ logger.debug("Set(%r) the image summary value", tag) if np_value.ndim != 4: - logger.error("The value is not Image, tag = %r, Value = %r", tag, np_value) + logger.error("The value is not Image, tag = %r, ndim = %r", tag, np_value.ndim) # convert the tensor format tensor = _convert_image_format(np_value, input_format) @@ -469,8 +378,8 @@ def _convert_image_format(np_tensor, input_format, out_format='HWC'): """ out_tensor = None if np_tensor.ndim != len(input_format): - logger.error("The tensor(%r) can't convert the format(%r) because dim not same", - np_tensor, input_format) + logger.error("The tensor with dim(%r) can't convert the format(%r) because dim not same", np_tensor.ndim, + input_format) return out_tensor input_format = input_format.upper() @@ -512,7 +421,7 @@ def _make_canvas_for_imgs(tensor, col_imgs=8): # check the tensor format if tensor.ndim != 4 or tensor.shape[1] != 3: - logger.error("The image tensor(%r) is not 'NCHW' format", tensor) + logger.error("The image tensor with ndim(%r) and shape(%r) is not 'NCHW' format", tensor.ndim, tensor.shape) return out_canvas # expand the N diff --git a/mindspore/train/summary/_summary_scheduler.py b/mindspore/train/summary/_summary_scheduler.py deleted file mode 100644 index 3327b02fa7c864a8c931950526fbd95bba811823..0000000000000000000000000000000000000000 --- a/mindspore/train/summary/_summary_scheduler.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Schedule the event writer process.""" -import multiprocessing as mp -from enum import Enum, unique -from mindspore import log as logger -from ..._c_expression import Tensor -from ._summary_adapter import SummaryType, package_summary_event, save_summary_data - -# define the type of summary -FORMAT_SCALAR_STR = "Scalar" -FORMAT_TENSOR_STR = "Tensor" -FORMAT_IMAGE_STR = "Image" -FORMAT_HISTOGRAM_STR = "Histogram" -FORMAT_BEGIN_SLICE = "[:" -FORMAT_END_SLICE = "]" - -# cache the summary data dict -# {id: SummaryData} -# |---[{"name": tag_name, "data": numpy}, {"name": tag_name, "data": numpy},...] -g_summary_data_id = 0 -g_summary_data_dict = {} -# cache the summary data file -g_summary_writer_id = 0 -g_summary_file = {} - - -@unique -class ScheduleMethod(Enum): - """Schedule method type.""" - FORMAL_WORKER = 0 # use the formal worker that receive small size data by queue - TEMP_WORKER = 1 # use the Temp worker that receive big size data by the global value(avoid copy) - CACHE_DATA = 2 # Cache data util have idle worker to process it - - -@unique -class WorkerStatus(Enum): - """Worker status.""" - WORKER_INIT = 0 # data is exist but not process - WORKER_PROCESSING = 1 # data is processing - WORKER_PROCESSED = 2 # data already processed - - -def _parse_tag_format(tag: str): - """ - Parse the tag. - - Args: - tag (str): Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor]. - - Returns: - Tuple, (SummaryType, summary_tag). - """ - - summary_type = SummaryType.INVALID - summary_tag = tag - if tag is None: - logger.error("The tag is None") - return summary_type, summary_tag - - # search the slice - slice_begin = FORMAT_BEGIN_SLICE - slice_end = FORMAT_END_SLICE - index = tag.rfind(slice_begin) - if index is -1: - logger.error("The tag(%s) have not the key slice.", tag) - return summary_type, summary_tag - - # slice the tag - summary_tag = tag[:index] - - # check the slice end - if tag[-1:] != slice_end: - logger.error("The tag(%s) end format is error", tag) - return summary_type, summary_tag - - # check the type - type_str = tag[index + 2: -1] - logger.debug("The summary_tag is = %r", summary_tag) - logger.debug("The type_str value is = %r", type_str) - if type_str == FORMAT_SCALAR_STR: - summary_type = SummaryType.SCALAR - elif type_str == FORMAT_TENSOR_STR: - summary_type = SummaryType.TENSOR - elif type_str == FORMAT_IMAGE_STR: - summary_type = SummaryType.IMAGE - elif type_str == FORMAT_HISTOGRAM_STR: - summary_type = SummaryType.HISTOGRAM - else: - logger.error("The tag(%s) type is invalid.", tag) - summary_type = SummaryType.INVALID - - return summary_type, summary_tag - - -class SummaryDataManager: - """Manage the summary global data cache.""" - def __init__(self): - global g_summary_data_dict - self.size = len(g_summary_data_dict) - - @classmethod - def summary_data_save(cls, data): - """Save the global summary cache.""" - global g_summary_data_id - data_id = g_summary_data_id - save_summary_data(data_id, data) - g_summary_data_id += 1 - return data_id - - @classmethod - def summary_file_set(cls, event_writer): - """Support the many event_writer.""" - global g_summary_file, g_summary_writer_id - g_summary_writer_id += 1 - g_summary_file[g_summary_writer_id] = event_writer - return g_summary_writer_id - - @classmethod - def summary_file_get(cls, writer_id=1): - ret = None - global g_summary_file - if writer_id in g_summary_file: - ret = g_summary_file.get(writer_id) - return ret - - -class WorkerScheduler: - """ - Create worker and schedule data to worker. - - Args: - writer_id (int): The index of writer. - """ - def __init__(self, writer_id): - # Create the process of write event file - self.write_lock = mp.Lock() - # Schedule info for all worker - # Format: {worker: (step, WorkerStatus)} - self.schedule_table = {} - # write id - self.writer_id = writer_id - self.has_graph = False - - def dispatch(self, step, data): - """ - Select schedule strategy and dispatch data. - - Args: - step (Number): The number of step index. - data (Object): The data of recode for summary. - - Retruns: - bool, run successfully or not. - """ - # save the data to global cache , convert the tensor to numpy - result, size, data = self._data_convert(data) - if result is False: - logger.error("The step(%r) summary data(%r) is invalid.", step, size) - return False - - data_id = SummaryDataManager.summary_data_save(data) - self._start_worker(step, data_id) - return True - - def _start_worker(self, step, data_id): - """ - Start worker. - - Args: - step (Number): The index of recode. - data_id (str): The id of work. - - Return: - bool, run successfully or not. - """ - # assign the worker - policy = self._make_policy() - if policy == ScheduleMethod.TEMP_WORKER: - worker = SummaryDataProcess(step, data_id, self.write_lock, self.writer_id) - # update the schedule table - self.schedule_table[worker] = (step, data_id, WorkerStatus.WORKER_INIT) - # start the worker - worker.start() - else: - logger.error("Do not support the other scheduler policy now.") - - # update the scheduler infor - self._update_scheduler() - return True - - def _data_convert(self, data_list): - """Convert the data.""" - if data_list is None: - logger.warning("The step does not have record data.") - return False, 0, None - - # convert the summary to numpy - size = 0 - for v_dict in data_list: - tag = v_dict["name"] - data = v_dict["data"] - # confirm the data is valid - summary_type, summary_tag = _parse_tag_format(tag) - if summary_type == SummaryType.INVALID: - logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) - return False, 0, None - if isinstance(data, Tensor): - # get the summary type and parse the tag - v_dict["name"] = summary_tag - v_dict["type"] = summary_type - v_dict["data"] = data.asnumpy() - size += v_dict["data"].size - else: - logger.error("The data type is invalid, tag = %r, tensor = %r", tag, data) - return False, 0, None - - return True, size, data_list - - def _update_scheduler(self): - """Check the worker status and update schedule table.""" - workers = list(self.schedule_table.keys()) - for worker in workers: - if not worker.is_alive(): - # update the table - worker.join() - del self.schedule_table[worker] - - def close(self): - """Confirm all worker is end.""" - workers = self.schedule_table.keys() - for worker in workers: - if worker.is_alive(): - worker.join() - - def _make_policy(self): - """Select the schedule strategy by data.""" - # now only support the temp worker - return ScheduleMethod.TEMP_WORKER - - -class SummaryDataProcess(mp.Process): - """ - Process that consume the summarydata. - - Args: - step (int): The index of step. - data_id (int): The index of summary data. - write_lock (Lock): The process lock for writer same file. - writer_id (int): The index of writer. - """ - def __init__(self, step, data_id, write_lock, writer_id): - super(SummaryDataProcess, self).__init__() - self.daemon = True - self.writer_id = writer_id - self.writer = SummaryDataManager.summary_file_get(self.writer_id) - if self.writer is None: - logger.error("The writer_id(%r) does not have writer", writer_id) - self.step = step - self.data_id = data_id - self.write_lock = write_lock - self.name = "SummaryDataConsumer_" + str(self.step) - - def run(self): - """The consumer is process the step data and exit.""" - # convert the data to event - # All exceptions need to be caught and end the queue - try: - logger.debug("process(%r) process a data(%r)", self.name, self.step) - # package the summary event - summary_event = package_summary_event(self.data_id, self.step) - # send the event to file - self._write_summary(summary_event) - except Exception as e: - logger.error("Summary data mq consumer exception occurred, value = %r", e) - - def _write_summary(self, summary_event): - """ - Write the summary to event file. - - Note: - The write record format: - 1 uint64 : data length. - 2 uint32 : mask crc value of data length. - 3 bytes : data. - 4 uint32 : mask crc value of data. - - Args: - summary_event (Event): The summary event of proto. - - """ - event_str = summary_event.SerializeToString() - self.write_lock.acquire() - self.writer.write_event_to_file(event_str) - self.writer.flush() - self.write_lock.release() diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index 4c60dce862975bd8ad7d23a48c658b0cc20e02a4..43baebccf97e7d8cf257d32805f437a725fbcdfc 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -14,17 +14,22 @@ # ============================================================================ """Record the summary event.""" import os +import re import threading + from mindspore import log as logger -from ._summary_scheduler import WorkerScheduler, SummaryDataManager -from ._summary_adapter import get_event_file_name, package_graph_event -from ._event_writer import EventRecord -from .._utils import _make_directory + +from ..._c_expression import Tensor from ..._checkparam import _check_str_by_regular +from .._utils import _make_directory +from ._event_writer import EventWriter +from ._summary_adapter import get_event_file_name, package_graph_event, package_init_event +# for the moment, this lock is for caution's sake, +# there are actually no any concurrencies happening. +_summary_lock = threading.Lock() # cache the summary data _summary_tensor_cache = {} -_summary_lock = threading.Lock() def _cache_summary_tensor_data(summary): @@ -34,14 +39,18 @@ def _cache_summary_tensor_data(summary): Args: summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...]. """ - _summary_lock.acquire() - if "SummaryRecord" in _summary_tensor_cache: - for record in summary: - _summary_tensor_cache["SummaryRecord"].append(record) - else: - _summary_tensor_cache["SummaryRecord"] = summary - _summary_lock.release() - return True + with _summary_lock: + for item in summary: + _summary_tensor_cache[item['name']] = item['data'] + return True + + +def _get_summary_tensor_data(): + global _summary_tensor_cache + with _summary_lock: + data = _summary_tensor_cache + _summary_tensor_cache = {} + return data class SummaryRecord: @@ -71,6 +80,7 @@ class SummaryRecord: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") """ + def __init__(self, log_dir, queue_max_size=0, @@ -101,26 +111,18 @@ class SummaryRecord: self.prefix = file_prefix self.suffix = file_suffix + self.network = network + self.has_graph = False + self._closed = False # create the summary writer file self.event_file_name = get_event_file_name(self.prefix, self.suffix) - if self.log_path[-1:] == '/': - self.full_file_name = self.log_path + self.event_file_name - else: - self.full_file_name = self.log_path + '/' + self.event_file_name - try: - self.full_file_name = os.path.realpath(self.full_file_name) + self.full_file_name = os.path.join(self.log_path, self.event_file_name) except Exception as ex: raise RuntimeError(ex) - self.event_writer = EventRecord(self.full_file_name, self.flush_time) - self.writer_id = SummaryDataManager.summary_file_set(self.event_writer) - self.worker_scheduler = WorkerScheduler(self.writer_id) - - self.step = 0 - self._closed = False - self.network = network - self.has_graph = False + self.event_writer = EventWriter(self.full_file_name, self.flush_time) + self.event_writer.write(package_init_event().SerializeToString()) def record(self, step, train_network=None): """ @@ -145,42 +147,34 @@ class SummaryRecord: if not isinstance(step, int) or isinstance(step, bool): raise ValueError("`step` should be int") # Set the current summary of train step - self.step = step - if self.network is not None and self.has_graph is False: + if self.network is not None and not self.has_graph: graph_proto = self.network.get_func_graph_proto() if graph_proto is None and train_network is not None: graph_proto = train_network.get_func_graph_proto() if graph_proto is None: logger.error("Failed to get proto for graph") else: - self.event_writer.write_event_to_file( - package_graph_event(graph_proto).SerializeToString()) - self.event_writer.flush() + self.event_writer.write(package_graph_event(graph_proto).SerializeToString()) self.has_graph = True - data = _summary_tensor_cache.get("SummaryRecord") - if data is None: + if not _summary_tensor_cache: return True - data = _summary_tensor_cache.get("SummaryRecord") - if data is None: - logger.error("The step(%r) does not have record data.", self.step) + data = _get_summary_tensor_data() + if not data: + logger.error("The step(%r) does not have record data.", step) return False if self.queue_max_size > 0 and len(data) > self.queue_max_size: logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), self.queue_max_size) - # clean the data of cache - del _summary_tensor_cache["SummaryRecord"] - # process the data - self.worker_scheduler.dispatch(self.step, data) - - # count & flush - self.event_writer.count_event() - self.event_writer.flush_cycle() - - logger.debug("Send the summary data to scheduler for saving, step = %d", self.step) + result = self._data_convert(data) + if not result: + logger.error("The step(%r) summary data is invalid.", step) + return False + self.event_writer.write((result, step)) + logger.debug("Send the summary data to scheduler for saving, step = %d", step) return True @property @@ -196,7 +190,7 @@ class SummaryRecord: Returns: String, the full path of log file. """ - return self.event_writer.full_file_name + return self.full_file_name def flush(self): """ @@ -224,20 +218,44 @@ class SummaryRecord: >>> summary_record.close() """ if not self._closed: - self._check_data_before_close() - self.worker_scheduler.close() # event writer flush and close self.event_writer.close() self._closed = True - def __del__(self): - """Process exit is called.""" - if hasattr(self, "worker_scheduler"): - if self.worker_scheduler: - self.close() - - def _check_data_before_close(self): - "Check whether there is any data in the cache, and if so, call record" - data = _summary_tensor_cache.get("SummaryRecord") - if data is not None: - self.record(self.step) + def _data_convert(self, summary): + """Convert the data.""" + # convert the summary to numpy + result = [] + for name, data in summary.items(): + # confirm the data is valid + summary_tag, summary_type = SummaryRecord._parse_from(name) + if summary_tag is None: + logger.error("The data type is invalid, name = %r, tensor = %r", name, data) + return None + if isinstance(data, Tensor): + result.append({'name': summary_tag, 'data': data.asnumpy(), '_type': summary_type}) + else: + logger.error("The data type is invalid, name = %r, tensor = %r", name, data) + return None + + return result + + @staticmethod + def _parse_from(name: str = None): + """ + Parse the tag and type from name. + + Args: + name (str): Format: TAG[:TYPE]. + + Returns: + Tuple, (summary_tag, summary_type). + """ + if name is None: + logger.error("The name is None") + return None, None + match = re.match(r'(.+)\[:(.+)\]', name) + if match: + return match.groups() + logger.error("The name(%r) format is invalid, expected 'TAG[:TYPE]'.", name) + return None, None