提交 e9104ef1 编写于 作者: M Megvii Engine Team

fix(mge/parampack): fix copy stream, import cycle

GitOrigin-RevId: 673e11c5b6f1e87b2b498f111ba20f0406d2e57c
上级 e283663a
...@@ -2,8 +2,8 @@ from collections import defaultdict ...@@ -2,8 +2,8 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
from ..distributed.util import Future
from ..tensor import tensor from ..tensor import tensor
from ..utils.future import Future
backwarding_grad_manager = None backwarding_grad_manager = None
...@@ -26,6 +26,7 @@ class GradManager: ...@@ -26,6 +26,7 @@ class GradManager:
self._param_dict[id(p)] = p self._param_dict[id(p)] = p
for cb in callbacks: for cb in callbacks:
self._call_back_dict[id(p)].append(cb) self._call_back_dict[id(p)].append(cb)
return self
def register_after_backward_callback(self, callback): def register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
...@@ -45,7 +46,7 @@ class GradManager: ...@@ -45,7 +46,7 @@ class GradManager:
if not isinstance(ys, (tuple, list)): if not isinstance(ys, (tuple, list)):
ys = [ys] ys = [ys]
if dys is None: if dys is None:
dys = [tensor(1.0) for y in ys] dys = [tensor(1.0).broadcast(y.shape) for y in ys]
if not isinstance(dys, (tuple, list)): if not isinstance(dys, (tuple, list)):
dys = [dys] dys = [dys]
try: try:
......
...@@ -178,8 +178,6 @@ class Grad: ...@@ -178,8 +178,6 @@ class Grad:
assert len(ys) == len(dys) assert len(ys) == len(dys)
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
if len(ids) == 0:
return
ys = [y for i, y in enumerate(ys) if i in ids] ys = [y for i, y in enumerate(ys) if i in ids]
dys = [dy for i, dy in enumerate(dys) if i in ids] dys = [dy for i, dy in enumerate(dys) if i in ids]
......
...@@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count ...@@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count
from ..functional.param_pack import get_offsets, pack_allreduce_split from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy from ..functional.utils import copy
from ..utils.future import Future
from .functional import all_reduce_sum, broadcast from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed from .group import WORLD, group_barrier, is_distributed
from .util import Future
class FakeTensor(Future): class FakeTensor(Future):
...@@ -77,7 +77,7 @@ class AllreduceCallback: ...@@ -77,7 +77,7 @@ class AllreduceCallback:
assert reduce_method in ["sum", "mean"] assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method self._reduce_method = reduce_method
self._group = group self._group = group
self._gm_set = set() self._marked_gm = set()
self._param_pack_thd = 10 * 1024 * 1024 self._param_pack_thd = 10 * 1024 * 1024
self._reset() self._reset()
...@@ -87,6 +87,7 @@ class AllreduceCallback: ...@@ -87,6 +87,7 @@ class AllreduceCallback:
self._futures_dict = dict() self._futures_dict = dict()
self._packing_list = defaultdict(list) self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int) self._packing_size = defaultdict(int)
self._grad_origin_device = dict()
def _pack(self, dtype): def _pack(self, dtype):
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
...@@ -102,27 +103,28 @@ class AllreduceCallback: ...@@ -102,27 +103,28 @@ class AllreduceCallback:
def __call__(self, param, grad): def __call__(self, param, grad):
gm = get_backwarding_grad_manager() gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager) assert isinstance(gm, GradManager)
if gm not in self._gm_set: if gm not in self._marked_gm:
gm.register_after_backward_callback(self._flush) gm.register_after_backward_callback(self._flush)
self._gm_set.add(gm) self._marked_gm.add(gm)
self._params.append(param) self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False) self._futures_dict[param] = FakeTensor(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device)
self._packing_list[param.dtype].append(param)
self._packing_size[param.dtype] += ( dtype_str = str(np.dtype(param.dtype))
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize dtype_size = np.dtype(param.dtype).itemsize
) self._packing_list[dtype_str].append(param)
if self._packing_size[param.dtype] > self._param_pack_thd: self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
self._pack(param.dtype) if self._packing_size[dtype_str] > self._param_pack_thd:
self._pack(dtype_str)
return self._futures_dict[param] return self._futures_dict[param]
def _flush(self): def _flush(self):
for dtype in self._packing_list.keys(): for dtype in sorted(self._packing_list.keys()):
self._pack(dtype) self._pack(dtype)
for param in self._params: for param in self._params:
grad = self._gradients_dict[param] grad = self._gradients_dict[param]
grad = copy(grad, get_default_device()) grad = copy(grad, self._grad_origin_device[param])
self._futures_dict[param].set(grad) self._futures_dict[param].set(grad)
self._reset() self._reset()
......
...@@ -16,7 +16,8 @@ from xmlrpc.client import ServerProxy ...@@ -16,7 +16,8 @@ from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
from ..core._imperative_rt.utils import create_mm_server from ..core._imperative_rt.utils import create_mm_server
from .util import Future, get_free_ports from ..utils.future import Future
from .util import get_free_ports
class Methods: class Methods:
......
...@@ -8,28 +8,9 @@ ...@@ -8,28 +8,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools import functools
import socket import socket
import threading
from typing import List from typing import List
class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None
def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()
def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
def get_free_ports(num: int) -> List[int]: def get_free_ports(num: int) -> List[int]:
"""Get one or more free ports. """Get one or more free ports.
""" """
......
...@@ -8,8 +8,8 @@ ...@@ -8,8 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
from ..functional.distributed import all_reduce_sum
from ..tensor import Tensor from ..tensor import Tensor
from .distributed import all_reduce_sum
from .tensor import param_pack_concat, param_pack_split from .tensor import param_pack_concat, param_pack_split
...@@ -29,6 +29,6 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): ...@@ -29,6 +29,6 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean": if reduce_method == "mean":
packed_grads /= dist_group.size packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes) grads = param_pack_split(packed_grads, offsets_val, shapes)
return grads return grads
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import threading
class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None
def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()
def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册