diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 845792c98c7870b8cd7148e7926b3090aa75c575..eb417a75beeeca16165b49a570db09bf8f2d1e82 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -14,7 +14,6 @@ import numpy as np from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import Tensor, apply -from .._trace_option import use_symbolic_shape from ..ops import builtin from ..ops.builtin import Elemwise, GetVarShape from ..ops.special import Const diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index f94724a6e52026a9059d229032c149161d5f7995..2d0131bb6c215339d23a6a76dca10812a9a29198 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -218,7 +218,7 @@ class AllreduceCallback: if len(self._packing_list[dtype]) == 0: return grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] - shapes = [p.shape for p in self._packing_list[dtype]] + shapes = [p._tuple_shape for p in self._packing_list[dtype]] reduced_grads = pack_allreduce_split( grad_list, shapes, self._group, self._reduce_method ) @@ -241,7 +241,7 @@ class AllreduceCallback: dtype_str = str(np.dtype(param.dtype)) dtype_size = np.dtype(param.dtype).itemsize self._packing_list[dtype_str].append(param) - self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size + self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size if self._packing_size[dtype_str] > self._param_pack_thd: self._pack(dtype_str) return self._futures_dict[param] diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index 7e39c296638d53b45724ae321a2ef3bdc2287413..85dad2b4a3a2439cf095e9bcef2f5c28f1cca43c 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -194,7 +194,7 @@ def run_test( worker(max_err) -@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"