diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index a4f584623acf99ca2040955514b28670125bb6c0..77864aa7c0d29583ab4aa275db0e440c60e8718c 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -232,18 +232,18 @@ template typename std::enable_if< std::is_floating_point::value && std::is_same::value>::type -elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, - T* out) { - auto blas = math::GetBlas(ctx); - blas.AXPY(data_len, 1., in, out); +elementwise_add(const DeviceContext& ctx, BlasT* blas, + size_t data_len, const T* in, T* out) { + // auto blas = math::GetBlas(ctx); + blas->AXPY(data_len, 1., in, out); } template typename std::enable_if< !std::is_floating_point::value && std::is_same::value>::type -elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, - T* out) { +elementwise_add(const DeviceContext& ctx, BlasT* blas, + size_t data_len, const T* in, T* out) { for (int64_t i = 0; i < data_len; i++) { out[i] += in[i]; } @@ -305,10 +305,11 @@ struct MergeAdd { auto* input_data = input->value().data(); auto& input_rows = input->rows(); + auto blas = math::GetBlas(context); for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id[input_rows[i]]; elementwise_add( - context, static_cast(input_width), + context, &blas, static_cast(input_width), &input_data[i * input_width], &out_data[out_i * input_width]); } }