diff --git a/paddle/fluid/operators/reduce_ops/CMakeLists.txt b/paddle/fluid/operators/reduce_ops/CMakeLists.txt index ebd07d90ebe6b0ba008ac89c01c4f054f96a6da9..3da481a142aa2282aade661de7679cf4edf597a0 100644 --- a/paddle/fluid/operators/reduce_ops/CMakeLists.txt +++ b/paddle/fluid/operators/reduce_ops/CMakeLists.txt @@ -22,3 +22,7 @@ if(WITH_GPU) endif() endforeach() endif() + +if(WITH_GPU) + nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor cub) +endif() diff --git a/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..7efdff934239a0643da8ff57911492139bac3a9c --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/check_reduce_rank_test.cu @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" + +namespace paddle { +namespace operators { +namespace detail { + +TEST(test_reduce_rank_check, all) { + using EnforceNotMet = paddle::platform::EnforceNotMet; + constexpr int kMaxRank = framework::DDim::kMaxRank; + + for (int rank = 0; rank < kMaxRank; rank++) { + for (int reduce_rank = 0; reduce_rank <= rank; reduce_rank++) { + bool is_valid = false; + if (rank % 2 == 0) { + is_valid = (reduce_rank == rank / 2); + } else { + if (reduce_rank == (rank - 1) / 2) { + is_valid = true; + } else if (reduce_rank == (rank + 1) / 2) { + is_valid = true; + } else { + is_valid = false; + } + } + + if (is_valid) { + CheckReduceRankIsValid(reduce_rank, rank); + } else { + ASSERT_THROW(CheckReduceRankIsValid(reduce_rank, rank), + paddle::platform::EnforceNotMet); + } + } + } +} + +} // namespace detail +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index af56e85e9c6f5e0cfb5e03587fbace4665d9e5fb..876118245f1ab63de41f7d87db8d3ce4eeea57ba 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -153,6 +153,18 @@ static inline int GetDesiredBlockDim(int block_dim) { : (1 << static_cast(std::log2(block_dim))); } +static inline void CheckReduceRankIsValid(int reduce_rank, int rank) { + if (rank % 2 == 0) { + PADDLE_ENFORCE_EQ(reduce_rank, rank / 2); + } else { + auto lower_rank = (rank - 1) / 2; + auto upper_rank = (rank + 1) / 2; + PADDLE_ENFORCE(reduce_rank == lower_rank || reduce_rank == upper_rank, + "When rank = %d, reduce_rank must be %d or %d, but got %d", + rank, lower_rank, upper_rank, reduce_rank); + } +} + template static void TensorReduceImpl( @@ -211,33 +223,36 @@ static void TensorReduceImpl( } */ + /** + * Since we have combined the adjacent reduce dimensions inside TensorReduce, + * The reduce ranks and non-reduce ranks must be interleaving. That is to say, + * the rank of Tensor must be `1010...` or `0101...` where 1 represents that + * the dimension is about to be reduced. + * + * Therefore, + * If rank is odd, only need to switch-case (rank - 1)/2 and (rank + 1)/2. + * If rank is even, only need to switch-case rank/2. + * + * The total switch-case numbers reduce from 1+2+3+...+8=36 to (1+2)*4=12, + * it would speed up compiling and make the binary size lower. + */ + CheckReduceRankIsValid(reduce_rank, rank); switch (rank) { CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1);); CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);); - CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); - CUB_REDUCE_RANK_CASE(3);); + CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2);); - CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); - CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);); + CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3);); - CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); - CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); - CUB_REDUCE_RANK_CASE(5);); + CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3);); - CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); - CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); - CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);); + CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);); - CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); - CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); - CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);); + CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4);); - CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2); - CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4); - CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6); - CUB_REDUCE_RANK_CASE(7); CUB_REDUCE_RANK_CASE(8);); + CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5);); } #undef CUB_REDUCE_RANK_CASE