From ea83f898d8db0983a5705eb444c1a12aa7fe012c Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Fri, 25 Nov 2022 10:34:32 +0800 Subject: [PATCH] fix embedding_bug (#48318) --- paddle/phi/api/lib/api_custom_impl.cc | 8 ++++++-- paddle/phi/api/yaml/legacy_backward.yaml | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 19a9b808dd6..77b2fa59c33 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -179,8 +179,6 @@ void embedding_grad_impl(const Tensor& x, VLOG(6) << "embedding_grad API kernel key: [" << kernel_key.backend() << ", " << kernel_key.layout() << ", " << kernel_data_type << "]"; - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - if (phi::DenseTensor::classof(weight.impl().get())) { std::string kernel_name = sparse ? "embedding_sparse_grad" : "embedding_grad"; @@ -191,6 +189,9 @@ void embedding_grad_impl(const Tensor& x, const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel: " << kernel; + auto* dev_ctx = GetDeviceContextByBackend( + kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend()); + auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_weight = PrepareData(weight, kernel.InputAt(1), {}); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); @@ -243,6 +244,9 @@ void embedding_grad_impl(const Tensor& x, const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel: " << kernel; + auto* dev_ctx = GetDeviceContextByBackend( + kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend()); + auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_weight = TensorToSelectedRows(weight); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index dc542a9964f..98beaa77633 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -476,6 +476,7 @@ args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false) output : Tensor(weight_grad) invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) + no_need_buffer : weight - backward_op : expand_as_grad forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out) -- GitLab