collective_comm.oprdecl 2.3 KB
Newer Older
1 2 3 4 5 6 7 8
decl_raw_opr(
    'collective_comm',
    inputs = [
        Doc('input', 'Input var.'),
        Doc('key', 'The key to NCCL cliques. Operators with same key belong '
            'to the same NCCL operation.', 'str'),
        Doc('nr_devices', 'Total number of devices involved in the NCCL '
            'operation to which this operator belongs.', 'int'),
9 10
        Doc('is_root', 'whether this node is root node', 'bool'),
        Doc('rank', 'rank of this node, if is -1, generate one', 'int'),
11
        Doc('local_grad', 'whether use local grad', 'bool'),
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
        Doc('server_addr', 'rpc server ip address'),
        Doc('port', 'server rpc listening port'),
        Doc('param', 'The only component of *param* is *mode*, which refers to '
            'a specific NCCL operation type.',
            ':class:`~megbrain.opr_param_defs.CollectiveComm`'),
        Doc('dtype', 'Data type of inputs and outputs. Currently this is '
            'required by BROADCAST and optional to other operations. If '
            'specified, it must be consistent with the *dtype* of inputs (if '
            'any).', ':class:`~megbrain.opr_param_defs.DType`', 'None'),
        Doc('backend', 'Backend for collective communication, nccl or ucx',
            'str', '\'nccl\''),
        Doc('output_buffer', 'The external dev buffer reserving output result',
            ':class:`.SharedND`', 'None'),
        Doc('disable', 'If true, the execution will return directly and the output '
            'is a random value. All the disable should be same in one collective '
            'communication group.', ':class:`.SharedScalar`', '_mgb.SharedScalar(0)')
    ],
    body = [
         'if isinstance(input, _mgb.SymbolVar):',
        ('    output = _mgb._Opr.collective_comm_with_input(input, key, '
32
         'nr_devices, is_root, rank, local_grad, server_addr, port, '
33 34 35 36
         '[param.serialize()], dtype, backend, output_buffer, config, disable)'),
         'else:',
         '    assert isinstance(input, _mgb.CompGraph)',
        ('    output = _mgb._Opr.collective_comm_without_input(input, key, '
37
         'nr_devices, is_root, rank, local_grad, server_addr, port, '
38 39 40 41 42 43 44
         '[param.serialize()], dtype, backend, output_buffer, config, disable)')
    ],
    desc = ('collective communication between multiple CompNodes on multiple '
            'machines')
)

# vim: ft=python