From df898f8b78b4edf5162daa9ef104125dc4ae2c19 Mon Sep 17 00:00:00 2001 From: taixiurong Date: Tue, 18 Jan 2022 15:24:05 +0800 Subject: [PATCH] fix lookup_table_v2 error in kunlun2 (#38855) --- .../fluid/operators/lookup_table_v2_op_xpu.cc | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/lookup_table_v2_op_xpu.cc b/paddle/fluid/operators/lookup_table_v2_op_xpu.cc index eec957fb8e5..521d3ab571e 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_xpu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_xpu.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/platform/device/device_wrapper.h" #ifdef PADDLE_WITH_XPU namespace paddle { namespace operators { @@ -44,11 +45,10 @@ class LookupTableV2XPUKernel : public framework::OpKernel { auto *table_t = context.Input("W"); auto &dev_ctx = context.template device_context(); - // size_t N = table_t->dims()[0]; - size_t D = table_t->dims()[1]; auto *table = table_t->data(); auto *output = output_t->mutable_data(context.GetPlace()); + const int64_t *ids = ids_t->data(); PADDLE_ENFORCE_EQ( @@ -56,14 +56,17 @@ class LookupTableV2XPUKernel : public framework::OpKernel { platform::errors::OutOfRange( "Number of ids greater than int32_t::max , please check " "number of ids in LookupTableV2XPUKernel.")); - int ids_numel_int32 = static_cast(ids_numel); - int r = xpu::embedding(dev_ctx.x_context(), ids_numel_int32, ids, D, - table, output, padding_idx); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External( - "XPU API return wrong value[%d] , please check where " - "Baidu Kunlun Card is properly installed.", - r)); + + int ym = static_cast(ids_numel); + + size_t xm = table_t->dims()[0]; + size_t n = table_t->dims()[1]; + + int r = + xpu::embedding(dev_ctx.x_context(), table, ids, output, xm, + n, ym, static_cast(padding_idx)); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); } }; @@ -108,11 +111,7 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel { int r = xpu::embedding_grad(dev_ctx.x_context(), d_output_data, ids_data, d_table_data, xm, n, ym, padding_idx); - PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, - platform::errors::External( - "XPU API return wrong value[%d] , please check where " - "Baidu Kunlun Card is properly installed.", - r)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad"); } }; } // namespace operators -- GitLab