未验证 提交 34e3adae 编写于 作者: Z Zeng Jinle 提交者: GitHub

Refine reduce codes to save compiling time and binary size (#20676)

* refine reduce code to save compiling time and binary sizes, test=develop

* add reduce rank check to avoid bug, test=develop
上级 1d925440
......@@ -22,3 +22,7 @@ if(WITH_GPU)
endif()
endforeach()
endif()
if(WITH_GPU)
nv_test(check_reduce_rank_test SRCS check_reduce_rank_test.cu DEPS tensor cub)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gtest/gtest.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
namespace paddle {
namespace operators {
namespace detail {
TEST(test_reduce_rank_check, all) {
using EnforceNotMet = paddle::platform::EnforceNotMet;
constexpr int kMaxRank = framework::DDim::kMaxRank;
for (int rank = 0; rank < kMaxRank; rank++) {
for (int reduce_rank = 0; reduce_rank <= rank; reduce_rank++) {
bool is_valid = false;
if (rank % 2 == 0) {
is_valid = (reduce_rank == rank / 2);
} else {
if (reduce_rank == (rank - 1) / 2) {
is_valid = true;
} else if (reduce_rank == (rank + 1) / 2) {
is_valid = true;
} else {
is_valid = false;
}
}
if (is_valid) {
CheckReduceRankIsValid(reduce_rank, rank);
} else {
ASSERT_THROW(CheckReduceRankIsValid(reduce_rank, rank),
paddle::platform::EnforceNotMet);
}
}
}
}
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -153,6 +153,18 @@ static inline int GetDesiredBlockDim(int block_dim) {
: (1 << static_cast<int>(std::log2(block_dim)));
}
static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
if (rank % 2 == 0) {
PADDLE_ENFORCE_EQ(reduce_rank, rank / 2);
} else {
auto lower_rank = (rank - 1) / 2;
auto upper_rank = (rank + 1) / 2;
PADDLE_ENFORCE(reduce_rank == lower_rank || reduce_rank == upper_rank,
"When rank = %d, reduce_rank must be %d or %d, but got %d",
rank, lower_rank, upper_rank, reduce_rank);
}
}
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp>
static void TensorReduceImpl(
......@@ -211,33 +223,36 @@ static void TensorReduceImpl(
}
*/
/**
* Since we have combined the adjacent reduce dimensions inside TensorReduce,
* The reduce ranks and non-reduce ranks must be interleaving. That is to say,
* the rank of Tensor must be `1010...` or `0101...` where 1 represents that
* the dimension is about to be reduced.
*
* Therefore,
* If rank is odd, only need to switch-case (rank - 1)/2 and (rank + 1)/2.
* If rank is even, only need to switch-case rank/2.
*
* The total switch-case numbers reduce from 1+2+3+...+8=36 to (1+2)*4=12,
* it would speed up compiling and make the binary size lower.
*/
CheckReduceRankIsValid(reduce_rank, rank);
switch (rank) {
CUB_RANK_CASE(2, CUB_REDUCE_RANK_CASE(1););
CUB_RANK_CASE(3, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2););
CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);
CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(4, CUB_REDUCE_RANK_CASE(2););
CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);
CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(5, CUB_REDUCE_RANK_CASE(2); CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);
CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);
CUB_REDUCE_RANK_CASE(5););
CUB_RANK_CASE(6, CUB_REDUCE_RANK_CASE(3););
CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);
CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);
CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6););
CUB_RANK_CASE(7, CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);
CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);
CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6););
CUB_RANK_CASE(8, CUB_REDUCE_RANK_CASE(4););
CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(1); CUB_REDUCE_RANK_CASE(2);
CUB_REDUCE_RANK_CASE(3); CUB_REDUCE_RANK_CASE(4);
CUB_REDUCE_RANK_CASE(5); CUB_REDUCE_RANK_CASE(6);
CUB_REDUCE_RANK_CASE(7); CUB_REDUCE_RANK_CASE(8););
CUB_RANK_CASE(9, CUB_REDUCE_RANK_CASE(4); CUB_REDUCE_RANK_CASE(5););
}
#undef CUB_REDUCE_RANK_CASE
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册