提交 e50fa074 编写于 作者: M Megvii Engine Team

fix(mge/imperative): remove backward from optimizer

GitOrigin-RevId: ad6ad444faf0e39ba6a5936c5f3a57a3b2406c75
上级 60702667
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable as Iter
from typing import Optional, Union
from ..device import get_default_device
from ..distributed.group import get_client, is_distributed
from ..functional import add_update
from ..functional.distributed import WORLD, Group, all_reduce_sum, broadcast
from ..functional.utils import copy
from ..tensor import Tensor, TensorDict
from ..tensor_nn import Parameter
from .optimizer import Optimizer
from .param_pack import get_pack_list, pack_allreduce_split
class DistributedOptimizer(Optimizer):
r"""Add Distributed Func for distributed training.
:param params: specifies what Tensors should be optimized.
:param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
:param reduce_method: use all_reduce_sum or all_reduce_mean to reduce gradients
:param bcast_period: broadcasts params every *bcast_period* iterations.
if it equals to 0, it will broadcast params only at the beginning. Default: 500
:param param_pack: whether to pack gradients to avoid small packages send/recv. Default: False
:param param_pack_thd: max size of packed gradients by bytes. Default: 10 * 1024 * 1024
"""
def __init__(
self,
params: Union[Iter[Parameter], dict],
defaults: dict,
reduce_method: Optional[str] = None,
dist_group: Optional[Group] = WORLD,
bcast_period: int = 0,
param_pack: bool = False,
param_pack_thd: int = 10 * 1024 * 1024,
):
if is_distributed():
assert reduce_method in ["sum", "mean"], "reduce_method must be specified"
defaults["orders"] = []
defaults["dist_group"] = dist_group
super().__init__(params, defaults)
self._bcast_period = bcast_period
self._param_pack = param_pack
self._param_pack_thd = param_pack_thd
self._reduce_method = reduce_method
self.add_save_load_state_ignore_keys(
{"grads", "orders", "pack_list", "shape_list", "dist_group"}
)
if is_distributed() and bcast_period != -1:
self.bcast_param()
def grad_callback(self, grad, i, group):
if is_distributed() and group["dist_group"] is not None:
dist_group = group["dist_group"]
if self._param_pack and "pack_list" in group:
for pack, shapes in zip(group["pack_list"], group["shape_list"]):
if i == pack[-1]:
pack_allreduce_split(group, pack, shapes, self._reduce_method)
else:
group["orders"].append(i)
group["grads"][i] = all_reduce_sum(
grad, dist_group, dist_group.comp_node
)
if self._reduce_method == "mean":
group["grads"][i] /= dist_group.size
def _gen_pack_list(self, group):
if "pack_list" not in group:
dist_group = group["dist_group"]
if dist_group.rank == 0:
pack_list, shape_list = get_pack_list(group, self._param_pack_thd)
get_client().set_pack_list(dist_group.key, (pack_list, shape_list))
else:
pack_list, shape_list = get_client().get_pack_list(dist_group.key)
group["pack_list"] = pack_list
group["shape_list"] = shape_list
def backward(self, loss: Tensor):
ret = super().backward(loss)
if is_distributed():
for group in self.param_groups:
if self._param_pack and group["dist_group"] is not None:
self._gen_pack_list(group)
return ret
def step(self):
if is_distributed():
for group in self.param_groups:
device = get_default_device()
for param in group["params"]:
if param.__wrapped__ not in self._grad_skip:
if param.grad.device != device:
param.grad = copy(param.grad, device)
if self._bcast_period > 0:
self._bcast_iter += 1
if self._bcast_iter == self._bcast_period:
self.bcast_param()
self._bcast_iter = 0
super().step()
def bcast_param(self):
device = get_default_device()
for group in self.param_groups:
for param in group["params"]:
dist_group = group["dist_group"]
new_param = broadcast(param, dist_group)
if new_param.device != device:
new_param = copy(new_param, device)
add_update(param, new_param, alpha=0)
param._reset(new_param)
......@@ -11,22 +11,13 @@ from collections import Iterable
from contextlib import contextmanager
from typing import Dict
from typing import Iterable as Iter
from typing import Set, Union
from typing import Union
import numpy as np
from ..core.autodiff.grad import Grad
from ..device import get_default_device
from ..distributed.group import get_client, is_distributed
from ..functional import add_update
from ..functional.distributed import all_reduce_sum, broadcast
from ..functional.utils import copy
from ..logger import get_logger
from ..tensor import Tensor, TensorDict
from ..tensor_nn import Buffer, Parameter
logger = get_logger(__name__)
class _RequiredParameter:
def __repr__(self):
......@@ -43,10 +34,6 @@ class Optimizer(metaclass=ABCMeta):
:param defaults: a dict of default parameters of Optimizer, like learning rate or momentum.
"""
_recording = None
_grad = None
_gradients = None
def __init__( # pylint: disable=too-many-branches
self, params: Union[Iter[Parameter], dict], defaults: dict,
):
......@@ -63,7 +50,6 @@ class Optimizer(metaclass=ABCMeta):
)
self.param_groups = [] # type: list
self.save_load_state_ignore_keys = set()
param_groups = list(params)
if len(param_groups) == 0:
......@@ -154,100 +140,6 @@ class Optimizer(metaclass=ABCMeta):
params.append(param)
return params
def grad_callback(self, grad, i, group):
pass
def record(self):
@contextmanager
def recorder():
params = self._get_params()
grad = Grad()
gradients = [None] * len(params)
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for group in self.param_groups:
group["grads"] = [None] * len(group["params"])
for i, param in enumerate(group["params"]):
def callback(tensor, grad, i=i, group=group, self=self):
group["grads"][i] = grad
self.grad_callback(grad, i, group)
grad.wrt(param, callback=callback)
with grad:
yield
finally:
self._recording = False
self._grad = None
for group in self.param_groups:
group["grads"] = []
return recorder()
def _calculate_gradients(self, loss: Tensor):
if not self._recording:
raise RuntimeError(
"no computation history. "
"did you forget record() or "
"call a method that clears the history?"
)
assert self._grad is not None
if len(loss.__wrapped__._extra_data) == 0: # in case loss depends on no tensor
self._grad = None
return
one = Tensor([1.0], dtype=loss.dtype, device=loss.device)
one = one.reshape(loss.shape)
try:
self._grad(loss, one)
finally:
self._grad = None
def minimize(self, loss: Tensor):
self.backward(loss)
self.step()
def backward(self, loss: Tensor):
"""Computes the back-propagation of the network given loss.
:param loss: The obtained loss tensor
"""
rst = []
self._calculate_gradients(loss)
# _grad_skip records the parameters which are not in the path of backward
self._grad_skip = set()
for group in self.param_groups:
# _grad_skip is consumed in optimizer.step()
# XXX: assumptions
# 1. Assume the same execution sequence for all GPUs in data parallel
# 2. If backward is called by multiple times to accumulate grad,
# it's also assumed same _grad_skip for all backward() calls
# Please change the code if any assumption is invalid
for param, grad in zip(group["params"], group["grads"]):
if grad is None:
self._grad_skip.add(param.__wrapped__)
continue
grad = Buffer(grad)
if getattr(param, "grad", None) is None:
param.grad = grad
else:
assert isinstance(param.grad, Buffer)
param.grad += grad
rst.append(param.grad)
if len(self._grad_skip) > 0:
get_logger(__name__).warning(
"{} parameters have no grad! "
"Make sure you pass the right parameters list".format(
len(self._grad_skip)
)
)
return rst
def step(self):
r"""Performs a single optimization step.
......@@ -261,8 +153,8 @@ class Optimizer(metaclass=ABCMeta):
)
self._updates(group)
def zero_grad(self):
r"""Reset the grad to zeros.
def clear_grad(self):
r"""Clear the grad buffer.
"""
for param_group in self.param_groups:
......@@ -270,9 +162,6 @@ class Optimizer(metaclass=ABCMeta):
if getattr(param, "grad", None) is not None:
param.grad = None
def add_save_load_state_ignore_keys(self, keys: Set[str]):
self.save_load_state_ignore_keys |= keys
def state_dict(self) -> Dict:
r"""Export the optimizer state.
......@@ -293,11 +182,7 @@ class Optimizer(metaclass=ABCMeta):
state[param2id[param]] = st
for group in self.param_groups:
param_group = {
k: v
for k, v in group.items()
if k != "params" and k not in self.save_load_state_ignore_keys
}
param_group = {k: v for k, v in group.items() if k != "params"}
param_group["params"] = [param2id[param] for param in group["params"]]
param_groups.append(param_group)
......@@ -329,14 +214,12 @@ class Optimizer(metaclass=ABCMeta):
if isinstance(v, Buffer):
self._state[p][k] = Buffer(v.numpy())
new_keys = set(group_new.keys()) - self.save_load_state_ignore_keys
saved_keys = set(group_saved.keys()) - self.save_load_state_ignore_keys
if new_keys != saved_keys:
if set(group_new.keys()) != set(group_saved.keys()):
raise ValueError(
"loaded state dict contains a parameter group that "
"doesn't match the keys of optimizer's group"
)
for key in saved_keys:
for key in group_new.keys():
if key != "params":
group_new[key] = group_saved[key]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册