未验证 提交 e097a748 编写于 作者: L limingshu 提交者: GitHub

[BugFix]: Elementwise branch selection and Broadcast dimension merge (#38204)

* fix_bugs_for_elementwise_branch_selection

* fix merge_dims bugs

* fix all influenced file
上级 3a0e0b6f
......@@ -125,7 +125,7 @@ struct DimensionsTransform {
std::vector<DimVector> &in_dims,
DimVector &out, int i, int num) {
for (int j = 1; j < num; ++j) {
equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
}
};
auto merge_sequential_one_dims = [](bool &equal,
......@@ -134,7 +134,7 @@ struct DimensionsTransform {
equal = in_dims[0][i] == 1;
if (equal) {
for (int j = 1; j < num; ++j) {
equal = in_dims[j][i] == out[i];
equal &= in_dims[j][i] == out[i];
}
}
};
......
......@@ -29,7 +29,7 @@ void LaunchElementwiseCudaKernel(
std::vector<int> dims_size;
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims();
no_broadcast_flag &= ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size());
}
if (no_broadcast_flag) {
......
......@@ -131,7 +131,7 @@ struct DimensionsTransform {
int i,
int num) {
for (int j = 1; j < num; ++j) {
equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
}
};
auto merge_sequential_one_dims = [](bool &equal,
......@@ -142,7 +142,7 @@ struct DimensionsTransform {
equal = in_dims[0][i] == 1;
if (equal) {
for (int j = 1; j < num; ++j) {
equal = in_dims[j][i] == out[i];
equal &= in_dims[j][i] == out[i];
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册