未验证 提交 c47bafc6 编写于 作者: L lilong12 提交者: GitHub

add send/recv api (#32504)

* add sendrecv, test=develop
上级 a7be32cc
......@@ -44,6 +44,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
{"assign", {"X"}},
{"send_v2", {"X"}},
{"reshape2", {"X", "Shape"}},
{"expand", {"X", "ExpandTimes"}},
{"slice", {"Input", "StartsTensor", "EndsTensor"}},
......@@ -123,6 +124,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"sync_batch_norm", {"MeanOut", "VarianceOut"}},
{"accuracy", {"Correct", "Total"}},
{"fill_constant", {"Out"}},
{"recv_v2", {"Out"}},
{"matmul", {"Out"}},
{"c_broadcast", {"Out"}},
{"c_sync_calc_stream", {"Out"}},
......
......@@ -37,6 +37,8 @@ __all__ = [
'barrier',
'split',
'ReduceOp',
'send',
'recv',
]
......@@ -1170,3 +1172,103 @@ def split(x,
name=name,
group=None)
return linear_out
def send(tensor, dst=0, group=None, use_calc_stream=True):
"""
Send a tensor to the receiver.
Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
dst (int): The destination rank id.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Whether to use calculate stream or communication stream.
Returns:
None.
Examples:
.. code-block:: python
import paddle
#from paddle.distributed import init_parallel_env
#init_parallel_env()
#if paddle.distributed.ParallelEnv().rank == 0:
# data = paddle.to_tensor([7, 8, 9])
# paddle.distributed.send(data, dst=1)
#else:
# data = paddle.to_tensor([1,2,3])
# paddle.distributed.recv(data, src=0)
#out = data.numpy()
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
op_type = 'send_v2'
if in_dygraph_mode():
return core.ops.send_v2(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', dst)
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'send')
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
attrs={
'ring_id': ring_id,
'peer': dst,
'use_calc_stream': use_calc_stream,
})
def recv(tensor, src=0, group=None, use_calc_stream=True):
"""
Receive a tensor to the sender.
Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank id.
group (Group): The group instance return by new_group or None for global default group.
use_calc_stream (bool): Whether to use calculate stream or communication stream.
Returns:
None.
Examples:
.. code-block:: python
import paddle
#from paddle.distributed import init_parallel_env
#init_parallel_env()
#if paddle.distributed.ParallelEnv().rank == 0:
# data = paddle.to_tensor([7, 8, 9])
# paddle.distributed.send(data, dst=1)
#else:
# data = paddle.to_tensor([1,2,3])
# paddle.distributed.recv(data, src=0)
#out = data.numpy()
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
op_type = 'recv_v2'
if in_dygraph_mode():
return core.ops.recv_v2(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'peer', src, 'dtype',
tensor.dtype, 'out_shape', tensor.shape)
check_variable_and_dtype(
tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'],
'recv')
helper = LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
outputs={'Out': [tensor]},
attrs={
'ring_id': ring_id,
'peer': src,
'out_shape': tensor.shape,
'dtype': tensor.dtype,
'use_calc_stream': use_calc_stream,
})
......@@ -96,6 +96,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_new_group_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_sendrecv_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
endif()
......@@ -871,6 +872,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
endif()
if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_new_group_api PROPERTIES TIMEOUT 120)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
paddle.enable_static()
class TestCollectiveSendRecvAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank):
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata",
shape=[10, 1000],
dtype='float32',
append_batch_size=False)
if rank == 0:
paddle.distributed.send(tindata, dst=1)
else:
paddle.distributed.recv(tindata, src=0)
return [tindata]
if __name__ == "__main__":
runtime_main(TestCollectiveSendRecvAPI, "sendrecv")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
class TestCollectiveSendRecvAPI(TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
tindata = paddle.to_tensor(indata)
if rank == 0:
paddle.distributed.send(tindata, dst=1)
else:
paddle.distributed.recv(tindata, src=0)
return [tindata.numpy()]
if __name__ == "__main__":
runtime_main(TestCollectiveSendRecvAPI, "sendrecv")
......@@ -33,7 +33,7 @@ from paddle.fluid import core
class TestCollectiveAPIRunnerBase(object):
def get_model(self, train_prog, startup_prog, rank):
def get_model(self, train_prog, startup_prog, rank, indata=None):
raise NotImplementedError(
"get model should be implemented by child class.")
......@@ -44,7 +44,6 @@ class TestCollectiveAPIRunnerBase(object):
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
result = self.get_model(train_prog, startup_prog, rank)
paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl':
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
......@@ -55,16 +54,21 @@ class TestCollectiveAPIRunnerBase(object):
place = fluid.XPUPlace(device_id)
else:
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid())
indata = np.random.random((10, 1000)).astype("float32")
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=fetch_list)
if args['static_mode']:
result = self.get_model(train_prog, startup_prog, rank)
exe = fluid.Executor(place)
exe.run(startup_prog)
fetch_list = []
for elem in result:
fetch_list.append(elem.name)
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=fetch_list)
else:
out = self.get_model(train_prog, startup_prog, rank, indata)
#print(out, sys.stderr)
if six.PY2:
print(pickle.dumps(out))
else:
......@@ -81,6 +85,7 @@ def runtime_main(test_class, col_type):
args["col_type"] = col_type
args["backend"] = os.getenv("BACKEND")
args["path_id"] = int(os.getenv("PATH_ID"))
args["static_mode"] = int(os.getenv("STATIC_MODE"))
model.run_trainer(args)
......@@ -186,6 +191,7 @@ class TestDistBase(unittest.TestCase):
col_type,
backend="nccl",
path_id="0",
static_mode="1",
check_error_log=False,
need_envs={}):
if backend == "nccl" or backend == "bkcl":
......@@ -199,8 +205,10 @@ class TestDistBase(unittest.TestCase):
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "0",
"FLAGS_call_stack_level": "2",
"GLOG_v": "3",
"NCCL_P2P_DISABLE": "1",
"STATIC_MODE": static_mode,
"PADDLE_WITH_GLOO": with_gloo,
"BACKEND": backend,
"PATH_ID": path_id
......@@ -269,5 +277,10 @@ class TestDistBase(unittest.TestCase):
self.assertTrue(
np.allclose(
result_data, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "sendrecv":
result_data = tr1_out[0]
self.assertTrue(
np.allclose(
input1, result_data, rtol=1e-05, atol=1e-05))
else:
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from test_collective_api_base import TestDistBase
paddle.enable_static()
class TestCollectiveSendRecvAPI(TestDistBase):
def _setup_config(self):
pass
#def test_sendrecv_nccl(self):
# if paddle.fluid.core.is_compiled_with_cuda():
# self.check_with_place("collective_sendrecv_api.py", "sendrecv",
# "nccl")
def test_sendrecv_nccl_dygraph(self):
if paddle.fluid.core.is_compiled_with_cuda():
self.check_with_place(
"collective_sendrecv_api_dygraph.py",
"sendrecv",
"nccl",
static_mode='0')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册