提交 2b3a88d0 编写于 作者: M Megvii Engine Team

feat(mge/distributed): add parameter replica_mode

GitOrigin-RevId: 244e4ca437e3427d65d395f4be01b4fe6ed92e91
上级 44c381b6
...@@ -31,6 +31,9 @@ class Parameter(Tensor): ...@@ -31,6 +31,9 @@ class Parameter(Tensor):
t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad)
self.__dict__.update(t.__dict__) self.__dict__.update(t.__dict__)
# broadcast and allreduce will not be performed in optimizer if replica_mode is False
self.replica_mode = True
@property @property
def shape(self): def shape(self):
r"""Return shape of parameter. r"""Return shape of parameter.
......
...@@ -178,7 +178,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -178,7 +178,7 @@ class Optimizer(metaclass=ABCMeta):
assert len(grads) == len(params) assert len(grads) == len(params)
for param, grad in zip(params, grads): for param, grad in zip(params, grads):
if is_distributed(): if is_distributed() and param.replica_mode:
with opr_priority_scope(cg, -(2 ** 30)): with opr_priority_scope(cg, -(2 ** 30)):
# always run all_reduce_mean first except add_update # always run all_reduce_mean first except add_update
grad = ( grad = (
...@@ -230,10 +230,14 @@ class Optimizer(metaclass=ABCMeta): ...@@ -230,10 +230,14 @@ class Optimizer(metaclass=ABCMeta):
key = 0 key = 0
for group in self.param_groups: for group in self.param_groups:
for param in group["params"]: for param in group["params"]:
bcast_param( if param.replica_mode:
param, "bcast_param_" + str(key), get_world_size(), get_rank() == 0, bcast_param(
) param,
key += 1 "bcast_param_" + str(key),
get_world_size(),
get_rank() == 0,
)
key += 1
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
r"""Export the optimizer state. r"""Export the optimizer state.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册