提交 6056d043 编写于 作者: Q Qiao Longfei

optimize blas call

上级 5db75513
...@@ -232,18 +232,18 @@ template <typename DeviceContext, typename T> ...@@ -232,18 +232,18 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
T* out) { size_t data_len, const T* in, T* out) {
auto blas = math::GetBlas<DeviceContext, T>(ctx); // auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.AXPY(data_len, 1., in, out); blas->AXPY(data_len, 1., in, out);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value && !std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
T* out) { size_t data_len, const T* in, T* out) {
for (int64_t i = 0; i < data_len; i++) { for (int64_t i = 0; i < data_len; i++) {
out[i] += in[i]; out[i] += in[i];
} }
...@@ -305,10 +305,11 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -305,10 +305,11 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
auto* input_data = input->value().data<T>(); auto* input_data = input->value().data<T>();
auto& input_rows = input->rows(); auto& input_rows = input->rows();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (size_t i = 0; i < input_rows.size(); i++) { for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]]; size_t out_i = rows_to_id[input_rows[i]];
elementwise_add<platform::CPUDeviceContext, T>( elementwise_add<platform::CPUDeviceContext, T>(
context, static_cast<size_t>(input_width), context, &blas, static_cast<size_t>(input_width),
&input_data[i * input_width], &out_data[out_i * input_width]); &input_data[i * input_width], &out_data[out_i * input_width]);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册