diff --git a/python_module/megengine/core/tensor_nn.py b/python_module/megengine/core/tensor_nn.py index 9c25df715e1085316aa7b7f5fad27ee2193ad08a..e2bbc927dd8db283401bf0c210d328506ce72f9e 100644 --- a/python_module/megengine/core/tensor_nn.py +++ b/python_module/megengine/core/tensor_nn.py @@ -31,6 +31,9 @@ class Parameter(Tensor): t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) self.__dict__.update(t.__dict__) + # broadcast and allreduce will not be performed in optimizer if replica_mode is False + self.replica_mode = True + @property def shape(self): r"""Return shape of parameter. diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index 64a53ad2f0407153e40ac8c09f2de50443457935..2596d26ae0e7574a8ddbf871ccf70d76ffe73562 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -178,7 +178,7 @@ class Optimizer(metaclass=ABCMeta): assert len(grads) == len(params) for param, grad in zip(params, grads): - if is_distributed(): + if is_distributed() and param.replica_mode: with opr_priority_scope(cg, -(2 ** 30)): # always run all_reduce_mean first except add_update grad = ( @@ -230,10 +230,14 @@ class Optimizer(metaclass=ABCMeta): key = 0 for group in self.param_groups: for param in group["params"]: - bcast_param( - param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, - ) - key += 1 + if param.replica_mode: + bcast_param( + param, + "bcast_param_" + str(key), + get_world_size(), + get_rank() == 0, + ) + key += 1 def state_dict(self) -> Dict: r"""Export the optimizer state.