提交 6056d043 编写于 作者: Q Qiao Longfei

optimize blas call

上级 5db75513
......@@ -232,18 +232,18 @@ template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in,
T* out) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.AXPY(data_len, 1., in, out);
elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) {
// auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas->AXPY(data_len, 1., in, out);
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in,
T* out) {
elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) {
for (int64_t i = 0; i < data_len; i++) {
out[i] += in[i];
}
......@@ -305,10 +305,11 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
auto* input_data = input->value().data<T>();
auto& input_rows = input->rows();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
elementwise_add<platform::CPUDeviceContext, T>(
context, static_cast<size_t>(input_width),
context, &blas, static_cast<size_t>(input_width),
&input_data[i * input_width], &out_data[out_i * input_width]);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册