From 2b3a88d01129324b71f96270501d17aefa3cfb97 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 23 Jun 2020 18:54:23 +0800 Subject: [PATCH] feat(mge/distributed): add parameter replica_mode GitOrigin-RevId: 244e4ca437e3427d65d395f4be01b4fe6ed92e91 --- python_module/megengine/core/tensor_nn.py | 3 +++ python_module/megengine/optimizer/optimizer.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/python_module/megengine/core/tensor_nn.py b/python_module/megengine/core/tensor_nn.py index 9c25df715..e2bbc927d 100644 --- a/python_module/megengine/core/tensor_nn.py +++ b/python_module/megengine/core/tensor_nn.py @@ -31,6 +31,9 @@ class Parameter(Tensor): t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) self.__dict__.update(t.__dict__) + # broadcast and allreduce will not be performed in optimizer if replica_mode is False + self.replica_mode = True + @property def shape(self): r"""Return shape of parameter. diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index 64a53ad2f..2596d26ae 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -178,7 +178,7 @@ class Optimizer(metaclass=ABCMeta): assert len(grads) == len(params) for param, grad in zip(params, grads): - if is_distributed(): + if is_distributed() and param.replica_mode: with opr_priority_scope(cg, -(2 ** 30)): # always run all_reduce_mean first except add_update grad = ( @@ -230,10 +230,14 @@ class Optimizer(metaclass=ABCMeta): key = 0 for group in self.param_groups: for param in group["params"]: - bcast_param( - param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, - ) - key += 1 + if param.replica_mode: + bcast_param( + param, + "bcast_param_" + str(key), + get_world_size(), + get_rank() == 0, + ) + key += 1 def state_dict(self) -> Dict: r"""Export the optimizer state. -- GitLab