From cb7fd370ef0b57a87d9603f41bdeb3650608b334 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Fri, 10 Mar 2023 10:56:35 +0800 Subject: [PATCH] support c_embedding_grad for kunlun (#51399) --- .../collective/c_embedding_op_xpu.cc | 67 +++++++++++++++++++ paddle/phi/backends/xpu/xpu2_op_list.cc | 1 + 2 files changed, 68 insertions(+) diff --git a/paddle/fluid/operators/collective/c_embedding_op_xpu.cc b/paddle/fluid/operators/collective/c_embedding_op_xpu.cc index c966ed3354a..e3f54ebfbeb 100644 --- a/paddle/fluid/operators/collective/c_embedding_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_embedding_op_xpu.cc @@ -71,6 +71,70 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel { } }; +template +class CEmbeddingGradOpXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const int64_t start_idx = context.Attr("start_index"); + auto ids_t = context.Input("Ids"); + auto d_output_t = + context.Input(framework::GradVarName("Out")); + auto table_t = context.Input("W"); + auto table_grad_t = + context.Output(framework::GradVarName("W")); + + T* table_grad_data = + table_grad_t->mutable_data(table_t->dims(), context.GetPlace()); + + size_t table_t_mem_size = + table_t->numel() * phi::SizeOf(table_grad_t->dtype()); + size_t table_grad_t_mem_size = + table_grad_t->numel() * + framework::SizeOfType( + framework::TransToProtoVarType(table_grad_t->dtype())); + + VLOG(10) << "table_dims:" << table_t->dims() + << ", table_t memory_size:" << table_t_mem_size + << ", table_grad_t memory_size:" << table_grad_t_mem_size + << ", start_index:" << start_idx; + + auto& dev_ctx = context.template device_context(); + int r = xpu::constant( + dev_ctx.x_context(), table_grad_data, table_grad_t_mem_size, (T)0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + const T* d_output_data = d_output_t->data(); + + const int64_t height = table_t->dims()[0]; + const int64_t width = table_t->dims()[1]; + + const auto& index_type = framework::TransToProtoVarType(ids_t->dtype()); + if (index_type == framework::proto::VarType::INT32) { + r = xpu::embedding_grad(dev_ctx.x_context(), + d_output_data, + ids_t->data(), + table_grad_data, + height, + width, + ids_t->numel(), + -1, + static_cast(start_idx)); + } else if (index_type == framework::proto::VarType::INT64) { + r = xpu::embedding_grad(dev_ctx.x_context(), + d_output_data, + ids_t->data(), + table_grad_data, + height, + width, + ids_t->numel(), + -1, + static_cast(start_idx)); + } else { + PADDLE_THROW(platform::errors::Unavailable( + "XPU c_embedding ids only support int32 or int64.")); + } + } +}; + } // namespace operators } // namespace paddle @@ -80,3 +144,6 @@ namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL( c_embedding, ops::CEmbeddingOpXPUKernel); +REGISTER_OP_XPU_KERNEL( + c_embedding_grad, + ops::CEmbeddingGradOpXPUKernel); diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 77bff35611a..509aa8afa0d 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -97,6 +97,7 @@ XPUOpMap& get_kl2_ops() { {"c_concat", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, {"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"c_identity", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, -- GitLab