diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index fef881ffd249d398664e507dc4a390bff2c20c77..d0ba2083b750c25297262e14a6ad13585b10213b 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -29,6 +29,8 @@ XPUOpMap& get_kl2_ops() { {"adadelta", XPUKernelSet({phi::DataType::FLOAT32})}, {"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"adam_dense_param_sparse_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"arg_max", XPUKernelSet({phi::DataType::FLOAT32})}, {"argsort_grad", XPUKernelSet({phi::DataType::INT32, @@ -152,7 +154,10 @@ XPUOpMap& get_kl2_ops() { {"elementwise_mul_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_mul", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"elementwise_pow", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"elementwise_sub_grad", @@ -176,6 +181,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::FLOAT64})}, + {"embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"embedding_sparse_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"equal", XPUKernelSet({phi::DataType::INT64, @@ -356,6 +362,11 @@ XPUOpMap& get_kl2_ops() { {"mul", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"mul_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"multiply", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::INT32, + phi::DataType::INT64})}, {"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"not_equal", @@ -549,8 +560,14 @@ XPUOpMap& get_kl2_ops() { {"temporal_shift_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"tril_triu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"tril", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"triu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, {"tril_triu_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"tril_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, + {"triu_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})}, {"tile", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64, diff --git a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc index cd3b920feffa85b476a5e1eb64680be4b8c3e428..19e2e380aa1e4ffadeb1dc46bab093537b9b2fcc 100644 --- a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/memory/memcpy.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" namespace phi { @@ -80,15 +81,18 @@ void EmbeddingSparseGradKernel(const Context& ctx, int64_t padding_idx, SelectedRows* weight_grad) { DDim table_dim = weight.dims(); + auto xpu_place = ctx.GetPlace(); xpu::ctx_guard RAII_GUARD(ctx.x_context()); - std::vector ids(input.numel()); + std::vector ids; + DenseTensor ids_cpu; + ids_cpu.Resize(input.dims()); + ctx.template HostAlloc( + &ids_cpu, input.dtype(), input.numel() * sizeof(int64_t)); if (input.dtype() == phi::DataType::INT64) { - paddle::memory::Copy(CPUPlace(), - ids.data(), - input.place(), - input.data(), - sizeof(int64_t) * input.numel()); + phi::Copy(ctx, input, CPUPlace(), false, &ids_cpu); + + ids = CopyIdsToVector(ids_cpu); } else if (input.dtype() == phi::DataType::INT32) { int64_t* id_t = RAII_GUARD.alloc_l3_or_gm(input.numel()); @@ -96,10 +100,11 @@ void EmbeddingSparseGradKernel(const Context& ctx, ctx.x_context(), input.data(), id_t, input.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); paddle::memory::Copy(CPUPlace(), - ids.data(), + ids_cpu.data(), input.place(), id_t, sizeof(int64_t) * input.numel()); + ids = CopyIdsToVector(ids_cpu); } else { PADDLE_THROW(phi::errors::Unimplemented( "emebdding input only support int32 and int64")); @@ -115,7 +120,7 @@ void EmbeddingSparseGradKernel(const Context& ctx, auto* d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - ctx.template Alloc(d_table_value); + ctx.template HostAlloc(d_table_value); d_table->set_height(table_dim[0]); @@ -134,9 +139,12 @@ void EmbeddingSparseGradKernel(const Context& ctx, "output@Grad's shape = [%s].", d_table_value->dims(), d_output_dims_2d)); - int r = xpu::copy( - ctx.x_context(), d_output_data, d_table_data, d_output->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); + + paddle::memory::Copy(CPUPlace(), + d_table_data, + xpu_place, + d_output_data, + d_output->numel() * sizeof(T)); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc b/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc index 1f5d95c50f8e2a234f774014eee304471dcc763c..946daf5d7104a57c5b10141a393c1e2277c99770 100644 --- a/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/log_softmax_grad_kernel.cc @@ -37,42 +37,16 @@ void LogSoftmaxGradKernel(const Context& dev_ctx, return; } - if (out.numel() != 0) { - auto out_shape = phi::vectorize(out.dims()); - dev_ctx.template Alloc(x_grad); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - T* tmp_ptr = RAII_GUARD.alloc_l3_or_gm(out_grad.numel()); - T* tmp2_ptr = RAII_GUARD.alloc_l3_or_gm(out_grad.numel()); - PADDLE_ENFORCE_NE( - tmp_ptr, nullptr, phi::errors::External("no enough memory in xpu")); - PADDLE_ENFORCE_NE( - tmp2_ptr, nullptr, phi::errors::External("no enough memory in xpu")); - - int r = xpu::exp(dev_ctx.x_context(), - reinterpret_cast(out.data()), - reinterpret_cast(tmp_ptr), - out_grad.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); - r = xpu::reciprocal(dev_ctx.x_context(), - reinterpret_cast(tmp_ptr), - reinterpret_cast(tmp2_ptr), - out_grad.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal"); - r = xpu::mul(dev_ctx.x_context(), - reinterpret_cast(tmp2_ptr), - reinterpret_cast(out_grad.data()), - reinterpret_cast(tmp2_ptr), - out_grad.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); - r = xpu::softmax_grad( - dev_ctx.x_context(), - reinterpret_cast(tmp_ptr), - reinterpret_cast(tmp2_ptr), - reinterpret_cast(x_grad->data()), - out_shape, - axis); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_grad"); - } + auto out_shape = phi::vectorize(out.dims()); + dev_ctx.template Alloc(x_grad); + int r = xpu::log_softmax_grad( + dev_ctx.x_context(), + reinterpret_cast(out.data()), + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + out_shape, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_softmax_grad"); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/log_softmax_kernel.cc b/paddle/phi/kernels/xpu/log_softmax_kernel.cc index a4feac7b2330716770ec97ec7147e27e2b12a22b..059157db63a3149743aed6359dfd28ee70e50818 100644 --- a/paddle/phi/kernels/xpu/log_softmax_kernel.cc +++ b/paddle/phi/kernels/xpu/log_softmax_kernel.cc @@ -39,17 +39,13 @@ void LogSoftmaxKernel(const Context& dev_ctx, auto x_shape = phi::vectorize(x.dims()); dev_ctx.template Alloc(out); if (axis < 0) axis += rank; - int r = xpu::softmax(dev_ctx.x_context(), + int r = + xpu::log_softmax(dev_ctx.x_context(), reinterpret_cast(x.data()), reinterpret_cast(out->data()), x_shape, axis); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); - r = xpu::log(dev_ctx.x_context(), - reinterpret_cast(out->data()), - reinterpret_cast(out->data()), - out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "log"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "log_softmax"); } }