提交 e9104ef1 编写于 作者: M Megvii Engine Team

fix(mge/parampack): fix copy stream, import cycle

GitOrigin-RevId: 673e11c5b6f1e87b2b498f111ba20f0406d2e57c
上级 e283663a
......@@ -2,8 +2,8 @@ from collections import defaultdict
from contextlib import contextmanager
from ..core.autodiff.grad import Grad
from ..distributed.util import Future
from ..tensor import tensor
from ..utils.future import Future
backwarding_grad_manager = None
......@@ -26,6 +26,7 @@ class GradManager:
self._param_dict[id(p)] = p
for cb in callbacks:
self._call_back_dict[id(p)].append(cb)
return self
def register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
......@@ -45,7 +46,7 @@ class GradManager:
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:
dys = [tensor(1.0) for y in ys]
dys = [tensor(1.0).broadcast(y.shape) for y in ys]
if not isinstance(dys, (tuple, list)):
dys = [dys]
try:
......
......@@ -178,8 +178,6 @@ class Grad:
assert len(ys) == len(dys)
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
if len(ids) == 0:
return
ys = [y for i, y in enumerate(ys) if i in ids]
dys = [dy for i, dy in enumerate(dys) if i in ids]
......
......@@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count
from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
from ..utils.future import Future
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
from .util import Future
class FakeTensor(Future):
......@@ -77,7 +77,7 @@ class AllreduceCallback:
assert reduce_method in ["sum", "mean"]
self._reduce_method = reduce_method
self._group = group
self._gm_set = set()
self._marked_gm = set()
self._param_pack_thd = 10 * 1024 * 1024
self._reset()
......@@ -87,6 +87,7 @@ class AllreduceCallback:
self._futures_dict = dict()
self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int)
self._grad_origin_device = dict()
def _pack(self, dtype):
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
......@@ -102,27 +103,28 @@ class AllreduceCallback:
def __call__(self, param, grad):
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._gm_set:
if gm not in self._marked_gm:
gm.register_after_backward_callback(self._flush)
self._gm_set.add(gm)
self._marked_gm.add(gm)
self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
self._gradients_dict[param] = grad
self._packing_list[param.dtype].append(param)
self._packing_size[param.dtype] += (
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize
)
if self._packing_size[param.dtype] > self._param_pack_thd:
self._pack(param.dtype)
self._grad_origin_device[param] = str(grad.device)
dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize
self._packing_list[dtype_str].append(param)
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
if self._packing_size[dtype_str] > self._param_pack_thd:
self._pack(dtype_str)
return self._futures_dict[param]
def _flush(self):
for dtype in self._packing_list.keys():
for dtype in sorted(self._packing_list.keys()):
self._pack(dtype)
for param in self._params:
grad = self._gradients_dict[param]
grad = copy(grad, get_default_device())
grad = copy(grad, self._grad_origin_device[param])
self._futures_dict[param].set(grad)
self._reset()
......
......@@ -16,7 +16,8 @@ from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
from ..core._imperative_rt.utils import create_mm_server
from .util import Future, get_free_ports
from ..utils.future import Future
from .util import get_free_ports
class Methods:
......
......@@ -8,28 +8,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import socket
import threading
from typing import List
class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None
def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()
def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
def get_free_ports(num: int) -> List[int]:
"""Get one or more free ports.
"""
......
......@@ -8,8 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
from ..functional.distributed import all_reduce_sum
from ..tensor import Tensor
from .distributed import all_reduce_sum
from .tensor import param_pack_concat, param_pack_split
......@@ -29,6 +29,6 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean":
packed_grads /= dist_group.size
packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes)
return grads
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import threading
class Future:
def __init__(self, ack=True):
self.ready = threading.Event()
self.ack = threading.Event() if ack else None
def set(self, value):
self.value = value
self.ready.set()
if self.ack:
self.ack.wait()
def get(self):
self.ready.wait()
if self.ack:
self.ack.set()
return self.value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册