From 2b108a04ab31793dd47d0b65395ac7cd20d23336 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Fri, 23 Apr 2021 14:03:04 +0800 Subject: [PATCH] add c_concat and c_split ops (#32486) * add c_concat op --- .../fluid/operators/collective/c_concat_op.cc | 112 ++++++++++++++++++ .../operators/collective/c_concat_op.cu.cc | 110 +++++++++++++++++ .../fluid/operators/collective/c_concat_op.h | 38 ++++++ .../fluid/operators/collective/c_split_op.cc | 112 ++++++++++++++++++ .../operators/collective/c_split_op.cu.cc | 80 +++++++++++++ .../fluid/operators/collective/c_split_op.h | 38 ++++++ .../fluid/tests/unittests/CMakeLists.txt | 4 + .../tests/unittests/collective_concat_op.py | 69 +++++++++++ .../tests/unittests/collective_split_op.py | 69 +++++++++++ .../fluid/tests/unittests/test_c_concat.py | 34 ++++++ .../fluid/tests/unittests/test_c_split.py | 34 ++++++ .../tests/unittests/test_collective_base.py | 17 +++ 12 files changed, 717 insertions(+) create mode 100644 paddle/fluid/operators/collective/c_concat_op.cc create mode 100644 paddle/fluid/operators/collective/c_concat_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_concat_op.h create mode 100644 paddle/fluid/operators/collective/c_split_op.cc create mode 100644 paddle/fluid/operators/collective/c_split_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_split_op.h create mode 100644 python/paddle/fluid/tests/unittests/collective_concat_op.py create mode 100644 python/paddle/fluid/tests/unittests/collective_split_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_c_concat.py create mode 100644 python/paddle/fluid/tests/unittests/test_c_split.py diff --git a/paddle/fluid/operators/collective/c_concat_op.cc b/paddle/fluid/operators/collective/c_concat_op.cc new file mode 100644 index 0000000000..551fde2116 --- /dev/null +++ b/paddle/fluid/operators/collective/c_concat_op.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_concat_op.h" + +namespace paddle { +namespace operators { + +class CConcatOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "c_concat"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "c_concat"); + int nranks = ctx->Attrs().Get("nranks"); + int rank = ctx->Attrs().Get("rank"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument( + "The number of ranks (%d) for c_concat " + "must be greater than 1.", + nranks)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_concat must be non-negative.", ring_id)); + PADDLE_ENFORCE_GE( + rank, 0, platform::errors::InvalidArgument( + "The rank (%d) for c_concat must be non-negative.", rank)); + PADDLE_ENFORCE_LT(rank, nranks, + platform::errors::InvalidArgument( + "The value of rank (%d) for c_concat must " + "be less than that of nranks.", + rank, nranks)); + + framework::DDim dim = ctx->GetInputDim("X"); + dim[dim.size() - 1] = dim[dim.size() - 1] * nranks; + if (dim[dim.size() - 1] < 0) dim[dim.size() - 1] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +template +class CConcatOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("c_split"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +class CConcatOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be concated."); + AddOutput("Out", "(Tensor) the result of concat."); + AddAttr("rank", "(int default 0) rank id.").SetDefault(0); + AddAttr("nranks", "(int default 1) number of ranks.").SetDefault(1); + AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default true) eject CUDA operations to calculation stream.") + .SetDefault(true); + AddAttr("use_model_parallel", + "(bool default true) use this op with model parallel.") + .SetDefault(true); + AddComment(R"DOC( +CConcat Operator +AllGather the tensors on different trainers and concat them along the last dimension. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_concat, ops::CConcatOp, + ops::CConcatOpGradMaker, + ops::CConcatOpGradMaker, + ops::CConcatOpMaker); + +REGISTER_OP_CPU_KERNEL(c_concat, ops::CConcatOpCPUKernel, + ops::CConcatOpCPUKernel, + ops::CConcatOpCPUKernel, + ops::CConcatOpCPUKernel, + ops::CConcatOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc new file mode 100644 index 0000000000..bfdc49c440 --- /dev/null +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -0,0 +1,110 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/operators/collective/c_concat_op.h" +#include "paddle/fluid/operators/math/concat_and_split.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CConcatOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int nranks = ctx.Attr("nranks"); + int rank = ctx.Attr("rank"); + int rid = ctx.Attr("ring_id"); + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_GE(rank, 0, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_concat must be " + "greater than or equal to 0.", + rank)); + PADDLE_ENFORCE_GE(nranks, 2, + platform::errors::PreconditionNotMet( + "The value of nranks (%d) for c_concat must be " + "greater than or equal to 2.", + nranks)); + PADDLE_ENFORCE_LT(rank, nranks, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_concat must be " + "less than that of nranks (%d).", + rank, nranks)); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + PADDLE_ENFORCE_EQ( + nranks, comm->nranks(), + platform::errors::InvalidArgument("nranks: %s should equal to %s", + nranks, comm->nranks())); + + framework::Tensor temp_out; + framework::DDim temp_out_dims = x->dims(); + temp_out_dims[0] *= nranks; + temp_out.mutable_data(temp_out_dims, place); + int64_t send_numel = x->numel(); + const T* send_buff = x->data(); + T* recv_buff = temp_out.data(); + gpuStream_t stream = nullptr; + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather( + send_buff, recv_buff, send_numel, static_cast(dtype), + comm->comm(), stream)); + + std::vector inputs; + int axis = x->dims().size() - 1; + auto out_dims = x->dims(); + out_dims[out_dims.size() - 1] *= nranks; + int rows_per_tensor = x->dims()[0]; + int offset = 0; + for (int i = 0; i < nranks; i++) { + framework::Tensor temp = temp_out.Slice(offset, offset + rows_per_tensor); + inputs.emplace_back(temp); + offset += rows_per_tensor; + } + + math::ConcatFunctor functor; + out->mutable_data(out_dims, place); + auto& dev_ctx2 = ctx.template device_context(); + functor(dev_ctx2, inputs, axis, out); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(c_concat, ops::CConcatOpCUDAKernel, + ops::CConcatOpCUDAKernel, + ops::CConcatOpCUDAKernel, + ops::CConcatOpCUDAKernel, + ops::CConcatOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_concat_op.h b/paddle/fluid/operators/collective/c_concat_op.h new file mode 100644 index 0000000000..55a5799e37 --- /dev/null +++ b/paddle/fluid/operators/collective/c_concat_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CConcatOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support c_concat for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_split_op.cc b/paddle/fluid/operators/collective/c_split_op.cc new file mode 100644 index 0000000000..03046d571d --- /dev/null +++ b/paddle/fluid/operators/collective/c_split_op.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_split_op.h" + +namespace paddle { +namespace operators { + +class CSplitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "c_split"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "c_split"); + int nranks = ctx->Attrs().Get("nranks"); + int rank = ctx->Attrs().Get("rank"); + int ring_id = ctx->Attrs().Get("ring_id"); + PADDLE_ENFORCE_GE(nranks, 2, platform::errors::InvalidArgument( + "The number of ranks (%d) for c_split " + "must be greater than 1.", + nranks)); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for c_split must be non-negative.", ring_id)); + PADDLE_ENFORCE_GE( + rank, 0, platform::errors::InvalidArgument( + "The rank (%d) for c_split must be non-negative.", rank)); + PADDLE_ENFORCE_LT(rank, nranks, + platform::errors::InvalidArgument( + "The value of rank (%d) for c_split must " + "be less than that of nranks.", + rank, nranks)); + + framework::DDim dim = ctx->GetInputDim("X"); + dim[dim.size() - 1] = dim[dim.size() - 1] / nranks; + if (dim[0] < 0) dim[0] = -1; + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +template +class CSplitOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("c_allgather"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +class CSplitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be split."); + AddOutput("Out", "(Tensor) the result of split."); + AddAttr("rank", "(int default 0) rank id.").SetDefault(0); + AddAttr("nranks", "(int default 1) number of ranks.").SetDefault(1); + AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddAttr("use_model_parallel", + "(bool default false) use this op with model parallel.") + .SetDefault(true); + AddComment(R"DOC( +CSplit Operator +Split the tensor evenly according to its rank. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_split, ops::CSplitOp, + ops::CSplitOpGradMaker, + ops::CSplitOpGradMaker, + ops::CSplitOpMaker); + +REGISTER_OP_CPU_KERNEL(c_split, ops::CSplitOpCPUKernel, + ops::CSplitOpCPUKernel, + ops::CSplitOpCPUKernel, + ops::CSplitOpCPUKernel, + ops::CSplitOpCPUKernel); diff --git a/paddle/fluid/operators/collective/c_split_op.cu.cc b/paddle/fluid/operators/collective/c_split_op.cu.cc new file mode 100644 index 0000000000..92a7f5e41b --- /dev/null +++ b/paddle/fluid/operators/collective/c_split_op.cu.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/operators/collective/c_split_op.h" +#include "paddle/fluid/operators/math/concat_and_split.h" + +namespace paddle { +namespace operators { + +template +class CSplitOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + + int nranks = ctx.Attr("nranks"); + int rank = ctx.Attr("rank"); + auto place = ctx.GetPlace(); + + PADDLE_ENFORCE_GE(rank, 0, platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "greater than or equal to 0.", + rank)); + PADDLE_ENFORCE_GE(nranks, 2, + platform::errors::PreconditionNotMet( + "The value of nranks (%d) for c_split must be " + "greater than or equal to 2.", + nranks)); + PADDLE_ENFORCE_LT(rank, nranks, + platform::errors::PreconditionNotMet( + "The value of rank (%d) for c_split must be " + "less than that of nranks (%d).", + rank, nranks)); + + auto& dev_ctx = ctx.template device_context(); + std::vector shape_refer; + std::vector results; + size_t numel = x->numel(); + auto dims = x->dims(); + numel /= nranks; + int axis = dims.size() - 1; + dims[dims.size() - 1] /= nranks; + for (int i = 0; i < nranks; i++) { + framework::Tensor* out = new framework::Tensor(); + out->mutable_data(dims, place); + shape_refer.emplace_back(out); + results.emplace_back(out); + } + + math::SplitFunctor functor; + functor(dev_ctx, *x, shape_refer, axis, &results); + out->mutable_data(dims, place); + paddle::framework::TensorCopySync(*results[rank], out->place(), out); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(c_split, ops::CSplitOpCUDAKernel, + ops::CSplitOpCUDAKernel, + ops::CSplitOpCUDAKernel, + ops::CSplitOpCUDAKernel, + ops::CSplitOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_split_op.h b/paddle/fluid/operators/collective/c_split_op.h new file mode 100644 index 0000000000..ea0c7fc45c --- /dev/null +++ b/paddle/fluid/operators/collective/c_split_op.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class CSplitOpCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unavailable( + "Do not support c_split for cpu kernel now.")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index fc8073e716..1db84665ca 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -71,6 +71,8 @@ endforeach() if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op) + LIST(REMOVE_ITEM TEST_OPS test_c_concat) + LIST(REMOVE_ITEM TEST_OPS test_c_split) LIST(REMOVE_ITEM TEST_OPS test_allgather) LIST(REMOVE_ITEM TEST_OPS test_allreduce) LIST(REMOVE_ITEM TEST_OPS test_broadcast) @@ -873,6 +875,8 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_collective_reduce_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_reduce PROPERTIES TIMEOUT 120) set_tests_properties(test_allreduce PROPERTIES TIMEOUT 120) + set_tests_properties(test_c_concat PROPERTIES TIMEOUT 120) + set_tests_properties(test_c_split PROPERTIES TIMEOUT 120) set_tests_properties(test_allgather PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_scatter_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_barrier_api PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/collective_concat_op.py b/python/paddle/fluid/tests/unittests/collective_concat_op.py new file mode 100644 index 0000000000..c9de1713e7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_concat_op.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveConcat(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + nranks = 2 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofconcat", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_concat", + inputs={'X': tindata}, + attrs={ + 'ring_id': ring_id, + 'rank': self.rank, + 'nranks': nranks + }, + outputs={'Out': toutdata}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveConcat, "concat", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_split_op.py b/python/paddle/fluid/tests/unittests/collective_split_op.py new file mode 100644 index 0000000000..553955354f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_split_op.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + +paddle.enable_static() + + +class TestCollectiveAllGather(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + nranks = 2 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofsplit", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_split", + inputs={'X': tindata}, + attrs={ + 'ring_id': ring_id, + 'rank': self.rank, + 'nranks': nranks + }, + outputs={'Out': toutdata}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllGather, "split", 0) diff --git a/python/paddle/fluid/tests/unittests/test_c_concat.py b/python/paddle/fluid/tests/unittests/test_c_concat.py new file mode 100644 index 0000000000..20f166af14 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_c_concat.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_base import TestDistBase + +paddle.enable_static() + + +class TestConcatOp(TestDistBase): + def _setup_config(self): + pass + + def test_concat(self, col_type="concat"): + self.check_with_place("collective_concat_op.py", col_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_c_split.py b/python/paddle/fluid/tests/unittests/test_c_split.py new file mode 100644 index 0000000000..0a5d91e062 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_c_split.py @@ -0,0 +1,34 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle + +from test_collective_base import TestDistBase + +paddle.enable_static() + + +class TestSplitOp(TestDistBase): + def _setup_config(self): + pass + + def test_split(self, col_type="split"): + self.check_with_place("collective_split_op.py", col_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py index fc267ed914..0d592ec185 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -284,5 +284,22 @@ class TestDistBase(unittest.TestCase): need_result2 = np.concatenate((tmp20, tmp21), axis=1) self.assertTrue(np.allclose(tr0_out, need_result1)) self.assertTrue(np.allclose(tr1_out, need_result2)) + elif col_type == "concat": + need_result = np.concatenate((input1, input2), axis=1) + self.assertTrue( + np.allclose( + tr0_out, need_result, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out, need_result, rtol=1e-05, atol=1e-05)) + elif col_type == "split": + need_result1 = np.split(input1, 2, axis=1)[0] + need_result2 = np.split(input2, 2, axis=1)[1] + self.assertTrue( + np.allclose( + tr0_out, need_result1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out, need_result2, rtol=1e-05, atol=1e-05)) else: pass -- GitLab