未验证 提交 ea83f898 编写于 作者: W wanghuancoder 提交者: GitHub

fix embedding_bug (#48318)

上级 00b3b4bd
...@@ -179,8 +179,6 @@ void embedding_grad_impl(const Tensor& x, ...@@ -179,8 +179,6 @@ void embedding_grad_impl(const Tensor& x,
VLOG(6) << "embedding_grad API kernel key: [" << kernel_key.backend() << ", " VLOG(6) << "embedding_grad API kernel key: [" << kernel_key.backend() << ", "
<< kernel_key.layout() << ", " << kernel_data_type << "]"; << kernel_key.layout() << ", " << kernel_data_type << "]";
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
if (phi::DenseTensor::classof(weight.impl().get())) { if (phi::DenseTensor::classof(weight.impl().get())) {
std::string kernel_name = std::string kernel_name =
sparse ? "embedding_sparse_grad" : "embedding_grad"; sparse ? "embedding_sparse_grad" : "embedding_grad";
...@@ -191,6 +189,9 @@ void embedding_grad_impl(const Tensor& x, ...@@ -191,6 +189,9 @@ void embedding_grad_impl(const Tensor& x,
const auto& kernel = kernel_result.kernel; const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel: " << kernel; VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend());
auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = PrepareData(weight, kernel.InputAt(1), {}); auto input_weight = PrepareData(weight, kernel.InputAt(1), {});
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
...@@ -243,6 +244,9 @@ void embedding_grad_impl(const Tensor& x, ...@@ -243,6 +244,9 @@ void embedding_grad_impl(const Tensor& x,
const auto& kernel = kernel_result.kernel; const auto& kernel = kernel_result.kernel;
VLOG(6) << kernel_name << " API kernel: " << kernel; VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend());
auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = TensorToSelectedRows(weight); auto input_weight = TensorToSelectedRows(weight);
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
......
...@@ -476,6 +476,7 @@ ...@@ -476,6 +476,7 @@
args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false) args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false)
output : Tensor(weight_grad) output : Tensor(weight_grad)
invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad)
no_need_buffer : weight
- backward_op : expand_as_grad - backward_op : expand_as_grad
forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out) forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册