diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 81cfc77b0b50a268bef6b1ef2f74139a0f303e39..05db40c16e07b3bc62510e78e9aa73ef9678b870 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -88,7 +88,6 @@ class AllreduceCallback: self._futures_dict = dict() self._packing_list = defaultdict(list) self._packing_size = defaultdict(int) - self._grad_origin_device = dict() def _pack(self, dtype): grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] @@ -110,7 +109,6 @@ class AllreduceCallback: self._params.append(param) self._futures_dict[param] = TensorFuture(ack=False) self._gradients_dict[param] = grad - self._grad_origin_device[param] = str(grad.device) dtype_str = str(np.dtype(param.dtype)) dtype_size = np.dtype(param.dtype).itemsize @@ -125,7 +123,6 @@ class AllreduceCallback: self._pack(dtype) for param in self._params: grad = self._gradients_dict[param] - grad = copy(grad, self._grad_origin_device[param]) self._futures_dict[param].set(grad) self._reset() diff --git a/imperative/python/megengine/functional/param_pack.py b/imperative/python/megengine/functional/param_pack.py index 0ad3a11bf4cf36261c70eec93d2990bb6bc4a78a..d7d52085de256de4529574bc758f5b45c41ddd9e 100644 --- a/imperative/python/megengine/functional/param_pack.py +++ b/imperative/python/megengine/functional/param_pack.py @@ -27,7 +27,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): offsets_val = get_offsets(shapes) offsets = Tensor(offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val) - packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) + packed_grads = all_reduce_sum(packed_grads, group) if reduce_method == "mean": packed_grads /= group.size grads = param_pack_split(packed_grads, offsets_val, shapes)