未验证 提交 2562ad5a 编写于 作者: L limingshu 提交者: GitHub

Fix dimension merge bug in broadcast (#42143)

* change sequential logic

* change some quotes

* add some notations

* change wrong note style.
上级 e52e6d01
......@@ -31,13 +31,14 @@ struct DimensionsTransform {
using DimVector = std::vector<int64_t>;
typedef void (*MergeFunctor)(
bool &, std::vector<DimVector> &, DimVector &, int, int);
int64_t N;
int64_t dim_size;
DimVector out_dims;
std::vector<DimVector> in_dims;
private:
// To compensate the lackage of input_tensors` dimension with input variable
// 'axis'
// To compensate the lackage of input_tensors` dimension with input
// variable 'axis'.
void InputDimensionsExtend(int N, int axis) {
for (auto &in_dim : in_dims) {
int64_t in_idx = 0;
......@@ -82,6 +83,8 @@ struct DimensionsTransform {
std::reverse(out_dims.begin(), out_dims.end());
}
// Merge sequential dimension to shrink calculation cost for
// offset computation in CUDA Kernel.
template <typename MergeFunctor>
__inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
......@@ -120,11 +123,44 @@ struct DimensionsTransform {
}
}
// To judge whether shape of any input tensors is sequential
// 1-value-dimensions, and metric the length of it.
int GetSequentialOneDimLength(int *swap_index) {
int index = 0;
int max_one_length = 0;
for (int j = 0; j < N; ++j) {
int seq_one_length = 0;
bool active_seq = false;
for (int i = 0; i < dim_size; ++i) {
if (!active_seq && in_dims[j][i] == 1) {
seq_one_length = 1;
active_seq = true;
} else if (active_seq) {
if (in_dims[j][i] == 1) {
seq_one_length++;
} else {
active_seq = false;
}
}
}
max_one_length =
seq_one_length > max_one_length ? seq_one_length : max_one_length;
index = seq_one_length > max_one_length ? j : index;
}
if (max_one_length > 1) {
std::swap(in_dims[0], in_dims[index]);
*swap_index = index;
}
return max_one_length;
}
public:
explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
const phi::DDim &dims,
int axis) {
const int N = std::max(static_cast<int>(ins.size()), 2);
N = std::max(static_cast<int>(ins.size()), 2);
dim_size = dims.size();
out_dims = phi::vectorize<int64_t>(dims);
in_dims.resize(N);
......@@ -140,6 +176,11 @@ struct DimensionsTransform {
}
InputDimensionsExtend(N, axis);
// To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears. Example below :
// in_1.shape = [2, 3, 4, 5] in_1.shape = [2, 12, 5]
// in_2.shape = [1, 3, 4, 5] -> in_2.shape = [1, 12, 5]
// in_3.shape = [2, 3, 4, 1] in_3.shape = [2, 12, 1]
auto merge_sequential_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out,
......@@ -149,6 +190,17 @@ struct DimensionsTransform {
equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
}
};
MergeFunctor merge_ptr = merge_sequential_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
// To Merge the dimension of input_tensors while the sequential
// 1-value-dimensions appears. Example below :
// in_1.shape = [2, 1, 1, 5] in_1.shape = [2, 1, 5]
// in_2.shape = [2, 3, 4, 5] -> in_2.shape = [1, 12, 5]
// in_3.shape = [2, 3, 4, 1] in_3.shape = [2, 12, 1]
// Caution: Once 1-value-dimensions appears, the corresponding
// shape position of other input tensors must be same with the
// output tensor`s shape, or incorrect merge may occur.
auto merge_sequential_one_dims = [](bool &equal,
std::vector<DimVector> &in_dims,
DimVector &out,
......@@ -161,27 +213,13 @@ struct DimensionsTransform {
}
}
};
// To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears.
MergeFunctor merge_ptr = merge_sequential_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
int min_idx = 0;
int min_val = std::accumulate(
in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies<int64_t>());
for (int j = 1; j < N; ++j) {
int temp = std::accumulate(
in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies<int64_t>());
min_val = min_val > temp ? temp : min_val;
min_idx = min_val == temp ? j : min_idx;
int swap_idx = 0;
int max_one_length = GetSequentialOneDimLength(&swap_idx);
if (max_one_length > 1) {
merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[swap_idx], in_dims[0]);
}
std::swap(in_dims[0], in_dims[min_idx]);
// To Merge the dimension of input_tensors while the consequtive
// 1-value-dimensions appears.
merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[min_idx], in_dims[0]);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册