提交 02259575 编写于 作者: Q Qiao Longfei

change elementwise_add to elementwise_add_to test=develop

上级 64406706
...@@ -233,8 +233,8 @@ template <typename DeviceContext, typename T> ...@@ -233,8 +233,8 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
std::is_floating_point<T>::value && std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas, elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) { size_t data_len, const T* in, T* out) {
blas->AXPY(data_len, 1., in, out); blas->AXPY(data_len, 1., in, out);
} }
...@@ -242,8 +242,8 @@ template <typename DeviceContext, typename T> ...@@ -242,8 +242,8 @@ template <typename DeviceContext, typename T>
typename std::enable_if< typename std::enable_if<
!std::is_floating_point<T>::value && !std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas, elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) { size_t data_len, const T* in, T* out) {
for (int64_t i = 0; i < data_len; i++) { for (int64_t i = 0; i < data_len; i++) {
out[i] += in[i]; out[i] += in[i];
} }
...@@ -308,7 +308,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> { ...@@ -308,7 +308,7 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
for (size_t i = 0; i < input_rows.size(); i++) { for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]]; size_t out_i = rows_to_id[input_rows[i]];
elementwise_add<platform::CPUDeviceContext, T>( elementwise_add_to<platform::CPUDeviceContext, T>(
context, &blas, static_cast<size_t>(input_width), context, &blas, static_cast<size_t>(input_width),
&input_data[i * input_width], &out_data[out_i * input_width]); &input_data[i * input_width], &out_data[out_i * input_width]);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册