未验证 提交 d4911594 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

optimize mergeadd for sparse_adam,*test=kunlun (#39966)

* optimize mergeadd for sparse_adam,*test=kunlun

* optimize mergeadd for sparse_adam,*test=kunlun

* optimize mergeadd for sparse_adam, *test=kunlun
上级 e8d45583
......@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220219")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220228")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
......@@ -502,32 +503,29 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
out.mutable_value()->mutable_data<T>(
phi::make_ddim({static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
rows_to_id[merge_rows[i]] = i;
}
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
auto* y_data = out.mutable_value()->data<T>();
auto* x_data = input.value().data<T>();
int xm = input_rows.size();
int ym = merge_rows.size();
int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
auto r = xpu::add(context.x_context(), &input_data[i * input_width],
&out_data[out_i * input_width],
&out_data[out_i * input_width], n);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s], ", r,
XPUAPIErrorMsg[r]));
}
xpu::ctx_guard RAII_GUARD(context.x_context());
int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
memory::Copy(context.GetPlace(), y_rows_data, platform::CPUPlace(),
merge_rows.data(), ym * sizeof(int64_t));
memory::Copy(context.GetPlace(), x_rows_data, platform::CPUPlace(),
input_rows.data(), xm * sizeof(int64_t));
int r =
xpu::merge_dup_rows<T, int64_t>(context.x_context(), x_data, y_data,
x_rows_data, y_rows_data, xm, n, ym);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
}
void operator()(const platform::XPUDeviceContext& context,
......@@ -582,15 +580,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
{static_cast<int64_t>(merged_row_set.size()), input_width}),
context.GetPlace());
int r =
xpu::constant<T>(context.x_context(), out.mutable_value()->data<T>(),
merge_rows.size() * input_width, static_cast<T>(0.f));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU constant op return"
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
float* out_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
float* y_data = reinterpret_cast<float*>(out.mutable_value()->data<T>());
std::unordered_map<int64_t, size_t> rows_to_id;
for (size_t i = 0; i < merge_rows.size(); ++i) {
......@@ -603,17 +593,22 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
}
auto& input_rows = input->rows();
auto* x_data = input->value().data<T>();
int xm = input_rows.size();
int ym = merge_rows.size();
int n = input_width;
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id[input_rows[i]];
auto r = xpu::add(
context.x_context(), input->value().data<T>() + i * input_width,
&out_data[out_i * input_width], &out_data[out_i * input_width], n);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d %s], ", r,
XPUAPIErrorMsg[r]));
}
xpu::ctx_guard RAII_GUARD(context.x_context());
int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(xm);
int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm<int64_t>(ym);
memory::Copy(context.GetPlace(), y_rows_data, platform::CPUPlace(),
merge_rows.data(), ym * sizeof(int64_t));
memory::Copy(context.GetPlace(), x_rows_data, platform::CPUPlace(),
input_rows.data(), xm * sizeof(int64_t));
int r =
xpu::merge_dup_rows<T, int64_t>(context.x_context(), x_data, y_data,
x_rows_data, y_rows_data, xm, n, ym);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册