未验证 提交 649948a6 编写于 作者: Z zhangyikun02 提交者: GitHub

softmax_with_cross_entropy support fp16 on xpu, test=kunlun (#40869)

上级 3b381aac
...@@ -28,6 +28,8 @@ namespace operators { ...@@ -28,6 +28,8 @@ namespace operators {
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -48,6 +50,10 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { ...@@ -48,6 +50,10 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
std::vector<int> logits_dims = phi::vectorize<int>(logits->dims()); std::vector<int> logits_dims = phi::vectorize<int>(logits->dims());
const bool soft_label = context.Attr<bool>("soft_label"); const bool soft_label = context.Attr<bool>("soft_label");
auto logits_data = reinterpret_cast<const XPUType*>(logits->data<T>());
auto softmax_data = reinterpret_cast<XPUType*>(softmax->data<T>());
auto loss_data = reinterpret_cast<XPUType*>(loss->data<T>());
// softmax // softmax
auto& dev_ctx = auto& dev_ctx =
context.template device_context<platform::XPUDeviceContext>(); context.template device_context<platform::XPUDeviceContext>();
...@@ -55,32 +61,41 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { ...@@ -55,32 +61,41 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) == if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) ==
phi::backends::xpu::XPUVersion::XPU2 && phi::backends::xpu::XPUVersion::XPU2 &&
soft_label) { soft_label) {
r = xpu::soft_softmax_with_cross_entropy( auto labels_data = reinterpret_cast<const XPUType*>(labels->data<T>());
dev_ctx.x_context(), logits->data<float>(), labels->data<T>(), r = xpu::soft_softmax_with_cross_entropy<XPUType>(
softmax->data<T>(), loss->data<T>(), n, d); dev_ctx.x_context(), logits_data, labels_data, softmax_data,
loss_data, n, d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy");
return; return;
} }
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int len = logits->numel(); int len = logits->numel();
T* clip_logits_data = RAII_GUARD.alloc_l3_or_gm<T>(len); T* clip_logits = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits_data); PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits);
XPUType* clip_logits_data = reinterpret_cast<XPUType*>(clip_logits);
float max_val = 1e20;
float min_val = -1e20;
if (std::is_same<T, platform::float16>::value) {
max_val = 65504;
min_val = -65504;
}
r = xpu::clip_v2(dev_ctx.x_context(), logits->data<float>(), r = xpu::clip_v2<XPUType>(
clip_logits_data, len, static_cast<float>(-1e20), dev_ctx.x_context(), logits_data, clip_logits_data, len,
static_cast<float>(1e20)); static_cast<XPUType>(min_val), static_cast<XPUType>(max_val));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::softmax(dev_ctx.x_context(), clip_logits_data, r = xpu::softmax<XPUType>(dev_ctx.x_context(), clip_logits_data,
softmax->data<float>(), logits_dims, axis); softmax_data, logits_dims, axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
// cross_entropy // cross_entropy
if (soft_label) { if (soft_label) {
r = xpu::soft_cross_entropy<float>( auto labels_data = reinterpret_cast<const XPUType*>(labels->data<T>());
dev_ctx.x_context(), softmax->data<float>(), labels->data<float>(), r = xpu::soft_cross_entropy<XPUType>(dev_ctx.x_context(), softmax_data,
loss->data<float>(), n, d); labels_data, loss_data, n, d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy");
} else { } else {
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
...@@ -92,10 +107,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> { ...@@ -92,10 +107,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
labels_int32.data<int32_t>(), labels->numel()); labels_int32.data<int32_t>(), labels->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::hard_cross_entropy<float, int32_t>( r = xpu::hard_cross_entropy<XPUType, int32_t>(
dev_ctx.x_context(), softmax->data<float>(), dev_ctx.x_context(), softmax_data, labels_int32.data<int32_t>(),
labels_int32.data<int32_t>(), loss->data<float>(), nullptr, n, d, loss_data, nullptr, n, d, ignore_index);
ignore_index);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy");
} }
} }
...@@ -167,8 +181,9 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> { ...@@ -167,8 +181,9 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_XPU_KERNEL(
ops::SoftmaxWithCrossEntropyXPUKernel<float>); softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyXPUKernel<float>,
ops::SoftmaxWithCrossEntropyXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
softmax_with_cross_entropy_grad, softmax_with_cross_entropy_grad,
ops::SoftmaxWithCrossEntropyGradXPUKernel<float>, ops::SoftmaxWithCrossEntropyGradXPUKernel<float>,
......
...@@ -321,7 +321,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -321,7 +321,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"softmax_with_cross_entropy", {"softmax_with_cross_entropy",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"softplus", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"softplus_grad", {"softplus_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册