未验证 提交 649948a6 编写于 作者: Z zhangyikun02 提交者: GitHub

softmax_with_cross_entropy support fp16 on xpu, test=kunlun (#40869)

上级 3b381aac
......@@ -28,6 +28,8 @@ namespace operators {
template <typename T>
class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
......@@ -48,6 +50,10 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
std::vector<int> logits_dims = phi::vectorize<int>(logits->dims());
const bool soft_label = context.Attr<bool>("soft_label");
auto logits_data = reinterpret_cast<const XPUType*>(logits->data<T>());
auto softmax_data = reinterpret_cast<XPUType*>(softmax->data<T>());
auto loss_data = reinterpret_cast<XPUType*>(loss->data<T>());
// softmax
auto& dev_ctx =
context.template device_context<platform::XPUDeviceContext>();
......@@ -55,32 +61,41 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) ==
phi::backends::xpu::XPUVersion::XPU2 &&
soft_label) {
r = xpu::soft_softmax_with_cross_entropy(
dev_ctx.x_context(), logits->data<float>(), labels->data<T>(),
softmax->data<T>(), loss->data<T>(), n, d);
auto labels_data = reinterpret_cast<const XPUType*>(labels->data<T>());
r = xpu::soft_softmax_with_cross_entropy<XPUType>(
dev_ctx.x_context(), logits_data, labels_data, softmax_data,
loss_data, n, d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy");
return;
}
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int len = logits->numel();
T* clip_logits_data = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits_data);
T* clip_logits = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits);
XPUType* clip_logits_data = reinterpret_cast<XPUType*>(clip_logits);
float max_val = 1e20;
float min_val = -1e20;
if (std::is_same<T, platform::float16>::value) {
max_val = 65504;
min_val = -65504;
}
r = xpu::clip_v2(dev_ctx.x_context(), logits->data<float>(),
clip_logits_data, len, static_cast<float>(-1e20),
static_cast<float>(1e20));
r = xpu::clip_v2<XPUType>(
dev_ctx.x_context(), logits_data, clip_logits_data, len,
static_cast<XPUType>(min_val), static_cast<XPUType>(max_val));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::softmax(dev_ctx.x_context(), clip_logits_data,
softmax->data<float>(), logits_dims, axis);
r = xpu::softmax<XPUType>(dev_ctx.x_context(), clip_logits_data,
softmax_data, logits_dims, axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
// cross_entropy
if (soft_label) {
r = xpu::soft_cross_entropy<float>(
dev_ctx.x_context(), softmax->data<float>(), labels->data<float>(),
loss->data<float>(), n, d);
auto labels_data = reinterpret_cast<const XPUType*>(labels->data<T>());
r = xpu::soft_cross_entropy<XPUType>(dev_ctx.x_context(), softmax_data,
labels_data, loss_data, n, d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy");
} else {
auto ignore_index = context.Attr<int>("ignore_index");
......@@ -92,10 +107,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
labels_int32.data<int32_t>(), labels->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::hard_cross_entropy<float, int32_t>(
dev_ctx.x_context(), softmax->data<float>(),
labels_int32.data<int32_t>(), loss->data<float>(), nullptr, n, d,
ignore_index);
r = xpu::hard_cross_entropy<XPUType, int32_t>(
dev_ctx.x_context(), softmax_data, labels_int32.data<int32_t>(),
loss_data, nullptr, n, d, ignore_index);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy");
}
}
......@@ -167,8 +181,9 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(softmax_with_cross_entropy,
ops::SoftmaxWithCrossEntropyXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(
softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyXPUKernel<float>,
ops::SoftmaxWithCrossEntropyXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradXPUKernel<float>,
......
......@@ -321,7 +321,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softplus_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册