提交 1a24fb29 编写于 作者: M Megvii Engine Team

perf(mge/allreduce): put allreduce on another cuda stream

GitOrigin-RevId: 2e778dfa0444ac2c2870b9dcfa72cfe7271fbc1a
上级 4a5e3170
...@@ -88,6 +88,7 @@ class AllreduceCallback: ...@@ -88,6 +88,7 @@ class AllreduceCallback:
self._futures_dict = dict() self._futures_dict = dict()
self._packing_list = defaultdict(list) self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int) self._packing_size = defaultdict(int)
self._grad_origin_device = dict()
def _pack(self, dtype): def _pack(self, dtype):
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
...@@ -109,6 +110,7 @@ class AllreduceCallback: ...@@ -109,6 +110,7 @@ class AllreduceCallback:
self._params.append(param) self._params.append(param)
self._futures_dict[param] = TensorFuture(ack=False) self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device)
dtype_str = str(np.dtype(param.dtype)) dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize dtype_size = np.dtype(param.dtype).itemsize
...@@ -123,6 +125,7 @@ class AllreduceCallback: ...@@ -123,6 +125,7 @@ class AllreduceCallback:
self._pack(dtype) self._pack(dtype)
for param in self._params: for param in self._params:
grad = self._gradients_dict[param] grad = self._gradients_dict[param]
grad = copy(grad, self._grad_origin_device[param])
self._futures_dict[param].set(grad) self._futures_dict[param].set(grad)
self._reset() self._reset()
......
...@@ -27,7 +27,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): ...@@ -27,7 +27,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes) offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val) offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group) packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
if reduce_method == "mean": if reduce_method == "mean":
packed_grads /= group.size packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes) grads = param_pack_split(packed_grads, offsets_val, shapes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册