From 38568519f78f57e4def0dcf44909e430c3e80e64 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 11 Oct 2018 15:25:53 +0800 Subject: [PATCH] optimize code --- paddle/fluid/operators/math/selected_rows_functor.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 95f3c62a5..a11c6461d 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/operators/math/math_function.h" @@ -228,6 +229,11 @@ struct MergeAdd { } std::vector merge_rows(merged_row_set.begin(), merged_row_set.end()); + std::map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; + } + out.set_rows(merge_rows); out.set_height(input_height); out.mutable_value()->mutable_data( @@ -245,7 +251,7 @@ struct MergeAdd { auto& input_rows = input->rows(); for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = FindPos(merge_rows, input_rows[i]); + size_t out_i = rows_to_id[input_rows[i]]; for (int64_t j = 0; j < input_width; j++) { out_data[out_i * input_width + j] += input_data[i * input_width + j]; } -- GitLab