未验证 提交 1412d3bc 编写于 作者: A arlesniak 提交者: GitHub

Use CBLAS for SelectedRows elementwise add operation. (#34008)

* Use CBLAS for SelectedRows elementwise add operation. It's faster.

* template compilation fix

* reverted template compilation fix

* slimmed template compilation fix
Co-authored-by: NAdam Osewski <adam.osewski@intel.com>
上级 78ab656c
...@@ -300,33 +300,26 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, ...@@ -300,33 +300,26 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
// add or mul. // add or mul.
namespace scatter { namespace scatter {
#ifdef PADDLE_WITH_MKLDNN
template <typename T> template <typename T>
typename std::enable_if<std::is_same<T, float>::value || typename std::enable_if<std::is_same<T, platform::bfloat16>::value>::type
std::is_same<T, platform::bfloat16>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len, elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) { const T* in, T* out) {
#ifdef PADDLE_WITH_MKLDNN
onednn_handler_axpy(data_len, T(1.f), in, out); onednn_handler_axpy(data_len, T(1.f), in, out);
} #else
template <typename T>
typename std::enable_if<std::is_same<T, double>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
blas->AXPY(data_len, T(1.f), in, out); blas->AXPY(data_len, T(1.f), in, out);
#endif
} }
#else
template <typename T> template <typename T>
typename std::enable_if<std::is_floating_point<T>::value || typename std::enable_if<std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, platform::complex<float>>::value || std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type std::is_same<T, platform::complex<double>>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len, elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) { const T* in, T* out) {
blas->AXPY(data_len, T(1.f), in, out); blas->AXPY(data_len, T(1.f), in, out);
} }
#endif
template <typename T> template <typename T>
typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to( typename std::enable_if<std::is_integral<T>::value>::type elementwise_add_to(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册