From b9762d714c7fe0e1c5918fccad5836b1cbb09469 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Dec 2020 20:02:53 +0800 Subject: [PATCH] fix(mge): make parampack run with tensor symbolic shape GitOrigin-RevId: 6fc313785d3cd926db9ab7a4872e9a1e18653511 --- imperative/python/megengine/core/tensor/tensor_wrapper.py | 1 - imperative/python/megengine/distributed/helper.py | 4 ++-- imperative/python/test/integration/test_dp_correctness.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 845792c98..eb417a75b 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -14,7 +14,6 @@ import numpy as np from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import Tensor, apply -from .._trace_option import use_symbolic_shape from ..ops import builtin from ..ops.builtin import Elemwise, GetVarShape from ..ops.special import Const diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index f94724a6e..2d0131bb6 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -218,7 +218,7 @@ class AllreduceCallback: if len(self._packing_list[dtype]) == 0: return grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] - shapes = [p.shape for p in self._packing_list[dtype]] + shapes = [p._tuple_shape for p in self._packing_list[dtype]] reduced_grads = pack_allreduce_split( grad_list, shapes, self._group, self._reduce_method ) @@ -241,7 +241,7 @@ class AllreduceCallback: dtype_str = str(np.dtype(param.dtype)) dtype_size = np.dtype(param.dtype).itemsize self._packing_list[dtype_str].append(param) - self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size + self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size if self._packing_size[dtype_str] > self._param_pack_thd: self._pack(dtype_str) return self._futures_dict[param] diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index 7e39c2966..85dad2b4a 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -194,7 +194,7 @@ def run_test( worker(max_err) -@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" -- GitLab