diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 7634c2462738b2b3bdb622e851723aef23045dfd..10216f80c00d4fded6184f108ca8ab20bc103493 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -31,13 +31,14 @@ struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)( bool &, std::vector &, DimVector &, int, int); + int64_t N; int64_t dim_size; DimVector out_dims; std::vector in_dims; private: - // To compensate the lackage of input_tensors` dimension with input variable - // 'axis' + // To compensate the lackage of input_tensors` dimension with input + // variable 'axis'. void InputDimensionsExtend(int N, int axis) { for (auto &in_dim : in_dims) { int64_t in_idx = 0; @@ -82,6 +83,8 @@ struct DimensionsTransform { std::reverse(out_dims.begin(), out_dims.end()); } + // Merge sequential dimension to shrink calculation cost for + // offset computation in CUDA Kernel. template __inline__ void MergeDimensions(MergeFunctor merge_func, int N) { auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { @@ -120,11 +123,44 @@ struct DimensionsTransform { } } + // To judge whether shape of any input tensors is sequential + // 1-value-dimensions, and metric the length of it. + int GetSequentialOneDimLength(int *swap_index) { + int index = 0; + int max_one_length = 0; + for (int j = 0; j < N; ++j) { + int seq_one_length = 0; + bool active_seq = false; + + for (int i = 0; i < dim_size; ++i) { + if (!active_seq && in_dims[j][i] == 1) { + seq_one_length = 1; + active_seq = true; + } else if (active_seq) { + if (in_dims[j][i] == 1) { + seq_one_length++; + } else { + active_seq = false; + } + } + } + max_one_length = + seq_one_length > max_one_length ? seq_one_length : max_one_length; + index = seq_one_length > max_one_length ? j : index; + } + + if (max_one_length > 1) { + std::swap(in_dims[0], in_dims[index]); + *swap_index = index; + } + return max_one_length; + } + public: explicit DimensionsTransform(const std::vector &ins, const phi::DDim &dims, int axis) { - const int N = std::max(static_cast(ins.size()), 2); + N = std::max(static_cast(ins.size()), 2); dim_size = dims.size(); out_dims = phi::vectorize(dims); in_dims.resize(N); @@ -140,6 +176,11 @@ struct DimensionsTransform { } InputDimensionsExtend(N, axis); + // To Merge the dimensions of input_tensors while the consequtive + // equal-dimensions appears. Example below : + // in_1.shape = [2, 3, 4, 5] in_1.shape = [2, 12, 5] + // in_2.shape = [1, 3, 4, 5] -> in_2.shape = [1, 12, 5] + // in_3.shape = [2, 3, 4, 1] in_3.shape = [2, 12, 1] auto merge_sequential_dims = [](bool &equal, std::vector &in_dims, DimVector &out, @@ -149,6 +190,17 @@ struct DimensionsTransform { equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false; } }; + MergeFunctor merge_ptr = merge_sequential_dims; + MergeDimensions(merge_ptr, N); + + // To Merge the dimension of input_tensors while the sequential + // 1-value-dimensions appears. Example below : + // in_1.shape = [2, 1, 1, 5] in_1.shape = [2, 1, 5] + // in_2.shape = [2, 3, 4, 5] -> in_2.shape = [1, 12, 5] + // in_3.shape = [2, 3, 4, 1] in_3.shape = [2, 12, 1] + // Caution: Once 1-value-dimensions appears, the corresponding + // shape position of other input tensors must be same with the + // output tensor`s shape, or incorrect merge may occur. auto merge_sequential_one_dims = [](bool &equal, std::vector &in_dims, DimVector &out, @@ -161,27 +213,13 @@ struct DimensionsTransform { } } }; - // To Merge the dimensions of input_tensors while the consequtive - // equal-dimensions appears. - MergeFunctor merge_ptr = merge_sequential_dims; - MergeDimensions(merge_ptr, N); - - int min_idx = 0; - int min_val = std::accumulate( - in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies()); - for (int j = 1; j < N; ++j) { - int temp = std::accumulate( - in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies()); - min_val = min_val > temp ? temp : min_val; - min_idx = min_val == temp ? j : min_idx; + int swap_idx = 0; + int max_one_length = GetSequentialOneDimLength(&swap_idx); + if (max_one_length > 1) { + merge_ptr = merge_sequential_one_dims; + MergeDimensions(merge_ptr, N); + std::swap(in_dims[swap_idx], in_dims[0]); } - std::swap(in_dims[0], in_dims[min_idx]); - - // To Merge the dimension of input_tensors while the consequtive - // 1-value-dimensions appears. - merge_ptr = merge_sequential_one_dims; - MergeDimensions(merge_ptr, N); - std::swap(in_dims[min_idx], in_dims[0]); } };