diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index a72bdec05d77f69711d5e1f62cc66474627a8276..757cac4e4ffce442677eac99bc932f08e6b1cac1 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -300,33 +300,26 @@ template struct SelectedRowsAddToTensor -typename std::enable_if::value || - std::is_same::value>::type +typename std::enable_if::value>::type elementwise_add_to(BlasT* blas, size_t data_len, const T* in, T* out) { +#ifdef PADDLE_WITH_MKLDNN onednn_handler_axpy(data_len, T(1.f), in, out); -} - -template -typename std::enable_if::value || - std::is_same>::value || - std::is_same>::value>::type -elementwise_add_to(BlasT* blas, size_t data_len, - const T* in, T* out) { +#else blas->AXPY(data_len, T(1.f), in, out); +#endif } -#else + template -typename std::enable_if::value || +typename std::enable_if::value || + std::is_same::value || std::is_same>::value || std::is_same>::value>::type elementwise_add_to(BlasT* blas, size_t data_len, const T* in, T* out) { blas->AXPY(data_len, T(1.f), in, out); } -#endif template typename std::enable_if::value>::type elementwise_add_to(