未验证 提交 e097a748 编写于 作者: L limingshu 提交者: GitHub

[BugFix]: Elementwise branch selection and Broadcast dimension merge (#38204)

* fix_bugs_for_elementwise_branch_selection

* fix merge_dims bugs

* fix all influenced file
上级 3a0e0b6f
...@@ -125,7 +125,7 @@ struct DimensionsTransform { ...@@ -125,7 +125,7 @@ struct DimensionsTransform {
std::vector<DimVector> &in_dims, std::vector<DimVector> &in_dims,
DimVector &out, int i, int num) { DimVector &out, int i, int num) {
for (int j = 1; j < num; ++j) { for (int j = 1; j < num; ++j) {
equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
} }
}; };
auto merge_sequential_one_dims = [](bool &equal, auto merge_sequential_one_dims = [](bool &equal,
...@@ -134,7 +134,7 @@ struct DimensionsTransform { ...@@ -134,7 +134,7 @@ struct DimensionsTransform {
equal = in_dims[0][i] == 1; equal = in_dims[0][i] == 1;
if (equal) { if (equal) {
for (int j = 1; j < num; ++j) { for (int j = 1; j < num; ++j) {
equal = in_dims[j][i] == out[i]; equal &= in_dims[j][i] == out[i];
} }
} }
}; };
......
...@@ -29,7 +29,7 @@ void LaunchElementwiseCudaKernel( ...@@ -29,7 +29,7 @@ void LaunchElementwiseCudaKernel(
std::vector<int> dims_size; std::vector<int> dims_size;
bool no_broadcast_flag = true; bool no_broadcast_flag = true;
for (auto *in : ins) { for (auto *in : ins) {
no_broadcast_flag = ins[0]->dims() == in->dims(); no_broadcast_flag &= ins[0]->dims() == in->dims();
dims_size.emplace_back(in->dims().size()); dims_size.emplace_back(in->dims().size());
} }
if (no_broadcast_flag) { if (no_broadcast_flag) {
......
...@@ -131,7 +131,7 @@ struct DimensionsTransform { ...@@ -131,7 +131,7 @@ struct DimensionsTransform {
int i, int i,
int num) { int num) {
for (int j = 1; j < num; ++j) { for (int j = 1; j < num; ++j) {
equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
} }
}; };
auto merge_sequential_one_dims = [](bool &equal, auto merge_sequential_one_dims = [](bool &equal,
...@@ -142,7 +142,7 @@ struct DimensionsTransform { ...@@ -142,7 +142,7 @@ struct DimensionsTransform {
equal = in_dims[0][i] == 1; equal = in_dims[0][i] == 1;
if (equal) { if (equal) {
for (int j = 1; j < num; ++j) { for (int j = 1; j < num; ++j) {
equal = in_dims[j][i] == out[i]; equal &= in_dims[j][i] == out[i];
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册