未验证 提交 df898f8b 编写于 作者: T taixiurong 提交者: GitHub

fix lookup_table_v2 error in kunlun2 (#38855)

上级 b1365d25
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
......@@ -44,11 +45,10 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<LoDTensor>("W");
auto &dev_ctx = context.template device_context<DeviceContext>();
// size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
const int64_t *ids = ids_t->data<int64_t>();
PADDLE_ENFORCE_EQ(
......@@ -56,14 +56,17 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2XPUKernel."));
int ids_numel_int32 = static_cast<int>(ids_numel);
int r = xpu::embedding<T>(dev_ctx.x_context(), ids_numel_int32, ids, D,
table, output, padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d] , please check where "
"Baidu Kunlun Card is properly installed.",
r));
int ym = static_cast<int>(ids_numel);
size_t xm = table_t->dims()[0];
size_t n = table_t->dims()[1];
int r =
xpu::embedding<T, int64_t>(dev_ctx.x_context(), table, ids, output, xm,
n, ym, static_cast<int>(padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
}
};
......@@ -108,11 +111,7 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
int r = xpu::embedding_grad<T, int64_t>(dev_ctx.x_context(), d_output_data,
ids_data, d_table_data, xm, n, ym,
padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
platform::errors::External(
"XPU API return wrong value[%d] , please check where "
"Baidu Kunlun Card is properly installed.",
r));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册