未验证 提交 d4911594 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

optimize mergeadd for sparse_adam,*test=kunlun (#39966)

* optimize mergeadd for sparse_adam,*test=kunlun

* optimize mergeadd for sparse_adam,*test=kunlun

* optimize mergeadd for sparse_adam, *test=kunlun
上级 e8d45583
...@@ -36,7 +36,7 @@ ENDIF() ...@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220219") SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220228")
else() else()
SET(XPU_BASE_URL "${XPU_BASE_URL}") SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/operators/mkldnn/axpy_handler.h" #include "paddle/fluid/operators/mkldnn/axpy_handler.h"
...@@ -502,32 +503,29 @@ struct MergeAdd<platform::XPUDeviceContext, T> { ...@@ -502,32 +503,29 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
out.mutable_value()->mutable_data<T>( out.mutable_value()->mutable_data<T>(
phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}), phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace()); context.GetPlace());
int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
std::unordered_map<int64_t, size_t> rows_to_id; std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) { for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i; rows_to_id[merge_rows[i]] = i;
} }
auto* out_data = out.mutable_value()->data<T>(); auto* y_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>(); auto* x_data = input.value().data<T>();
int xm = input_rows.size();
int ym = merge_rows.size();
int n = input_width; int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]]; xpu::ctx_guard RAII_GUARD(context.x_context());
auto r = xpu::add(context.x_context(), &input_data[i * input_width], int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
&out_data[out_i * input_width], int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
&out_data[out_i * input_width], n); memory::Copy(context.GetPlace(), y_rows_data, platform::CPUPlace(),
PADDLE_ENFORCE_EQ( merge_rows.data(), ym * sizeof(int64_t));
r, XPU_SUCCESS, memory::Copy(context.GetPlace(), x_rows_data, platform::CPUPlace(),
platform::errors::External("XPU API return wrong value[%d %s], ", r, input_rows.data(), xm * sizeof(int64_t));
XPUAPIErrorMsg[r])); int r =
} xpu::merge_dup_rows<T, int64_t>(context.x_context(), x_data, y_data,
x_rows_data, y_rows_data, xm, n, ym);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
} }
void operator()(const platform::XPUDeviceContext& context, void operator()(const platform::XPUDeviceContext& context,
...@@ -582,15 +580,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> { ...@@ -582,15 +580,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
{static_cast<int64_t>(merged_row_set.size()), input_width}), {static_cast<int64_t>(merged_row_set.size()), input_width}),
context.GetPlace()); context.GetPlace());
int r = float* y_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
float* out_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
std::unordered_map<int64_t, size_t> rows_to_id; std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) { for (size_t i = 0; i < merge_rows.size(); ++i) {
...@@ -603,17 +593,22 @@ struct MergeAdd<platform::XPUDeviceContext, T> { ...@@ -603,17 +593,22 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
} }
auto& input_rows = input->rows(); auto& input_rows = input->rows();
auto* x_data = input->value().data<T>();
int xm = input_rows.size();
int ym = merge_rows.size();
int n = input_width; int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]]; xpu::ctx_guard RAII_GUARD(context.x_context());
auto r = xpu::add( int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
context.x_context(), input->value().data<T>() + i * input_width, int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
&out_data[out_i * input_width], &out_data[out_i * input_width], n); memory::Copy(context.GetPlace(), y_rows_data, platform::CPUPlace(),
PADDLE_ENFORCE_EQ( merge_rows.data(), ym * sizeof(int64_t));
r, XPU_SUCCESS, memory::Copy(context.GetPlace(), x_rows_data, platform::CPUPlace(),
platform::errors::External("XPU API return wrong value[%d %s], ", r, input_rows.data(), xm * sizeof(int64_t));
XPUAPIErrorMsg[r])); int r =
} xpu::merge_dup_rows<T, int64_t>(context.x_context(), x_data, y_data,
x_rows_data, y_rows_data, xm, n, ym);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册