提交 b9762d71 编写于 作者: M Megvii Engine Team

fix(mge): make parampack run with tensor symbolic shape

GitOrigin-RevId: 6fc313785d3cd926db9ab7a4872e9a1e18653511
上级 4d75f691
...@@ -14,7 +14,6 @@ import numpy as np ...@@ -14,7 +14,6 @@ import numpy as np
from .._imperative_rt.common import CompNode from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply from .._imperative_rt.core2 import Tensor, apply
from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape from ..ops.builtin import Elemwise, GetVarShape
from ..ops.special import Const from ..ops.special import Const
......
...@@ -218,7 +218,7 @@ class AllreduceCallback: ...@@ -218,7 +218,7 @@ class AllreduceCallback:
if len(self._packing_list[dtype]) == 0: if len(self._packing_list[dtype]) == 0:
return return
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
shapes = [p.shape for p in self._packing_list[dtype]] shapes = [p._tuple_shape for p in self._packing_list[dtype]]
reduced_grads = pack_allreduce_split( reduced_grads = pack_allreduce_split(
grad_list, shapes, self._group, self._reduce_method grad_list, shapes, self._group, self._reduce_method
) )
...@@ -241,7 +241,7 @@ class AllreduceCallback: ...@@ -241,7 +241,7 @@ class AllreduceCallback:
dtype_str = str(np.dtype(param.dtype)) dtype_str = str(np.dtype(param.dtype))
dtype_size = np.dtype(param.dtype).itemsize dtype_size = np.dtype(param.dtype).itemsize
self._packing_list[dtype_str].append(param) self._packing_list[dtype_str].append(param)
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
if self._packing_size[dtype_str] > self._param_pack_thd: if self._packing_size[dtype_str] > self._param_pack_thd:
self._pack(dtype_str) self._pack(dtype_str)
return self._futures_dict[param] return self._futures_dict[param]
......
...@@ -194,7 +194,7 @@ def run_test( ...@@ -194,7 +194,7 @@ def run_test(
worker(max_err) worker(max_err)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册