diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 19a9b808dd6f6d986189928741c49200c1d2b86e..77b2fa59c339799f3ef930909a9b37a77368c242 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -179,8 +179,6 @@ void embedding_grad_impl(const Tensor& x, VLOG(6) << "embedding_grad API kernel key: [" << kernel_key.backend() << ", " << kernel_key.layout() << ", " << kernel_data_type << "]"; - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - if (phi::DenseTensor::classof(weight.impl().get())) { std::string kernel_name = sparse ? "embedding_sparse_grad" : "embedding_grad"; @@ -191,6 +189,9 @@ void embedding_grad_impl(const Tensor& x, const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel: " << kernel; + auto* dev_ctx = GetDeviceContextByBackend( + kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend()); + auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_weight = PrepareData(weight, kernel.InputAt(1), {}); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); @@ -243,6 +244,9 @@ void embedding_grad_impl(const Tensor& x, const auto& kernel = kernel_result.kernel; VLOG(6) << kernel_name << " API kernel: " << kernel; + auto* dev_ctx = GetDeviceContextByBackend( + kernel_result.has_fallback_cpu ? Backend::CPU : kernel_key.backend()); + auto input_x = PrepareData(x, kernel.InputAt(0), {}); auto input_weight = TensorToSelectedRows(weight); auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index dc542a9964f538c59a5d3bae3969b87dbf06b615..98beaa776336f8c08117a283eba776ca1d69af48 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -476,6 +476,7 @@ args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false) output : Tensor(weight_grad) invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) + no_need_buffer : weight - backward_op : expand_as_grad forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out)