未验证 提交 cb7fd370 编写于 作者: Q QingshuChen 提交者: GitHub

support c_embedding_grad for kunlun (#51399)

上级 e3826e0a
......@@ -71,6 +71,70 @@ class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const int64_t start_idx = context.Attr<int64_t>("start_index");
auto ids_t = context.Input<phi::DenseTensor>("Ids");
auto d_output_t =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto table_t = context.Input<phi::DenseTensor>("W");
auto table_grad_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
T* table_grad_data =
table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace());
size_t table_t_mem_size =
table_t->numel() * phi::SizeOf(table_grad_t->dtype());
size_t table_grad_t_mem_size =
table_grad_t->numel() *
framework::SizeOfType(
framework::TransToProtoVarType(table_grad_t->dtype()));
VLOG(10) << "table_dims:" << table_t->dims()
<< ", table_t memory_size:" << table_t_mem_size
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_idx;
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::constant(
dev_ctx.x_context(), table_grad_data, table_grad_t_mem_size, (T)0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
const T* d_output_data = d_output_t->data<T>();
const int64_t height = table_t->dims()[0];
const int64_t width = table_t->dims()[1];
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
ids_t->data<int32_t>(),
table_grad_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int32_t>(start_idx));
} else if (index_type == framework::proto::VarType::INT64) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
ids_t->data<int64_t>(),
table_grad_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int64_t>(start_idx));
} else {
PADDLE_THROW(platform::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -80,3 +144,6 @@ namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
c_embedding,
ops::CEmbeddingOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
c_embedding_grad,
ops::CEmbeddingGradOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
......@@ -97,6 +97,7 @@ XPUOpMap& get_kl2_ops() {
{"c_concat",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"c_embedding", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_embedding_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_identity",
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::FLOAT32,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册