From e097a748954381e61155bd77c43086c15b3594da Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Fri, 17 Dec 2021 10:32:56 +0800 Subject: [PATCH] [BugFix]: Elementwise branch selection and Broadcast dimension merge (#38204) * fix_bugs_for_elementwise_branch_selection * fix merge_dims bugs * fix all influenced file --- .../fluid/operators/elementwise/elementwise_op_broadcast.cu.h | 4 ++-- paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h | 2 +- .../hybird/cuda/elementwise/elementwise_broadcast.cu.h | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h index e743d43e47e..30aba42aeee 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -125,7 +125,7 @@ struct DimensionsTransform { std::vector &in_dims, DimVector &out, int i, int num) { for (int j = 1; j < num; ++j) { - equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; + equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false; } }; auto merge_sequential_one_dims = [](bool &equal, @@ -134,7 +134,7 @@ struct DimensionsTransform { equal = in_dims[0][i] == 1; if (equal) { for (int j = 1; j < num; ++j) { - equal = in_dims[j][i] == out[i]; + equal &= in_dims[j][i] == out[i]; } } }; diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h index 9bcfa1d857b..0ef2ee2fdf1 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h @@ -29,7 +29,7 @@ void LaunchElementwiseCudaKernel( std::vector dims_size; bool no_broadcast_flag = true; for (auto *in : ins) { - no_broadcast_flag = ins[0]->dims() == in->dims(); + no_broadcast_flag &= ins[0]->dims() == in->dims(); dims_size.emplace_back(in->dims().size()); } if (no_broadcast_flag) { diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h index 258e8f410eb..ccdeb70002b 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h +++ b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h @@ -131,7 +131,7 @@ struct DimensionsTransform { int i, int num) { for (int j = 1; j < num; ++j) { - equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; + equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false; } }; auto merge_sequential_one_dims = [](bool &equal, @@ -142,7 +142,7 @@ struct DimensionsTransform { equal = in_dims[0][i] == 1; if (equal) { for (int j = 1; j < num; ++j) { - equal = in_dims[j][i] == out[i]; + equal &= in_dims[j][i] == out[i]; } } }; -- GitLab