diff --git a/paddle/fluid/operators/collective/gather_op_v2.cc b/paddle/fluid/operators/collective/gather_op_v2.cc new file mode 100644 index 0000000000000000000000000000000000000000..8434f63586d951b1b46e282affd0e2df1c993d36 --- /dev/null +++ b/paddle/fluid/operators/collective/gather_op_v2.cc @@ -0,0 +1,97 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/gather_op_v2.h" + +namespace paddle { +namespace operators { + +class GatherOpV2 : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CGather"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CGather"); + int root_id = ctx->Attrs().Get("root"); + int ring_id = ctx->Attrs().Get("ring_id"); + int nranks = ctx->Attrs().Get("nranks"); + PADDLE_ENFORCE_GE(nranks, 2, + platform::errors::InvalidArgument( + "The number of ranks (%d) must be greater than 1 " + "to use collective op (gather_op_v2).", + nranks)); + PADDLE_ENFORCE_GE( + root_id, 0, + platform::errors::InvalidArgument( + "The root_id (%d) for gather_op_v2 must be non-negative.", + root_id)); + PADDLE_ENFORCE_LT( + root_id, nranks, + platform::errors::InvalidArgument( + "The root_id (%d) for gather_op_v2 must be less than nranks (%d).", + root_id, nranks)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for gather_op_v2 must be non-negative.", + root_id)); + framework::DDim dim = ctx->GetInputDim("X"); + dim[0] = dim[0] * nranks; + if (dim[0] < 0) dim[0] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class GatherOpV2Maker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be gathered."); + AddOutput("Out", "(Tensor) the result of gather."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("root", "(int default 0) root id for broadcasting.") + .SetDefault(0); + AddAttr("nranks", "(int default 1) number of ranks.").SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +Gather Operator +Gather tensors from all participators. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(gather_v2, ops::GatherOpV2, ops::GatherOpV2Maker); + +REGISTER_OP_CPU_KERNEL(gather_v2, ops::CGatherOpV2CPUKernel, + ops::GatherOpV2CPUKernel, + ops::GatherOpV2CPUKernel, + ops::GatherOpV2CPUKernel, + ops::GatherOpV2CPUKernel); diff --git a/paddle/fluid/operators/collective/gather_op_v2.cu.cc b/paddle/fluid/operators/collective/gather_op_v2.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..76f24eb3dd1f014d0096bb9edd0409337e171136 --- /dev/null +++ b/paddle/fluid/operators/collective/gather_op_v2.cu.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/gather_op_v2.h" + +#if defined(PADDLE_WITH_NCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class GatherOpV2CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int send_numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int nranks = ctx.Attr("nranks"); + int root_id = ctx.Attr("root"); + int ring_id = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); + PADDLE_ENFORCE_EQ(nranks, comm->nranks(), + platform::errors::InvalidArgument( + "The number of ranks (%d) you set of must " + "be equal to comm->nranks (%d).", + nranks, comm->nranks())); + PADDLE_ENFORCE_GE( + root_id, 0, + platform::errors::InvalidArgument( + "The root_id (%d) for gather_op_v2 must be non-negative.", + root_id)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for gather_op_v2 must be non-negative.", + ring_id)); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + framework::DDim x_dims = x->dims(); + framework::DDim out_dims(x_dims); + out_dims[0] *= nranks; + auto send_buf = x->data(); + auto offset = 0; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( + send_buf, send_numel, dtype, root_id, comm->comm(), stream)); + if (root_id == comm->rank()) { + auto recv_buf = out->mutable_data(out_dims, place); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf + offset, send_numel, dtype, i, comm->comm(), stream)); + offset += send_numel; + } + } + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(gather_v2, ops::GatherOpV2CUDAKernel, + ops::GatherOpV2CUDAKernel, + ops::GatherOpV2CUDAKernel, + ops::GatherOpV2CUDAKernel, + ops::GatherOpV2CUDAKernel); diff --git a/paddle/fluid/operators/collective/gather_op_v2.h b/paddle/fluid/operators/collective/gather_op_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..f600ea9245978804341091bcd63190346ea94fe2 --- /dev/null +++ b/paddle/fluid/operators/collective/gather_op_v2.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include +#include "paddle/fluid/framework/fleet/gloo_wrapper.h" +#endif + +namespace paddle { +namespace operators { + +template +class GatherOpV2CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_GLOO) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto root_id = ctx.Attr("root"); + auto nranks = ctx.Attr("nranks"); + + auto gloo = paddle::framework::GlooWrapper::GetInstance(); + PADDLE_ENFORCE_EQ( + gloo->IsInitialized(), true, + platform::errors::PreconditionNotMet( + "You must initialize the gloo environment first to use it.")); + + PADDLE_ENFORCE_EQ(nranks, gloo->Size(), + platform::errors::InvalidArgument( + "The number of ranks (%d) you set must " + "be equal to gloo->Size() (%d).", + nranks, gloo->Size())); + int64_t send_numel = in->numel(); + int64_t recv_numel = out->numel(); + auto in_dim = x->dims(); + auto out_dim = framework::DDim(in_dim); + out_dim[0] *= nranks; + auto nranks = gloo->Size(); + auto rank = gloo->Rank(); + gloo::GatherOptions opts(gloo->GetContext()); + if (root_id == rank) { + T* recv_buff = out->mutable_data(place, out_dim); + opts.setOutput(recv_buff, recv_numel); + } + opts.setInput(send_buff, send_numel); + opts.setRoot(root_id); + + gloo::gather(opts); +#else + PADDLE_THROW(platform::errors::Unavailable( + "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON")); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_scatter_op.cc b/paddle/fluid/operators/collective/scatter_op_v2.cc similarity index 69% rename from paddle/fluid/operators/collective/c_scatter_op.cc rename to paddle/fluid/operators/collective/scatter_op_v2.cc index 908708e6e328f54466d4bb69b30fd607e14d1fe9..0a3fa8f8bc128b111ef576fd519e0cfa83dfb3e9 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cc +++ b/paddle/fluid/operators/collective/scatter_op_v2.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/collective/c_scatter_op.h" +#include "paddle/fluid/operators/collective/scatter_op_v2.h" namespace paddle { namespace operators { -class CScatterOp : public framework::OperatorWithKernel { +class ScatterOpV2 : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -30,18 +30,23 @@ class CScatterOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument( "The number of ranks (%d) must be greater than 1 " - "to use collective op (c_scatter op).", + "to use collective op (scatter_op_v2).", nranks)); PADDLE_ENFORCE_GE( root_id, 0, platform::errors::InvalidArgument( - "The root_id (%d) for c_scatter_op must be non-negative.", + "The root_id (%d) for scatter_op_v2 must be non-negative.", root_id)); + PADDLE_ENFORCE_LT(root_id, nranks, + platform::errors::InvalidArgument( + "The root_id (%d) for scatter_op_v2 must be less " + "than the number of ranks (%d).", + root_id, nranks)); PADDLE_ENFORCE_GE( ring_id, 0, platform::errors::InvalidArgument( - "The ring_id (%d) for c_scatter_op must be non-negative.", - root_id)); + "The ring_id (%d) for scatter_op_v2 must be non-negative.", + ring_id)); framework::DDim dim = ctx->GetInputDim("X"); dim[0] = dim[0] / nranks; if (dim[0] < 0) dim[0] = -1; @@ -56,7 +61,7 @@ class CScatterOp : public framework::OperatorWithKernel { } }; -class CScatterOpMaker : public framework::OpProtoAndCheckerMaker { +class ScatterOpV2Maker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "(Tensor) tensor to be broadcasted."); @@ -71,7 +76,7 @@ class CScatterOpMaker : public framework::OpProtoAndCheckerMaker { "(bool default false) eject CUDA operations to calculation stream.") .SetDefault(false); AddComment(R"DOC( -CScatter Operator +Scatter Operator Scatter the source to all participators. )DOC"); } @@ -83,10 +88,11 @@ Scatter the source to all participators. namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_WITHOUT_GRADIENT(c_scatter, ops::CScatterOp, ops::CScatterOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(scatter_v2, ops::ScatterOpV2, + ops::ScatterOpV2Maker); -REGISTER_OP_CPU_KERNEL(c_scatter, ops::CScatterOpCPUKernel, - ops::CScatterOpCPUKernel, - ops::CScatterOpCPUKernel, - ops::CScatterOpCPUKernel, - ops::CScatterOpCPUKernel); +REGISTER_OP_CPU_KERNEL(scatter_v2, ops::ScatterOpV2CPUKernel, + ops::ScatterOpV2CPUKernel, + ops::ScatterOpV2CPUKernel, + ops::ScatterOpV2CPUKernel, + ops::ScatterOpV2CPUKernel); diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/scatter_op_v2.cu.cc similarity index 60% rename from paddle/fluid/operators/collective/c_scatter_op.cu.cc rename to paddle/fluid/operators/collective/scatter_op_v2.cu.cc index 8d9e6b4b7d99044f584e9e21062a786252d60f76..f733f381b6532077bd17c75d8f7558ac1b6d480a 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/scatter_op_v2.cu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/collective/c_scatter_op.h" +#include "paddle/fluid/operators/collective/scatter_op_v2.h" #if defined(PADDLE_WITH_NCCL) #include "paddle/fluid/platform/collective_helper.h" @@ -23,7 +23,7 @@ namespace paddle { namespace operators { template -class CScatterOpCUDAKernel : public framework::OpKernel { +class ScatterOpV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_NCCL) @@ -39,7 +39,7 @@ class CScatterOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); PADDLE_ENFORCE_EQ(nranks, comm->nranks(), platform::errors::InvalidArgument( - "The number of ranks (%d) you set of must " + "The number of ranks (%d) you set must " "be equal to comm->nranks (%d).", nranks, comm->nranks())); PADDLE_ENFORCE_GE( @@ -63,33 +63,25 @@ class CScatterOpCUDAKernel : public framework::OpKernel { framework::DDim x_dims = x->dims(); framework::DDim out_dims(x_dims); - framework::Tensor temp; - auto out_ptr = temp.mutable_data(out_dims, place); + out_dims[0] /= nranks; + auto send_buf = x->data(); + auto send_numel = numel / nranks; + auto recv_buf = out->mutable_data(out_dims, place); + auto offset = 0; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupStart()); if (root_id == comm->rank()) { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( - reinterpret_cast(const_cast(x->data())), numel, dtype, - root_id, comm->comm(), stream)); - - framework::TensorCopy(*static_cast(x), place, - *platform::DeviceContextPool::Instance().Get(place), - static_cast(&temp)); - } else { - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( - out_ptr, numel, dtype, root_id, comm->comm(), stream)); + for (auto i = 0; i < nranks; ++i) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclSend(send_buf + offset, send_numel, dtype, + root_id, comm->comm(), stream)); + offset += send_numel; + } } - - out_dims[0] = out_dims[0] / nranks; - auto start_index = out_dims[0] * comm->rank(); - auto end_index = start_index + out_dims[0]; - temp = temp.Slice(start_index, end_index); - temp.Resize(out_dims); - out->mutable_data(out_dims, place); - framework::TensorCopySync(*static_cast(&temp), - place, static_cast(out)); - out->Resize(out_dims); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( + recv_buf, send_numel, dtype, root_id, comm->comm(), stream)); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd()); #else - PADDLE_ENFORCE_EQ( - true, false, + PADDLE_THROW( platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); #endif } @@ -101,8 +93,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(c_scatter, ops::CScatterOpCUDAKernel, - ops::CScatterOpCUDAKernel, - ops::CScatterOpCUDAKernel, - ops::CScatterOpCUDAKernel, - ops::CScatterOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(scatter_v2, ops::ScatterOpV2CUDAKernel, + ops::ScatterOpV2CUDAKernel, + ops::ScatterOpV2CUDAKernel, + ops::ScatterOpV2CUDAKernel, + ops::ScatterOpV2CUDAKernel); diff --git a/paddle/fluid/operators/collective/c_scatter_op.h b/paddle/fluid/operators/collective/scatter_op_v2.h similarity index 97% rename from paddle/fluid/operators/collective/c_scatter_op.h rename to paddle/fluid/operators/collective/scatter_op_v2.h index 71a5f488ebc11a93cece9b85f6af288a4662b2d8..4799a5fb741dd5e72dc851679922271ef1d82c50 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.h +++ b/paddle/fluid/operators/collective/scatter_op_v2.h @@ -31,7 +31,7 @@ namespace paddle { namespace operators { template -class CScatterOpCPUKernel : public framework::OpKernel { +class ScatterOpV2CPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_GLOO) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 97a3ebc2135a0649fff88e1a1c14d02dfb7850b1..76d8ed4089cc5d83725060723b9086f6f8f2b303 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -62,6 +62,7 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) + LIST(REMOVE_ITEM TEST_OPS test_collective_gather) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api) LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api) diff --git a/python/paddle/fluid/tests/unittests/collective_gather_op.py b/python/paddle/fluid/tests/unittests/collective_gather_op.py new file mode 100644 index 0000000000000000000000000000000000000000..4793feb02723dc779879ca1ed7ef0b4e432c3e8a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_gather_op.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveGather(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program, rank=None): + ring_id = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = layers.data( + name="toutdata", shape=[20, 1000], dtype='float32') + main_prog.global_block().append_op( + type="gather_v2", + inputs={'X': tindata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id, + 'nranks': 2, + 'root': 1}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveGather, "gather", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_scatter_op.py b/python/paddle/fluid/tests/unittests/collective_scatter_op.py index 7afa4aec63990372d69f1d16c133e6698aef4dc9..095715d45bdcf34aa87c4267883ceeb104112c00 100644 --- a/python/paddle/fluid/tests/unittests/collective_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/collective_scatter_op.py @@ -49,13 +49,13 @@ class TestCollectiveScatter(TestCollectiveRunnerBase): tindata = layers.data( name="tindata", shape=[10, 1000], dtype='float32') toutdata = main_prog.current_block().create_var( - name="outofreduce", + name="tinout", dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False) main_prog.global_block().append_op( - type="c_scatter", + type="scatter_v2", inputs={'X': tindata}, attrs={'ring_id': ring_id, 'root': rootid, diff --git a/python/paddle/fluid/tests/unittests/test_collective_gather.py b/python/paddle/fluid/tests/unittests/test_collective_gather.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce1cb8e7b8282c62389d5f4970dbc49b2c4cd1c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_gather.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestCGatherOp(TestDistBase): + def _setup_config(self): + pass + + def test_gather(self): + self.check_with_place("collective_gather_op.py", "gather") + + +if __name__ == '__main__': + unittest.main()