diff --git a/paddle/fluid/operators/collective/c_allgather_op.cc b/paddle/fluid/operators/collective/c_allgather_op.cc index 18c8f5d642332d96f6e76cf7f2e70b554cacbb89..8d6dd66f99538e31a470e030412e0171760b7f70 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cc @@ -29,6 +29,7 @@ class CAllGatherOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2"); framework::DDim dim = ctx->GetInputDim("X"); dim[0] = dim[0] * nranks; + if (dim[0] < 0) dim[0] = -1; ctx->SetOutputDim("Out", dim); } }; diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index da92b65aa9ed2c90cefaf61a785566c4609935da..0115946141276845a44b750f13a17ccf50506d03 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -36,6 +36,11 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { int nranks = comm->nranks(); auto out_dims = in->dims(); + PADDLE_ENFORCE_EQ(out_dims[0] % nranks, 0, + platform::errors::InvalidArgument( + "The input tensor X's " + "dim[0] (%d) should be divisible by nranks(%d)", + out_dims[0], nranks)); out_dims[0] = out_dims[0] / nranks; out->mutable_data(out_dims, place); diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index 9e96624cf7c70f24a7f65a91c7ee41af45ddeb6c..43eb436f65e78114fe4a4c9bf7450faca0d87b38 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -134,7 +134,7 @@ def _c_reducescatter(x, nranks, ring_id=0, use_calc_stream=False): if not isinstance(x, Variable): raise TypeError('x must be a Variable') - if x.shape[0] % nranks != 0: + if x.shape[0] > 0 and x.shape[0] % nranks != 0: raise ValueError('x.shape[0](%d) cannot be evenly divided by nranks(%d)' % (x.shape[0], nranks)) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index fdd7c984a3fd6eb14e36dbf2dd8b2c085d324a4f..7e37652829b789303ad4a6e8e2a65274ca63a74f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -26,6 +26,7 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_allreduce) LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_reducescatter) + LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api) endif() if(WIN32) diff --git a/python/paddle/fluid/tests/unittests/collective_reducescatter.py b/python/paddle/fluid/tests/unittests/collective_reducescatter.py new file mode 100644 index 0000000000000000000000000000000000000000..2f14277ae1e549b0b8dc075694752c18b395d230 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_reducescatter.py @@ -0,0 +1,53 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveReduceScatter(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + nranks = 2 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveReduceScatter, "reducescatter", 0) diff --git a/python/paddle/fluid/tests/unittests/test_reducescatter_api.py b/python/paddle/fluid/tests/unittests/test_reducescatter_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa75cc3effe37197195da7555a1a3266e30754b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_reducescatter_api.py @@ -0,0 +1,40 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle.fluid as fluid + +from test_collective_base import TestDistBase + + +class TestReduceScatterAPI(TestDistBase): + def _setup_config(self): + pass + + def test_reducescatter(self, col_type="reduce_scatter"): + self.check_with_place("collective_reducescatter.py", col_type) + + def test_reducescatter_with_error(self): + nranks = 2 + tindata = fluid.data(name="tindata", shape=[5, 1000], dtype='float32') + try: + toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks) + except ValueError: + pass + + +if __name__ == '__main__': + unittest.main()