From e9104ef1572fc2d2295a5a7dbdef9a0f676b6514 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 8 Sep 2020 15:22:40 +0800 Subject: [PATCH] fix(mge/parampack): fix copy stream, import cycle GitOrigin-RevId: 673e11c5b6f1e87b2b498f111ba20f0406d2e57c --- .../python/megengine/autodiff/grad_manager.py | 5 ++-- .../python/megengine/core/autodiff/grad.py | 2 -- .../python/megengine/distributed/helper.py | 28 ++++++++++--------- .../python/megengine/distributed/server.py | 3 +- .../python/megengine/distributed/util.py | 19 ------------- .../python/megengine/functional/param_pack.py | 4 +-- imperative/python/megengine/utils/future.py | 26 +++++++++++++++++ 7 files changed, 48 insertions(+), 39 deletions(-) create mode 100644 imperative/python/megengine/utils/future.py diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index e34cec46f..528b805ee 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -2,8 +2,8 @@ from collections import defaultdict from contextlib import contextmanager from ..core.autodiff.grad import Grad -from ..distributed.util import Future from ..tensor import tensor +from ..utils.future import Future backwarding_grad_manager = None @@ -26,6 +26,7 @@ class GradManager: self._param_dict[id(p)] = p for cb in callbacks: self._call_back_dict[id(p)].append(cb) + return self def register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) @@ -45,7 +46,7 @@ class GradManager: if not isinstance(ys, (tuple, list)): ys = [ys] if dys is None: - dys = [tensor(1.0) for y in ys] + dys = [tensor(1.0).broadcast(y.shape) for y in ys] if not isinstance(dys, (tuple, list)): dys = [dys] try: diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index d21209375..c9d93e462 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -178,8 +178,6 @@ class Grad: assert len(ys) == len(dys) ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] - if len(ids) == 0: - return ys = [y for i, y in enumerate(ys) if i in ids] dys = [dy for i, dy in enumerate(dys) if i in ids] diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 2d3b64c6e..693da1e8d 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count from ..functional.param_pack import get_offsets, pack_allreduce_split from ..functional.utils import copy +from ..utils.future import Future from .functional import all_reduce_sum, broadcast from .group import WORLD, group_barrier, is_distributed -from .util import Future class FakeTensor(Future): @@ -77,7 +77,7 @@ class AllreduceCallback: assert reduce_method in ["sum", "mean"] self._reduce_method = reduce_method self._group = group - self._gm_set = set() + self._marked_gm = set() self._param_pack_thd = 10 * 1024 * 1024 self._reset() @@ -87,6 +87,7 @@ class AllreduceCallback: self._futures_dict = dict() self._packing_list = defaultdict(list) self._packing_size = defaultdict(int) + self._grad_origin_device = dict() def _pack(self, dtype): grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] @@ -102,27 +103,28 @@ class AllreduceCallback: def __call__(self, param, grad): gm = get_backwarding_grad_manager() assert isinstance(gm, GradManager) - if gm not in self._gm_set: + if gm not in self._marked_gm: gm.register_after_backward_callback(self._flush) - self._gm_set.add(gm) + self._marked_gm.add(gm) self._params.append(param) self._futures_dict[param] = FakeTensor(ack=False) self._gradients_dict[param] = grad - - self._packing_list[param.dtype].append(param) - self._packing_size[param.dtype] += ( - int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize - ) - if self._packing_size[param.dtype] > self._param_pack_thd: - self._pack(param.dtype) + self._grad_origin_device[param] = str(grad.device) + + dtype_str = str(np.dtype(param.dtype)) + dtype_size = np.dtype(param.dtype).itemsize + self._packing_list[dtype_str].append(param) + self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size + if self._packing_size[dtype_str] > self._param_pack_thd: + self._pack(dtype_str) return self._futures_dict[param] def _flush(self): - for dtype in self._packing_list.keys(): + for dtype in sorted(self._packing_list.keys()): self._pack(dtype) for param in self._params: grad = self._gradients_dict[param] - grad = copy(grad, get_default_device()) + grad = copy(grad, self._grad_origin_device[param]) self._futures_dict[param].set(grad) self._reset() diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index d6e9ba445..d8f199a6c 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -16,7 +16,8 @@ from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer from ..core._imperative_rt.utils import create_mm_server -from .util import Future, get_free_ports +from ..utils.future import Future +from .util import get_free_ports class Methods: diff --git a/imperative/python/megengine/distributed/util.py b/imperative/python/megengine/distributed/util.py index 9f5be3fab..b3a0a2aa1 100644 --- a/imperative/python/megengine/distributed/util.py +++ b/imperative/python/megengine/distributed/util.py @@ -8,28 +8,9 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import socket -import threading from typing import List -class Future: - def __init__(self, ack=True): - self.ready = threading.Event() - self.ack = threading.Event() if ack else None - - def set(self, value): - self.value = value - self.ready.set() - if self.ack: - self.ack.wait() - - def get(self): - self.ready.wait() - if self.ack: - self.ack.set() - return self.value - - def get_free_ports(num: int) -> List[int]: """Get one or more free ports. """ diff --git a/imperative/python/megengine/functional/param_pack.py b/imperative/python/megengine/functional/param_pack.py index 0b25e08b2..0ad3a11bf 100644 --- a/imperative/python/megengine/functional/param_pack.py +++ b/imperative/python/megengine/functional/param_pack.py @@ -8,8 +8,8 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -from ..functional.distributed import all_reduce_sum from ..tensor import Tensor +from .distributed import all_reduce_sum from .tensor import param_pack_concat, param_pack_split @@ -29,6 +29,6 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) if reduce_method == "mean": - packed_grads /= dist_group.size + packed_grads /= group.size grads = param_pack_split(packed_grads, offsets_val, shapes) return grads diff --git a/imperative/python/megengine/utils/future.py b/imperative/python/megengine/utils/future.py new file mode 100644 index 000000000..23ac3c4cf --- /dev/null +++ b/imperative/python/megengine/utils/future.py @@ -0,0 +1,26 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import threading + + +class Future: + def __init__(self, ack=True): + self.ready = threading.Event() + self.ack = threading.Event() if ack else None + + def set(self, value): + self.value = value + self.ready.set() + if self.ack: + self.ack.wait() + + def get(self): + self.ready.wait() + if self.ack: + self.ack.set() + return self.value -- GitLab