未验证 提交 92d973c1 编写于 作者: X Xu Jingxin 提交者: GitHub

feature(xjx): multiprocess tblogger, fix circular reference problem (#156)

* Fix recur reference in task and parallel, add distributed logger

* Update logger

* Clear ref list when exit task/parallel

* Put task in with statment

* Fix test

* FFix test

* Test is hard

* More comments
上级 1040c9fc
......@@ -1419,3 +1419,4 @@ formatted_*
eval_config.py
collect_demo_data_config.py
default*
!ding/**/*.py
......@@ -22,7 +22,8 @@ class Context(dict):
"""
ctx = Context()
for key in self._kept_keys:
ctx[key] = self[key]
if key in self:
ctx[key] = self[key]
return ctx
def keep(self, *keys: str) -> None:
......
import atexit
import os
import random
import threading
import time
from mpire.pool import WorkerPool
import pynng
......@@ -27,10 +26,10 @@ class Parallel(metaclass=SingletonMetaclass):
self._sock: Socket = None
self._rpc = {}
self._bind_addr = None
self._lock = threading.Lock()
self.is_active = False
self.attach_to = None
self.finished = False
self.node_id = None
def run(self, node_id: int, listen_to: str, attach_to: List[str] = None) -> None:
self.node_id = node_id
......@@ -187,6 +186,10 @@ now there are {} ports and {} workers".format(len(ports), n_workers)
def register_rpc(self, fn_name: str, fn: Callable) -> None:
self._rpc[fn_name] = fn
def unregister_rpc(self, fn_name: str) -> None:
if fn_name in self._rpc:
del self._rpc[fn_name]
def send_rpc(self, func_name: str, *args, **kwargs) -> None:
if self.is_active:
payload = {"f": func_name, "a": args, "k": kwargs}
......@@ -198,7 +201,8 @@ now there are {} ports and {} workers".format(len(ports), n_workers)
except Exception as e:
logging.warning("Error when unpacking message on node {}, msg: {}".format(self._bind_addr, e))
if payload["f"] in self._rpc:
self._rpc[payload["f"]](*payload["a"], **payload["k"])
fn = self._rpc[payload["f"]]
fn(*payload["a"], **payload["k"])
else:
logging.warning("There was no function named {} in rpc table".format(payload["f"]))
......@@ -224,6 +228,7 @@ now there are {} ports and {} workers".format(len(ports), n_workers)
def stop(self):
logging.info("Stopping parallel worker on address: {}".format(self._bind_addr))
self.finished = True
self._rpc.clear()
time.sleep(0.03)
if self._sock:
self._sock.close()
......
......@@ -225,8 +225,14 @@ class Task:
self.stop()
def stop(self) -> None:
self.emit("exit")
if self._thread_pool:
self._thread_pool.shutdown()
# The middleware and listeners may contain some methods that reference to task,
# If we do not clear them after the task exits, we may find that gc will not clean up the task object.
self.middleware.clear()
self.event_listeners.clear()
self.once_listeners.clear()
def sync(self) -> 'Task':
if self._loop:
......
......@@ -16,14 +16,14 @@ def parallel_main():
router.register_rpc("test_callback", test_callback)
# Wait for nodes to bind
time.sleep(0.7)
router.send_rpc("test_callback", "ping")
for _ in range(30):
router.send_rpc("test_callback", "ping")
if msg["ping"]:
break
time.sleep(0.03)
assert msg["ping"]
# Avoid can not receiving messages from each other after exit parallel
time.sleep(0.7)
@pytest.mark.unittest
......
......@@ -23,21 +23,21 @@ def test_serial_pipeline():
ctx.pipeline.append(1)
# Execute step1, step2 twice
task = Task()
for _ in range(2):
with Task() as task:
for _ in range(2):
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1, 0, 1]
# Renew and execute step1, step2
task.renew()
assert task.ctx.total_step == 1
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1, 0, 1]
# Renew and execute step1, step2
task.renew()
assert task.ctx.total_step == 1
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1]
assert task.ctx.pipeline == [0, 1]
# Test context inheritance
task.renew()
# Test context inheritance
task.renew()
@pytest.mark.unittest
......@@ -52,12 +52,12 @@ def test_serial_yield_pipeline():
def step1(ctx):
ctx.pipeline.append(1)
task = Task()
task.forward(step0)
task.forward(step1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
with Task() as task:
task.forward(step0)
task.forward(step1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
@pytest.mark.unittest
......@@ -71,16 +71,16 @@ def test_async_pipeline():
ctx.pipeline.append(1)
# Execute step1, step2 twice
task = Task(async_mode=True)
for _ in range(2):
task.forward(step0)
time.sleep(0.1)
task.forward(step1)
time.sleep(0.1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0, 1]
task.renew()
assert task.ctx.total_step == 1
with Task(async_mode=True) as task:
for _ in range(2):
task.forward(step0)
time.sleep(0.1)
task.forward(step1)
time.sleep(0.1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0, 1]
task.renew()
assert task.ctx.total_step == 1
@pytest.mark.unittest
......@@ -97,17 +97,16 @@ def test_async_yield_pipeline():
time.sleep(0.2)
ctx.pipeline.append(1)
task = Task(async_mode=True)
task.forward(step0)
task.forward(step1)
time.sleep(0.3)
task.backward().sync()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
with Task(async_mode=True) as task:
task.forward(step0)
task.forward(step1)
time.sleep(0.3)
task.backward().sync()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
def parallel_main():
task = Task()
sync_count = 0
def on_sync_parallel_ctx(ctx):
......@@ -115,14 +114,14 @@ def parallel_main():
assert isinstance(ctx, Context)
sync_count += 1
task.on("sync_parallel_ctx", on_sync_parallel_ctx)
task.use(lambda _: time.sleep(0.2 + random.random() / 10))
task.run(max_step=10)
assert sync_count > 0
with Task() as task:
task.on("sync_parallel_ctx", on_sync_parallel_ctx)
task.use(lambda _: time.sleep(0.2 + random.random() / 10))
task.run(max_step=10)
assert sync_count > 0
def parallel_main_eager():
task = Task()
sync_count = 0
def on_sync_parallel_ctx(ctx):
......@@ -130,11 +129,12 @@ def parallel_main_eager():
assert isinstance(ctx, Context)
sync_count += 1
task.on("sync_parallel_ctx", on_sync_parallel_ctx)
for _ in range(10):
task.forward(lambda _: time.sleep(0.2 + random.random() / 10))
task.renew()
assert sync_count > 0
with Task() as task:
task.on("sync_parallel_ctx", on_sync_parallel_ctx)
for _ in range(10):
task.forward(lambda _: time.sleep(0.2 + random.random() / 10))
task.renew()
assert sync_count > 0
@pytest.mark.unittest
......@@ -143,14 +143,6 @@ def test_parallel_pipeline():
Parallel.runner(n_parallel_workers=2)(parallel_main)
@pytest.mark.unittest
def test_copy_task():
t1 = Task(async_mode=True, n_async_workers=1)
t2 = copy.copy(t1)
assert t2.async_mode
assert t1 is not t2
def attach_mode_main_task():
with Task() as task:
task.use(lambda _: time.sleep(0.1))
......@@ -199,14 +191,14 @@ def test_attach_mode():
@pytest.mark.unittest
def test_label():
task = Task()
result = {}
task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"])
task.use(lambda _: result.setdefault("has_me", True))
task.run(max_step=1)
assert "not_me" not in result
assert "has_me" in result
with Task() as task:
result = {}
task.use(lambda _: result.setdefault("not_me", True), filter_labels=["async"])
task.use(lambda _: result.setdefault("has_me", True))
task.run(max_step=1)
assert "not_me" not in result
assert "has_me" in result
def sync_parallel_ctx_main():
......
......@@ -24,12 +24,12 @@ def test_step_timer():
# Lazy mode (with use statment)
step_timer = StepTimer()
task = Task()
task.use_step_wrapper(step_timer)
task.use(step1)
task.use(step2)
task.use(task.sequence(step3, step4))
task.run(3)
with Task() as task:
task.use_step_wrapper(step_timer)
task.use(step1)
task.use(step2)
task.use(task.sequence(step3, step4))
task.run(3)
assert len(step_timer.records) == 5
for records in step_timer.records.values():
......@@ -37,12 +37,12 @@ def test_step_timer():
# Eager mode (with forward statment)
step_timer = StepTimer()
task = Task()
task.use_step_wrapper(step_timer)
for _ in range(3):
task.forward(step1) # Step 1
task.forward(step2) # Step 2
task.renew()
with Task() as task:
task.use_step_wrapper(step_timer)
for _ in range(3):
task.forward(step1) # Step 1
task.forward(step2) # Step 2
task.renew()
assert len(step_timer.records) == 2
for records in step_timer.records.values():
......@@ -51,15 +51,19 @@ def test_step_timer():
# Wrapper in wrapper
step_timer1 = StepTimer()
step_timer2 = StepTimer()
task = Task()
task.use_step_wrapper(step_timer1)
task.use_step_wrapper(step_timer2)
task.use(step1)
task.use(step2)
task.run(3)
with Task() as task:
task.use_step_wrapper(step_timer1)
task.use_step_wrapper(step_timer2)
task.use(step1)
task.use(step2)
task.run(3)
assert len(step_timer1.records) == 2
assert len(step_timer2.records) == 2
try:
assert len(step_timer1.records) == 2
assert len(step_timer2.records) == 2
except:
print("ExceptionStepTimer", step_timer2.records)
raise Exception("StepTimer error")
for records in step_timer1.records.values():
assert len(records) == 3
for records in step_timer2.records.values():
......
......@@ -25,6 +25,7 @@ from .time_helper import build_time_helper, EasyTimer, WatchDog
from .type_helper import SequenceType
from .scheduler_helper import Scheduler
from .profiler_helper import Profiler, register_profiler
from .log_writer_helper import DistributedWriter
if ding.enable_linklink:
from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
......
......@@ -4,7 +4,7 @@ import os
import numpy as np
import yaml
from tabulate import tabulate
from tensorboardX import SummaryWriter
from .log_writer_helper import DistributedWriter
from typing import Optional, Tuple, Union, Dict, Any
......@@ -32,7 +32,7 @@ def build_logger(
name = 'default'
logger = LoggerFactory.create_logger(path, name=name) if need_text else None
tb_name = name + '_tb_logger'
tb_logger = SummaryWriter(os.path.join(path, tb_name)) if need_tb else None
tb_logger = DistributedWriter(os.path.join(path, tb_name)) if need_tb else None
return logger, tb_logger
......
from tensorboardX import SummaryWriter
from typing import TYPE_CHECKING
if TYPE_CHECKING:
# TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block.
# So if you import this module within TYPE_CHECKING, you will get code hints and other benefits.
# Here is a good answer on stackoverflow:
# https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
from ding.framework import Task
class DistributedWriter(SummaryWriter):
"""
Overview:
A simple subclass of SummaryWriter that supports writing to one process in multi-process mode.
The best way is to use it in conjunction with the ``task`` to take advantage of the message \
and event components of the task (see ``writer.plugin``).
"""
def __init__(self, *args, **kwargs):
self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True
# We need to write data to files lazily, so we should not use file writer in __init__,
# On the contrary, we will initialize the file writer when the user calls the
# add_* function for the first time
kwargs["write_to_disk"] = False
super().__init__(*args, **kwargs)
self._in_parallel = False
self._task = None
self._is_writer = False
self._lazy_initialized = False
def plugin(self, task: "Task", is_writer: bool = False) -> "DistributedWriter":
"""
Overview:
Plugin ``task``, so when using this writer in the task pipeline, it will automatically send requests\
to the main writer instead of writing it to the disk. So we can collect data from multiple processes\
and write them into one file.
Examples:
>>> DistributedWriter().plugin(task, is_writer=("node.0" in task.labels))
"""
if task.router.is_active:
self._in_parallel = True
self._task = task
self._is_writer = is_writer
if is_writer:
self.initialize()
self._lazy_initialized = True
task.router.register_rpc("distributed_writer", self._on_distributed_writer)
task.once("exit", lambda: self.close())
return self
def _on_distributed_writer(self, fn_name: str, *args, **kwargs):
if self._is_writer:
getattr(self, fn_name)(*args, **kwargs)
def initialize(self):
self.close()
self._write_to_disk = self._default_writer_to_disk
self._get_file_writer()
self._lazy_initialized = True
def __del__(self):
self.close()
def enable_parallel(fn_name, fn):
def _parallel_fn(self: DistributedWriter, *args, **kwargs):
if not self._lazy_initialized:
self.initialize()
if self._in_parallel and not self._is_writer:
self._task.router.send_rpc("distributed_writer", fn_name, *args, **kwargs)
else:
fn(self, *args, **kwargs)
return _parallel_fn
ready_to_parallel_fns = [
'add_audio',
'add_custom_scalars',
'add_custom_scalars_marginchart',
'add_custom_scalars_multilinechart',
'add_embedding',
'add_figure',
'add_graph',
'add_graph_deprecated',
'add_histogram',
'add_histogram_raw',
'add_hparams',
'add_image',
'add_image_with_boxes',
'add_images',
'add_mesh',
'add_onnx_graph',
'add_openvino_graph',
'add_pr_curve',
'add_pr_curve_raw',
'add_scalar',
'add_scalars',
'add_text',
'add_video',
]
for fn_name in ready_to_parallel_fns:
setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name)))
import random
from collections import deque
import numpy as np
import pytest
from easydict import EasyDict
import logging
from ding.utils.log_helper import build_logger, pretty_print
from ding.utils.file_helper import remove_file
......
import pytest
import time
import tempfile
import shutil
import os
from os import path
from ding.framework import Parallel
from ding.framework.task import Task
from ding.utils import DistributedWriter
def main_distributed_writer(tempdir):
with Task() as task:
time.sleep(task.router.node_id * 1) # Sleep 0 and 1, write to different files
tblogger = DistributedWriter(tempdir).plugin(task, is_writer=("node.0" in task.labels))
def _add_scalar(ctx):
n = 10
for i in range(n):
tblogger.add_scalar(str(task.router.node_id), task.router.node_id, ctx.total_step * n + i)
task.use(_add_scalar)
task.use(lambda _: time.sleep(0.2))
task.run(max_step=10)
time.sleep(0.3 + (1 - task.router.node_id) * 2)
@pytest.mark.unittest
def test_distributed_writer():
tempdir = path.join(tempfile.gettempdir(), "tblogger")
try:
Parallel.runner(n_parallel_workers=2)(main_distributed_writer, tempdir)
assert path.exists(tempdir)
assert len(os.listdir(tempdir)) == 1
finally:
if path.exists(tempdir):
shutil.rmtree(tempdir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册