diff --git a/imperative/python/megengine/optimizer/distributed_optimizer.py b/imperative/python/megengine/optimizer/distributed_optimizer.py deleted file mode 100644 index 86168c9ad844f9d378f277f1721b77521b021746..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/optimizer/distributed_optimizer.py +++ /dev/null @@ -1,120 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Iterable as Iter -from typing import Optional, Union - -from ..device import get_default_device -from ..distributed.group import get_client, is_distributed -from ..functional import add_update -from ..functional.distributed import WORLD, Group, all_reduce_sum, broadcast -from ..functional.utils import copy -from ..tensor import Tensor, TensorDict -from ..tensor_nn import Parameter -from .optimizer import Optimizer -from .param_pack import get_pack_list, pack_allreduce_split - - -class DistributedOptimizer(Optimizer): - r"""Add Distributed Func for distributed training. - - :param params: specifies what Tensors should be optimized. - :param defaults: a dict of default parameters of Optimizer, like learning rate or momentum. - :param reduce_method: use all_reduce_sum or all_reduce_mean to reduce gradients - :param bcast_period: broadcasts params every *bcast_period* iterations. - if it equals to 0, it will broadcast params only at the beginning. Default: 500 - :param param_pack: whether to pack gradients to avoid small packages send/recv. Default: False - :param param_pack_thd: max size of packed gradients by bytes. Default: 10 * 1024 * 1024 - """ - - def __init__( - self, - params: Union[Iter[Parameter], dict], - defaults: dict, - reduce_method: Optional[str] = None, - dist_group: Optional[Group] = WORLD, - bcast_period: int = 0, - param_pack: bool = False, - param_pack_thd: int = 10 * 1024 * 1024, - ): - if is_distributed(): - assert reduce_method in ["sum", "mean"], "reduce_method must be specified" - defaults["orders"] = [] - defaults["dist_group"] = dist_group - super().__init__(params, defaults) - self._bcast_period = bcast_period - self._param_pack = param_pack - self._param_pack_thd = param_pack_thd - self._reduce_method = reduce_method - - self.add_save_load_state_ignore_keys( - {"grads", "orders", "pack_list", "shape_list", "dist_group"} - ) - - if is_distributed() and bcast_period != -1: - self.bcast_param() - - def grad_callback(self, grad, i, group): - if is_distributed() and group["dist_group"] is not None: - dist_group = group["dist_group"] - if self._param_pack and "pack_list" in group: - for pack, shapes in zip(group["pack_list"], group["shape_list"]): - if i == pack[-1]: - pack_allreduce_split(group, pack, shapes, self._reduce_method) - else: - group["orders"].append(i) - group["grads"][i] = all_reduce_sum( - grad, dist_group, dist_group.comp_node - ) - if self._reduce_method == "mean": - group["grads"][i] /= dist_group.size - - def _gen_pack_list(self, group): - if "pack_list" not in group: - dist_group = group["dist_group"] - if dist_group.rank == 0: - pack_list, shape_list = get_pack_list(group, self._param_pack_thd) - get_client().set_pack_list(dist_group.key, (pack_list, shape_list)) - else: - pack_list, shape_list = get_client().get_pack_list(dist_group.key) - group["pack_list"] = pack_list - group["shape_list"] = shape_list - - def backward(self, loss: Tensor): - ret = super().backward(loss) - if is_distributed(): - for group in self.param_groups: - if self._param_pack and group["dist_group"] is not None: - self._gen_pack_list(group) - return ret - - def step(self): - if is_distributed(): - for group in self.param_groups: - device = get_default_device() - for param in group["params"]: - if param.__wrapped__ not in self._grad_skip: - if param.grad.device != device: - param.grad = copy(param.grad, device) - if self._bcast_period > 0: - self._bcast_iter += 1 - if self._bcast_iter == self._bcast_period: - self.bcast_param() - self._bcast_iter = 0 - super().step() - - def bcast_param(self): - device = get_default_device() - for group in self.param_groups: - for param in group["params"]: - dist_group = group["dist_group"] - new_param = broadcast(param, dist_group) - if new_param.device != device: - new_param = copy(new_param, device) - add_update(param, new_param, alpha=0) - param._reset(new_param) diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index f5bf18b0797f2cf739b7a41f0f2d33bf09635bf5..e4205bde92c32b41ec361e5d7c15e3545b25ddcf 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -11,22 +11,13 @@ from collections import Iterable from contextlib import contextmanager from typing import Dict from typing import Iterable as Iter -from typing import Set, Union +from typing import Union import numpy as np -from ..core.autodiff.grad import Grad -from ..device import get_default_device -from ..distributed.group import get_client, is_distributed -from ..functional import add_update -from ..functional.distributed import all_reduce_sum, broadcast -from ..functional.utils import copy -from ..logger import get_logger from ..tensor import Tensor, TensorDict from ..tensor_nn import Buffer, Parameter -logger = get_logger(__name__) - class _RequiredParameter: def __repr__(self): @@ -43,10 +34,6 @@ class Optimizer(metaclass=ABCMeta): :param defaults: a dict of default parameters of Optimizer, like learning rate or momentum. """ - _recording = None - _grad = None - _gradients = None - def __init__( # pylint: disable=too-many-branches self, params: Union[Iter[Parameter], dict], defaults: dict, ): @@ -63,7 +50,6 @@ class Optimizer(metaclass=ABCMeta): ) self.param_groups = [] # type: list - self.save_load_state_ignore_keys = set() param_groups = list(params) if len(param_groups) == 0: @@ -154,100 +140,6 @@ class Optimizer(metaclass=ABCMeta): params.append(param) return params - def grad_callback(self, grad, i, group): - pass - - def record(self): - @contextmanager - def recorder(): - params = self._get_params() - grad = Grad() - gradients = [None] * len(params) - if self._recording: - raise RuntimeError("already recording!") - try: - self._recording = True - self._grad = grad - for group in self.param_groups: - group["grads"] = [None] * len(group["params"]) - for i, param in enumerate(group["params"]): - - def callback(tensor, grad, i=i, group=group, self=self): - group["grads"][i] = grad - self.grad_callback(grad, i, group) - - grad.wrt(param, callback=callback) - with grad: - yield - finally: - self._recording = False - self._grad = None - for group in self.param_groups: - group["grads"] = [] - - return recorder() - - def _calculate_gradients(self, loss: Tensor): - if not self._recording: - raise RuntimeError( - "no computation history. " - "did you forget record() or " - "call a method that clears the history?" - ) - assert self._grad is not None - - if len(loss.__wrapped__._extra_data) == 0: # in case loss depends on no tensor - self._grad = None - return - - one = Tensor([1.0], dtype=loss.dtype, device=loss.device) - one = one.reshape(loss.shape) - try: - self._grad(loss, one) - finally: - self._grad = None - - def minimize(self, loss: Tensor): - self.backward(loss) - self.step() - - def backward(self, loss: Tensor): - """Computes the back-propagation of the network given loss. - - :param loss: The obtained loss tensor - """ - rst = [] - self._calculate_gradients(loss) - - # _grad_skip records the parameters which are not in the path of backward - self._grad_skip = set() - for group in self.param_groups: - # _grad_skip is consumed in optimizer.step() - # XXX: assumptions - # 1. Assume the same execution sequence for all GPUs in data parallel - # 2. If backward is called by multiple times to accumulate grad, - # it's also assumed same _grad_skip for all backward() calls - # Please change the code if any assumption is invalid - for param, grad in zip(group["params"], group["grads"]): - if grad is None: - self._grad_skip.add(param.__wrapped__) - continue - grad = Buffer(grad) - if getattr(param, "grad", None) is None: - param.grad = grad - else: - assert isinstance(param.grad, Buffer) - param.grad += grad - rst.append(param.grad) - if len(self._grad_skip) > 0: - get_logger(__name__).warning( - "{} parameters have no grad! " - "Make sure you pass the right parameters list".format( - len(self._grad_skip) - ) - ) - return rst - def step(self): r"""Performs a single optimization step. @@ -261,8 +153,8 @@ class Optimizer(metaclass=ABCMeta): ) self._updates(group) - def zero_grad(self): - r"""Reset the grad to zeros. + def clear_grad(self): + r"""Clear the grad buffer. """ for param_group in self.param_groups: @@ -270,9 +162,6 @@ class Optimizer(metaclass=ABCMeta): if getattr(param, "grad", None) is not None: param.grad = None - def add_save_load_state_ignore_keys(self, keys: Set[str]): - self.save_load_state_ignore_keys |= keys - def state_dict(self) -> Dict: r"""Export the optimizer state. @@ -293,11 +182,7 @@ class Optimizer(metaclass=ABCMeta): state[param2id[param]] = st for group in self.param_groups: - param_group = { - k: v - for k, v in group.items() - if k != "params" and k not in self.save_load_state_ignore_keys - } + param_group = {k: v for k, v in group.items() if k != "params"} param_group["params"] = [param2id[param] for param in group["params"]] param_groups.append(param_group) @@ -329,14 +214,12 @@ class Optimizer(metaclass=ABCMeta): if isinstance(v, Buffer): self._state[p][k] = Buffer(v.numpy()) - new_keys = set(group_new.keys()) - self.save_load_state_ignore_keys - saved_keys = set(group_saved.keys()) - self.save_load_state_ignore_keys - if new_keys != saved_keys: + if set(group_new.keys()) != set(group_saved.keys()): raise ValueError( "loaded state dict contains a parameter group that " "doesn't match the keys of optimizer's group" ) - for key in saved_keys: + for key in group_new.keys(): if key != "params": group_new[key] = group_saved[key]