diff --git a/paddle/fluid/operators/collective/recv_v2_op.cc b/paddle/fluid/operators/collective/recv_v2_op.cc index 39a9ed0c74ef59d8520147572b9ab0da8c567da2..daf123a6df5bf9575cd358bc553b358d95c5aea2 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cc @@ -34,21 +34,26 @@ class RecvOpV2 : public framework::OperatorWithKernel { ring_id, 0, platform::errors::InvalidArgument( "The ring_id (%d) for recv_v2 op must be non-negative.", ring_id)); - auto out_shape = ctx->Attrs().Get>("out_shape"); - PADDLE_ENFORCE_GE(out_shape.size(), 1, - platform::errors::InvalidArgument( - "The size of the output shape must be greater than 0 " - "but the value given is %d.", - out_shape.size())); - for (size_t i = 0; i < out_shape.size(); ++i) { - PADDLE_ENFORCE_GE(out_shape[i], 1, - platform::errors::InvalidArgument( - "The shape attribute for recv_v2 must be set " - "explicitly, but the %dth element is %d which " - "is less than 1.", - i, out_shape[i])); + + if (ctx->GetOutputsVarType("Out").front() == + framework::proto::VarType::LOD_TENSOR) { + auto out_shape = ctx->Attrs().Get>("out_shape"); + PADDLE_ENFORCE_GE( + out_shape.size(), 1, + platform::errors::InvalidArgument( + "The size of the output shape must be greater than 0 " + "but the value given is %d.", + out_shape.size())); + for (size_t i = 0; i < out_shape.size(); ++i) { + PADDLE_ENFORCE_GE(out_shape[i], 1, + platform::errors::InvalidArgument( + "The shape attribute for recv_v2 must be set " + "explicitly, but the %dth element is %d which " + "is less than 1.", + i, out_shape[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); } - ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); } protected: diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 7912733fa50cca735cbd76e22ed0124f79d0a61c..df94fee5223c6c13a2a6957390ac64b58f765cd5 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -40,13 +40,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument( "The peer (%d) for recv_v2 op must be non-negative.", peer)); - auto out = ctx.Output("Out"); - auto out_dims = out->dims(); - auto numel = out->numel(); - int data_type = ctx.Attr("dtype"); - framework::proto::VarType::Type type = - framework::proto::VarType::Type(data_type); - gpuStream_t stream = nullptr; auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); @@ -56,14 +49,40 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { } else { stream = comm->stream(); } - PADDLE_ENFORCE_LT( peer, comm->nranks(), platform::errors::InvalidArgument("The value of peer (%d) you set must " "be less than comm->nranks (%d).", peer, comm->nranks())); - out->mutable_data(out_dims, place); + + int data_type = ctx.Attr("dtype"); + framework::proto::VarType::Type type = + framework::proto::VarType::Type(data_type); ncclDataType_t dtype = platform::ToNCCLDataType(type); + + auto *out_var = ctx.OutputVar("Out"); + if (out_var->IsType()) { + auto out_array = out_var->GetMutable(); + for (size_t idx = 0; idx < out_array->size(); ++idx) { + VLOG(3) << "LodTensorArray: idx(" << idx << ")"; + auto out = &out_array->at(idx); + auto out_dims = out->dims(); + out->mutable_data(out_dims, place, 0); + auto numel = out->numel(); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + out->data(), numel, dtype, peer, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " recv " + << framework::product(out_dims) << " from " << peer; + } + return; + } + + auto out_shape = ctx.Attr>("out_shape"); + auto out = ctx.Output("Out"); + auto out_dims = out->dims(); + auto numel = out->numel(); + + out->mutable_data(out_dims, place); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( out->data(), numel, dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " recv " diff --git a/paddle/fluid/operators/collective/send_v2_op.cc b/paddle/fluid/operators/collective/send_v2_op.cc index c60d560e43baed37d1fc4392e8afc356ffdbd949..753a33268cc958f0c738fbbe5f32e11b3c1cb5cc 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cc @@ -38,6 +38,16 @@ class SendOpV2 : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + const framework::Variable* var = ctx.InputVar("X"); + if (var->IsType()) { + auto t_arr = var->Get(); + // NOTE(sandyhouse): Support an empty tensor array as Input. + // And set the kernel type is float. + if (t_arr.size() == 0) { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.device_context()); + } + } return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index c4f5d05e68fa8b7014559fb70e5b037a71e3f3d6..dc28910e9ec9cb12814dc07a9f6a37e3f272f126 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -28,9 +28,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { #if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ NCCL_VERSION_CODE >= 2703 - auto x = ctx.Input("X"); - int numel = x->numel(); - int rid = ctx.Attr("ring_id"); PADDLE_ENFORCE_GE( rid, 0, @@ -56,6 +53,25 @@ class SendOpV2CUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument("The value of peer (%d) you set must " "be less than comm->nranks (%d).", peer, comm->nranks())); + + auto* x_var = ctx.InputVar("X"); + if (x_var->IsType()) { + auto& x_array = x_var->Get(); + for (size_t idx = 0; idx < x_array.size(); idx++) { + VLOG(3) << "LodTensorArray: idx(" << idx << ")"; + auto& x = x_array.at(idx); + int numel = x.numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x.type()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + x.data(), numel, dtype, peer, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " send " + << framework::product(x.dims()) << " to " << peer; + } + return; + } + auto x = ctx.Input("X"); + int numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( x->data(), numel, dtype, peer, comm->comm(), stream)); diff --git a/python/paddle/fluid/tests/unittests/collective_sendrecv_op_array.py b/python/paddle/fluid/tests/unittests/collective_sendrecv_op_array.py new file mode 100644 index 0000000000000000000000000000000000000000..6876a70ce91bc0bf0e7ae1d3d27c321ff2b2a7e6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_sendrecv_op_array.py @@ -0,0 +1,95 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveSendRecv(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = self.global_ring_id + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", + shape=[10, 1000], + dtype='float64', + append_batch_size=False) + if self.rank == 0: + data1 = fluid.layers.assign( + np.array( + [[0, 1, 2]], dtype='float32')) + data2 = fluid.layers.assign( + np.array( + [[3, 4, 5]], dtype='float32')) + elif self.rank == 1: + data1 = fluid.layers.assign( + np.array( + [[3, 4, 5]], dtype='float32')) + data2 = fluid.layers.assign( + np.array( + [[0, 1, 2]], dtype='float32')) + tensor_array = fluid.layers.create_array(dtype='float32') + i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + fluid.layers.array_write(data1, i, tensor_array) + fluid.layers.array_write(data2, i + 1, tensor_array) + if self.rank == 0: + main_prog.global_block().append_op( + type="send_v2", + inputs={'X': tensor_array}, + attrs={ + 'ring_id': ring_id, + 'peer': 1, + 'use_calc_stream': True + }) + else: + main_prog.global_block().append_op( + type="recv_v2", + outputs={'Out': tensor_array}, + attrs={ + 'peer': 0, + 'ring_id': ring_id, + 'dtype': data1.dtype, + 'out_shape': [1, 3], + 'use_calc_stream': True, + }) + return tensor_array + + +if __name__ == "__main__": + runtime_main(TestCollectiveSendRecv, "sendrecv_array", 0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index 0c278f96dd555f5d5eee4e1a2d60edf3c356d2aa..31b8bafd16d1983797bc18c85b046216b8ce85eb 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -215,7 +215,7 @@ class TestDistBase(unittest.TestCase): "PYTHONPATH": os.getenv("PYTHONPATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), - "GLOG_v": "0", + "GLOG_v": "3", "NCCL_P2P_DISABLE": "1" } required_envs.update(need_envs) @@ -300,5 +300,14 @@ class TestDistBase(unittest.TestCase): self.assertTrue( np.allclose( tr1_out, need_result2, rtol=1e-05, atol=1e-05)) + elif col_type == "sendrecv_array": + need_result1 = np.array([[0, 1, 2]]) + need_result2 = np.array([[3, 4, 5]]) + self.assertTrue( + np.allclose( + tr1_out[0][0], need_result1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out[0][1], need_result2, rtol=1e-05, atol=1e-05)) else: pass diff --git a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py index 67c84a71bb3351eb0ff0d89b7bb93af38ea3f75d..40bacaf59d2f30e2e55c20c76e7da62adea71065 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py +++ b/python/paddle/fluid/tests/unittests/test_collective_sendrecv.py @@ -29,6 +29,10 @@ class TestSendRecvOp(TestDistBase): def test_sendrecv(self): self.check_with_place("collective_sendrecv_op.py", "sendrecv") + def test_sendrecv_array(self): + self.check_with_place("collective_sendrecv_op_array.py", + "sendrecv_array") + if __name__ == '__main__': unittest.main()