提交 b9762d71 编写于 作者: M Megvii Engine Team

fix(mge): make parampack run with tensor symbolic shape

GitOrigin-RevId: 6fc313785d3cd926db9ab7a4872e9a1e18653511
上级 4d75f691
......@@ -14,7 +14,6 @@ import numpy as np
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape
from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const
......
......@@ -218,7 +218,7 @@ class AllreduceCallback:
if len(self._packing_list[dtype]) == 0:
return
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p.shape for p in self._packing_list[dtype]]
shapes = [p._tuple_shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method
)
......@@ -241,7 +241,7 @@ class AllreduceCallback:
dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize
self._packing_list[dtype_str].append(param)
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size
self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
if self._packing_size[dtype_str] > self._param_pack_thd:
self._pack(dtype_str)
return self._futures_dict[param]
......
......@@ -194,7 +194,7 @@ def run_test(
worker(max_err)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device")
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册