未验证 提交 df898f8b 编写于 作者: T taixiurong 提交者: GitHub

fix lookup_table_v2 error in kunlun2 (#38855)

上级 b1365d25
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -44,11 +45,10 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> { ...@@ -44,11 +45,10 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<LoDTensor>("W"); auto *table_t = context.Input<LoDTensor>("W");
auto &dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
// size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
auto *table = table_t->data<T>(); auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
const int64_t *ids = ids_t->data<int64_t>(); const int64_t *ids = ids_t->data<int64_t>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -56,14 +56,17 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> { ...@@ -56,14 +56,17 @@ class LookupTableV2XPUKernel : public framework::OpKernel<T> {
platform::errors::OutOfRange( platform::errors::OutOfRange(
"Number of ids greater than int32_t::max , please check " "Number of ids greater than int32_t::max , please check "
"number of ids in LookupTableV2XPUKernel.")); "number of ids in LookupTableV2XPUKernel."));
int ids_numel_int32 = static_cast<int>(ids_numel);
int r = xpu::embedding<T>(dev_ctx.x_context(), ids_numel_int32, ids, D, int ym = static_cast<int>(ids_numel);
table, output, padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, size_t xm = table_t->dims()[0];
platform::errors::External( size_t n = table_t->dims()[1];
"XPU API return wrong value[%d] , please check where "
"Baidu Kunlun Card is properly installed.", int r =
r)); xpu::embedding<T, int64_t>(dev_ctx.x_context(), table, ids, output, xm,
n, ym, static_cast<int>(padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} }
}; };
...@@ -108,11 +111,7 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> { ...@@ -108,11 +111,7 @@ class LookupTableV2GradXPUKernel : public framework::OpKernel<T> {
int r = xpu::embedding_grad<T, int64_t>(dev_ctx.x_context(), d_output_data, int r = xpu::embedding_grad<T, int64_t>(dev_ctx.x_context(), d_output_data,
ids_data, d_table_data, xm, n, ym, ids_data, d_table_data, xm, n, ym,
padding_idx); padding_idx);
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true, PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
platform::errors::External(
"XPU API return wrong value[%d] , please check where "
"Baidu Kunlun Card is properly installed.",
r));
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册