From 3f2eac2fe1e35ab2719815929541fa86f2cd8230 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Sep 2020 17:10:00 +0800 Subject: [PATCH] fix(mge/imperative): move functional/distributed.py to distributed/functional.py GitOrigin-RevId: 30cf2f514b9abc5e863e1fb26382008391cd607a --- .../megengine/distributed/functional.py | 295 +++++++++++++++++ .../megengine/functional/distributed.py | 309 +----------------- 2 files changed, 310 insertions(+), 294 deletions(-) create mode 100644 imperative/python/megengine/distributed/functional.py diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py new file mode 100644 index 00000000..c6162c53 --- /dev/null +++ b/imperative/python/megengine/distributed/functional.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from typing import Optional, Tuple + +from ..core._imperative_rt.ops import CollectiveCommMode +from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn +from ..core.autodiff.grad import ( + Tracer, + check_backward_allow_noinput, + get_grad_managers, + get_op_has_grad_fn, + tracer_apply, +) +from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend +from ..core.tensor.core import apply +from ..core.tensor.tensor import Tensor, tensor_apply +from ..tensor import tensor +from ..device import get_default_device +from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank + +__all__ = [ + "reduce_sum", + "broadcast", + "all_gather", + "reduce_scatter_sum", + "all_reduce_sum", + "all_reduce_max", + "all_reduce_min", + "gather", + "scatter", + "all_to_all", + "remote_send", + "remote_recv", +] + + +@apply.add +def _(op: RemoteSend, *args: Tensor): + ret = tensor_apply(op, *args) + + # set extra information + tracer_set = dict() + for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): + tracer_set[k.name] = True + + # check tracer_set in remote_recv + get_client().set_remote_tracer(op.key, tracer_set) + return ret + + +@builtin_op_get_backward_fn.register(RemoteSend) +def _(op: RemoteSend, inputs, outputs, input_requires_grad): + def backward(*args): + return [ + remote_recv( + op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) + ) + ] + + return backward, [True] + + +@get_op_has_grad_fn.register(RemoteSend) +def _(op: RemoteSend): + def has_grad(opnode, reached): + return get_client().check_is_grad(op.key) + + return has_grad + + +@check_backward_allow_noinput.register(RemoteSend) +def _(op: RemoteSend): + return True + + +@builtin_op_get_backward_fn.register(RemoteRecv) +def _(op: RemoteRecv, inputs, outputs, input_requires_grad): + def backward(*output_grads): + return [remote_send(output_grads[0], op.rank_from)] + + return backward, [True] + + +@get_op_has_grad_fn.register(RemoteRecv) +def _(op: RemoteRecv): + def has_grad(opnode, reached): + ret = False + for v in opnode.outputs: + if v() in reached: + ret = True + break + get_client().set_is_grad(op.key, ret) + return ret + + return has_grad + + +def collective_comm(inp, mode, group, device): + """Helper function for applying collective communication functions""" + assert isinstance(group, Group) + if group is None: + return inp + op = CollectiveComm() + op.key = group.key + op.nr_devices = group.size + op.rank = group.rank + op.is_root = op.rank == 0 + op.local_grad = False + op.addr, op.port = get_mm_server_addr() + op.mode = mode + op.dtype = inp.dtype + op.backend = get_backend() + op.comp_node = device + return apply(op, inp)[0] + + +def reduce_sum( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create reduce_sum operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.REDUCE_SUM + return collective_comm(inp, mode, group, device) + + +def broadcast( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create broadcast operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.BROADCAST + return collective_comm(inp, mode, group, device) + + +def all_gather( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create all_gather operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.ALL_GATHER + return collective_comm(inp, mode, group, device) + + +def reduce_scatter_sum( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create reduce_scatter_sum operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.REDUCE_SCATTER_SUM + return collective_comm(inp, mode, group, device) + + +def all_reduce_sum( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create all_reduce_sum operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.ALL_REDUCE_SUM + return collective_comm(inp, mode, group, device) + + +def all_reduce_max( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create all_reduce_max operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.ALL_REDUCE_MAX + return collective_comm(inp, mode, group, device) + + +def all_reduce_min( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create all_reduce_min operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.ALL_REDUCE_MIN + return collective_comm(inp, mode, group, device) + + +def gather( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create gather operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.GATHER + return collective_comm(inp, mode, group, device) + + +def scatter( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create scatter operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.SCATTER + return collective_comm(inp, mode, group, device) + + +def all_to_all( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + """Create all_to_all operator for collective communication + + :param inp: input tensor + :param group: communication group + :param device: execute placement + """ + mode = CollectiveCommMode.ALL_TO_ALL + return collective_comm(inp, mode, group, device) + + +def remote_send(inp: Tensor, dest_rank: int) -> Tensor: + """Send a Tensor to a remote process + + :param inp: tensor to send + :param dest_rank: destination process rank + """ + op = RemoteSend() + op.key = "{}->{}".format(get_rank(), dest_rank) + op.addr, op.port = get_mm_server_addr() + op.rank_to = dest_rank + return apply(op, inp)[0] + + +def remote_recv( + src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None +) -> Tensor: + """Receive a Tensor from a remote process + + :param src_rank: source process rank + :param shape: the shape of the tensor to receive + :param dtype: the data type of the tensor to receive + :param device: the device to place the received tensor + """ + key = "{}->{}".format(src_rank, get_rank()) + + if device is None: + device = get_default_device() + # dummpy input + inp = tensor([0]) + tracer_set = get_client().check_remote_tracer(key) + for grad_manager in get_grad_managers(): + if grad_manager.name in tracer_set: + grad_manager.wrt(inp) + + op = RemoteRecv() + op.key = key + op.cn = device + op.shape = shape + op.dtype = dtype + op.addr, op.port = get_mm_server_addr() + op.rank_from = src_rank + + return apply(op, inp)[0] diff --git a/imperative/python/megengine/functional/distributed.py b/imperative/python/megengine/functional/distributed.py index e10cef11..da28c194 100644 --- a/imperative/python/megengine/functional/distributed.py +++ b/imperative/python/megengine/functional/distributed.py @@ -6,298 +6,19 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Optional, Tuple - -from ..core._imperative_rt.ops import CollectiveCommMode -from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn -from ..core.autodiff.grad import ( - Tracer, - check_backward_allow_noinput, - get_grad_managers, - get_op_has_grad_fn, - tracer_apply, +# pylint: disable=redefined-builtin +from ..distributed.functional import ( + all_gather, + all_reduce_max, + all_reduce_min, + all_reduce_sum, + all_to_all, + broadcast, + collective_comm, + gather, + reduce_scatter_sum, + reduce_sum, + remote_recv, + remote_send, + scatter, ) -from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend -from ..core.tensor.core import apply -from ..core.tensor.tensor import Tensor -from ..device import get_default_device -from ..distributed.group import ( - WORLD, - Group, - get_backend, - get_client, - get_mm_server_addr, - get_rank, -) -from ..tensor import tensor - -__all__ = [ - "reduce_sum", - "broadcast", - "all_gather", - "reduce_scatter_sum", - "all_reduce_sum", - "all_reduce_max", - "all_reduce_min", - "gather", - "scatter", - "all_to_all", - "remote_send", - "remote_recv", -] - - -@apply.register() -def _(op: RemoteSend, *args: Tensor): - ret = apply.super(op, *args) - - # set extra information - tracer_set = dict() - for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): - tracer_set[k.name] = True - - # check tracer_set in remote_recv - get_client().set_remote_tracer(op.key, tracer_set) - return ret - - -@builtin_op_get_backward_fn.register(RemoteSend) -def _(op: RemoteSend, inputs, outputs, input_requires_grad): - def backward(*args): - return [ - remote_recv( - op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) - ) - ] - - return backward, [True] - - -@get_op_has_grad_fn.register(RemoteSend) -def _(op: RemoteSend): - def has_grad(opnode, reached): - return get_client().check_is_grad(op.key) - - return has_grad - - -@check_backward_allow_noinput.register(RemoteSend) -def _(op: RemoteSend): - return True - - -@builtin_op_get_backward_fn.register(RemoteRecv) -def _(op: RemoteRecv, inputs, outputs, input_requires_grad): - def backward(*output_grads): - return [remote_send(output_grads[0], op.rank_from)] - - return backward, [True] - - -@get_op_has_grad_fn.register(RemoteRecv) -def _(op: RemoteRecv): - def has_grad(opnode, reached): - ret = False - for v in opnode.outputs: - if v() in reached: - ret = True - break - get_client().set_is_grad(op.key, ret) - return ret - - return has_grad - - -def collective_comm(inp, mode, group, device): - """Helper function for applying collective communication functions""" - assert isinstance(group, Group) - if group is None: - return inp - op = CollectiveComm() - op.key = group.key - op.nr_devices = group.size - op.rank = group.rank - op.is_root = op.rank == 0 - op.local_grad = False - op.addr, op.port = get_mm_server_addr() - op.mode = mode - op.dtype = inp.dtype - op.backend = get_backend() - op.comp_node = device - return apply(op, inp)[0] - - -def reduce_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create reduce_sum operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.REDUCE_SUM - return collective_comm(inp, mode, group, device) - - -def broadcast( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create broadcast operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.BROADCAST - return collective_comm(inp, mode, group, device) - - -def all_gather( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create all_gather operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.ALL_GATHER - return collective_comm(inp, mode, group, device) - - -def reduce_scatter_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create reduce_scatter_sum operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.REDUCE_SCATTER_SUM - return collective_comm(inp, mode, group, device) - - -def all_reduce_sum( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create all_reduce_sum operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.ALL_REDUCE_SUM - return collective_comm(inp, mode, group, device) - - -def all_reduce_max( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create all_reduce_max operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.ALL_REDUCE_MAX - return collective_comm(inp, mode, group, device) - - -def all_reduce_min( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create all_reduce_min operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.ALL_REDUCE_MIN - return collective_comm(inp, mode, group, device) - - -def gather( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create gather operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.GATHER - return collective_comm(inp, mode, group, device) - - -def scatter( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create scatter operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.SCATTER - return collective_comm(inp, mode, group, device) - - -def all_to_all( - inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" -) -> Tensor: - """Create all_to_all operator for collective communication - - :param inp: input tensor - :param group: communication group - :param device: execute placement - """ - mode = CollectiveCommMode.ALL_TO_ALL - return collective_comm(inp, mode, group, device) - - -def remote_send(inp: Tensor, dest_rank: int) -> Tensor: - """Send a Tensor to a remote process - - :param inp: tensor to send - :param dest_rank: destination process rank - """ - op = RemoteSend() - op.key = "{}->{}".format(get_rank(), dest_rank) - op.addr, op.port = get_mm_server_addr() - op.rank_to = dest_rank - return apply(op, inp)[0] - - -def remote_recv( - src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None -) -> Tensor: - """Receive a Tensor from a remote process - - :param src_rank: source process rank - :param shape: the shape of the tensor to receive - :param dtype: the data type of the tensor to receive - :param device: the device to place the received tensor, - if None, use default device - """ - key = "{}->{}".format(src_rank, get_rank()) - if device is None: - device = get_default_device() - - # dummpy input - inp = tensor([0]) - tracer_set = get_client().check_remote_tracer(key) - for grad_manager in get_grad_managers(): - if grad_manager.name in tracer_set: - grad_manager.wrt(inp) - - op = RemoteRecv() - op.key = key - op.cn = device - op.shape = shape - op.dtype = dtype - op.addr, op.port = get_mm_server_addr() - op.rank_from = src_rank - - return apply(op, inp)[0] -- GitLab