decl_raw_opr( 'collective_comm', inputs = [ Doc('input', 'Input var.'), Doc('key', 'The key to NCCL cliques. Operators with same key belong ' 'to the same NCCL operation.', 'str'), Doc('nr_devices', 'Total number of devices involved in the NCCL ' 'operation to which this operator belongs.', 'int'), Doc('is_root', 'whether this node is root node', 'bool'), Doc('rank', 'rank of this node, if is -1, generate one', 'int'), Doc('local_grad', 'whether use local grad', 'bool'), Doc('server_addr', 'rpc server ip address'), Doc('port', 'server rpc listening port'), Doc('param', 'The only component of *param* is *mode*, which refers to ' 'a specific NCCL operation type.', ':class:`~megbrain.opr_param_defs.CollectiveComm`'), Doc('dtype', 'Data type of inputs and outputs. Currently this is ' 'required by BROADCAST and optional to other operations. If ' 'specified, it must be consistent with the *dtype* of inputs (if ' 'any).', ':class:`~megbrain.opr_param_defs.DType`', 'None'), Doc('backend', 'Backend for collective communication, nccl or ucx', 'str', '\'nccl\''), Doc('output_buffer', 'The external dev buffer reserving output result', ':class:`.SharedND`', 'None'), Doc('disable', 'If true, the execution will return directly and the output ' 'is a random value. All the disable should be same in one collective ' 'communication group.', ':class:`.SharedScalar`', '_mgb.SharedScalar(0)') ], body = [ 'if isinstance(input, _mgb.SymbolVar):', (' output = _mgb._Opr.collective_comm_with_input(input, key, ' 'nr_devices, is_root, rank, local_grad, server_addr, port, ' '[param.serialize()], dtype, backend, output_buffer, config, disable)'), 'else:', ' assert isinstance(input, _mgb.CompGraph)', (' output = _mgb._Opr.collective_comm_without_input(input, key, ' 'nr_devices, is_root, rank, local_grad, server_addr, port, ' '[param.serialize()], dtype, backend, output_buffer, config, disable)') ], desc = ('collective communication between multiple CompNodes on multiple ' 'machines') ) # vim: ft=python