提交 aac8de55 编写于 作者: M Megvii Engine Team

fix(mge/parampacksplit): fix parampacksplit refcnt error

GitOrigin-RevId: c9644655963012eb06bbeb9617945a29a62ce7ac
上级 a7b9ece4
...@@ -88,7 +88,6 @@ class AllreduceCallback: ...@@ -88,7 +88,6 @@ class AllreduceCallback:
self._futures_dict = dict() self._futures_dict = dict()
self._packing_list = defaultdict(list) self._packing_list = defaultdict(list)
self._packing_size = defaultdict(int) self._packing_size = defaultdict(int)
self._grad_origin_device = dict()
def _pack(self, dtype): def _pack(self, dtype):
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
...@@ -110,7 +109,6 @@ class AllreduceCallback: ...@@ -110,7 +109,6 @@ class AllreduceCallback:
self._params.append(param) self._params.append(param)
self._futures_dict[param] = TensorFuture(ack=False) self._futures_dict[param] = TensorFuture(ack=False)
self._gradients_dict[param] = grad self._gradients_dict[param] = grad
self._grad_origin_device[param] = str(grad.device)
dtype_str = str(np.dtype(param.dtype)) dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize dtype_size = np.dtype(param.dtype).itemsize
...@@ -125,7 +123,6 @@ class AllreduceCallback: ...@@ -125,7 +123,6 @@ class AllreduceCallback:
self._pack(dtype) self._pack(dtype)
for param in self._params: for param in self._params:
grad = self._gradients_dict[param] grad = self._gradients_dict[param]
grad = copy(grad, self._grad_origin_device[param])
self._futures_dict[param].set(grad) self._futures_dict[param].set(grad)
self._reset() self._reset()
......
...@@ -27,7 +27,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): ...@@ -27,7 +27,7 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method):
offsets_val = get_offsets(shapes) offsets_val = get_offsets(shapes)
offsets = Tensor(offsets_val) offsets = Tensor(offsets_val)
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) packed_grads = all_reduce_sum(packed_grads, group)
if reduce_method == "mean": if reduce_method == "mean":
packed_grads /= group.size packed_grads /= group.size
grads = param_pack_split(packed_grads, offsets_val, shapes) grads = param_pack_split(packed_grads, offsets_val, shapes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册