From 6056d04361977bc2596f7b293230a8c0fa436643 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 15 Oct 2018 16:38:51 +0800 Subject: [PATCH] optimize blas call --- .../fluid/operators/math/selected_rows_functor.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index a4f584623..77864aa7c 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -232,18 +232,18 @@ template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type -elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, - T* out) { - auto blas = math::GetBlas(ctx); - blas.AXPY(data_len, 1., in, out); +elementwise_add(const DeviceContext& ctx, BlasT* blas, + size_t data_len, const T* in, T* out) { + // auto blas = math::GetBlas(ctx); + blas->AXPY(data_len, 1., in, out); } template typename std::enable_if< !std::is_floating_point::value && std::is_same::value>::type -elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, - T* out) { +elementwise_add(const DeviceContext& ctx, BlasT* blas, + size_t data_len, const T* in, T* out) { for (int64_t i = 0; i < data_len; i++) { out[i] += in[i]; } @@ -305,10 +305,11 @@ struct MergeAdd { auto* input_data = input->value().data(); auto& input_rows = input->rows(); + auto blas = math::GetBlas(context); for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id[input_rows[i]]; elementwise_add( - context, static_cast(input_width), + context, &blas, static_cast(input_width), &input_data[i * input_width], &out_data[out_i * input_width]); } } -- GitLab