diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 415c0fe9bef9eab89e670d8b3f6f7c330b316ed8..45a76fdc1f1a2aab66e7f4972eecbbec03af941a 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -36,7 +36,7 @@ ENDIF() if(NOT DEFINED XPU_BASE_URL) SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") - SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220219") + SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220228") else() SET(XPU_BASE_URL "${XPU_BASE_URL}") endif() diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index fcd5c06a6f310f8a23608a77f2d6b9098e99b33a..5ac39953462b5078aa663a7f39f5eb95c96bae7a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/device/device_wrapper.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/operators/mkldnn/axpy_handler.h" @@ -502,32 +503,29 @@ struct MergeAdd { out.mutable_value()->mutable_data( phi::make_ddim({static_cast(merge_rows.size()), input_width}), context.GetPlace()); - int r = - xpu::constant(context.x_context(), out.mutable_value()->data(), - merge_rows.size() * input_width, static_cast(0.f)); - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::External("XPU constant op return" - " wrong value[%d %s].", - r, XPUAPIErrorMsg[r])); std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { rows_to_id[merge_rows[i]] = i; } - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); + auto* y_data = out.mutable_value()->data(); + auto* x_data = input.value().data(); + int xm = input_rows.size(); + int ym = merge_rows.size(); int n = input_width; - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_to_id[input_rows[i]]; - auto r = xpu::add(context.x_context(), &input_data[i * input_width], - &out_data[out_i * input_width], - &out_data[out_i * input_width], n); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API return wrong value[%d %s], ", r, - XPUAPIErrorMsg[r])); - } + + xpu::ctx_guard RAII_GUARD(context.x_context()); + int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm(xm); + int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm(ym); + memory::Copy(context.GetPlace(), y_rows_data, platform::CPUPlace(), + merge_rows.data(), ym * sizeof(int64_t)); + memory::Copy(context.GetPlace(), x_rows_data, platform::CPUPlace(), + input_rows.data(), xm * sizeof(int64_t)); + int r = + xpu::merge_dup_rows(context.x_context(), x_data, y_data, + x_rows_data, y_rows_data, xm, n, ym); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows"); } void operator()(const platform::XPUDeviceContext& context, @@ -582,15 +580,7 @@ struct MergeAdd { {static_cast(merged_row_set.size()), input_width}), context.GetPlace()); - int r = - xpu::constant(context.x_context(), out.mutable_value()->data(), - merge_rows.size() * input_width, static_cast(0.f)); - PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, - platform::errors::External("XPU constant op return" - " wrong value[%d %s].", - r, XPUAPIErrorMsg[r])); - - float* out_data = reinterpret_cast(out.mutable_value()->data()); + float* y_data = reinterpret_cast(out.mutable_value()->data()); std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { @@ -603,17 +593,22 @@ struct MergeAdd { } auto& input_rows = input->rows(); + auto* x_data = input->value().data(); + int xm = input_rows.size(); + int ym = merge_rows.size(); int n = input_width; - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_to_id[input_rows[i]]; - auto r = xpu::add( - context.x_context(), input->value().data() + i * input_width, - &out_data[out_i * input_width], &out_data[out_i * input_width], n); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API return wrong value[%d %s], ", r, - XPUAPIErrorMsg[r])); - } + + xpu::ctx_guard RAII_GUARD(context.x_context()); + int64_t* x_rows_data = RAII_GUARD.alloc_l3_or_gm(xm); + int64_t* y_rows_data = RAII_GUARD.alloc_l3_or_gm(ym); + memory::Copy(context.GetPlace(), y_rows_data, platform::CPUPlace(), + merge_rows.data(), ym * sizeof(int64_t)); + memory::Copy(context.GetPlace(), x_rows_data, platform::CPUPlace(), + input_rows.data(), xm * sizeof(int64_t)); + int r = + xpu::merge_dup_rows(context.x_context(), x_data, y_data, + x_rows_data, y_rows_data, xm, n, ym); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows"); } } };