From c4d5a77fec998ea21870d6479a0584daccf4aa0e Mon Sep 17 00:00:00 2001
From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com>
Date: Wed, 13 Apr 2022 10:21:21 +0800
Subject: [PATCH] concat and relu sopport FP16 in XPU, test=kunlun (#41631)

---
 paddle/fluid/operators/activation_op_xpu.cc   |  8 ++++-
 paddle/fluid/operators/concat_op_xpu.cc       | 31 +++++++++++++------
 .../fluid/platform/device/xpu/xpu2_op_list.h  | 12 ++++---
 3 files changed, 36 insertions(+), 15 deletions(-)

diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc
index 4c2d3fc162f..e950f952c24 100644
--- a/paddle/fluid/operators/activation_op_xpu.cc
+++ b/paddle/fluid/operators/activation_op_xpu.cc
@@ -490,7 +490,6 @@ REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor,
                                XPULeakyReluGradFunctor)
 REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, XPUReciprocalFunctor,
                                XPUReciprocalGradFunctor)
-REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor)
 REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor,
                                XPUSigmoidGradFunctor)
 REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
@@ -500,6 +499,13 @@ REGISTER_ACTIVATION_XPU_KERNEL(softplus, XPUSoftPlusFunctor,
 REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor)
 REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor)
 
+REGISTER_OP_XPU_KERNEL(
+    relu, ops::XPUActivationKernel<ops::XPUReluFunctor<float>>,
+    ops::XPUActivationKernel<ops::XPUReluFunctor<paddle::platform::float16>>);
+REGISTER_OP_XPU_KERNEL(
+    relu_grad, ops::XPUActivationGradKernel<ops::XPUReluGradFunctor<float>>,
+    ops::XPUActivationGradKernel<
+        ops::XPUReluGradFunctor<paddle::platform::float16>>);
 REGISTER_OP_XPU_KERNEL(
     tanh, ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>,
     ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>);
diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc
index e4b0b0ee2e3..ba35098bbac 100644
--- a/paddle/fluid/operators/concat_op_xpu.cc
+++ b/paddle/fluid/operators/concat_op_xpu.cc
@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
 
 template <typename DeviceContext, typename T>
 class ConcatXPUKernel : public framework::OpKernel<T> {
+  using XPUType = typename XPUTypeTrait<T>::Type;
+
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
     auto ins = ctx.MultiInput<framework::LoDTensor>("X");
@@ -79,10 +81,10 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
     auto place = ctx.GetPlace();
     out->mutable_data<T>(place);
     std::vector<std::vector<int>> xdims_list;
-    std::vector<const T*> ptrs;
+    std::vector<const XPUType*> ptrs;
     for (unsigned int i = 0; i < ins.size(); ++i) {
       if (ins[i] && ins[i]->numel() > 0) {
-        ptrs.push_back(ins[i]->data<T>());
+        ptrs.push_back(reinterpret_cast<const XPUType*>(ins[i]->data<T>()));
         int size = ins[i]->dims().size();
         std::vector<int> tmp_dims(size);
         for (int j = 0; j < size; ++j) {
@@ -96,8 +98,9 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
                                                 "No tensor need concat"));
     auto& dev_ctx = ctx.template device_context<DeviceContext>();
 
-    int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(),
-                           xdims_list, axis);
+    int r = xpu::concat<XPUType>(dev_ctx.x_context(), ptrs,
+                                 reinterpret_cast<XPUType*>(out->data<T>()),
+                                 xdims_list, axis);
     PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
                       platform::errors::External(
                           "XPU concat kernel return wrong value[%d %s]", r,
@@ -107,6 +110,8 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
 
 template <typename DeviceContext, typename T>
 class ConcatGradXPUKernel : public framework::OpKernel<T> {
+  using XPUType = typename XPUTypeTrait<T>::Type;
+
  public:
   void Compute(const framework::ExecutionContext& ctx) const {
     auto* out_grad =
@@ -134,12 +139,12 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
     axis = ComputeAxis(static_cast<int64_t>(axis),
                        static_cast<int64_t>(ins[0]->dims().size()));
     // get output tensor that the name is not kEmptyVarName
-    std::vector<T*> ptrs(outs.size());
+    std::vector<XPUType*> ptrs(outs.size());
     for (size_t j = 0; j < outs.size(); ++j) {
       if (out_var_names[j] != framework::kEmptyVarName &&
           outs[j]->numel() != 0UL) {
         outs[j]->mutable_data<T>(ctx.GetPlace());
-        ptrs[j] = outs[j]->data<T>();
+        ptrs[j] = reinterpret_cast<XPUType*>(outs[j]->data<T>());
       } else {
         ptrs[j] = nullptr;
       }
@@ -173,8 +178,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
     xdims_list[axis] = total_length;
 
     auto& dev_ctx = ctx.template device_context<DeviceContext>();
-    int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs,
-                          xdims_list, split_list, axis);
+    int r = xpu::split<XPUType>(
+        dev_ctx.x_context(),
+        reinterpret_cast<const XPUType*>(out_grad->data<T>()), ptrs, xdims_list,
+        split_list, axis);
     PADDLE_ENFORCE_EQ(
         r, XPU_SUCCESS,
         platform::errors::External(
@@ -189,9 +196,13 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
 
 namespace ops = paddle::operators;
 REGISTER_OP_XPU_KERNEL(
-    concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>);
+    concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>,
+    ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext,
+                         paddle::platform::float16>);
 REGISTER_OP_XPU_KERNEL(
     concat_grad,
-    ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
+    ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
+    ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext,
+                             paddle::platform::float16>);
 
 #endif
diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h
index 3a047b8fce7..9915b4d8d34 100644
--- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h
+++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h
@@ -56,8 +56,10 @@ XPUOpMap& get_kl2_ops() {
                              pOpKernelType(vartype::INT64, XPUPlace()),
                              pOpKernelType(vartype::INT32, XPUPlace())})},
       {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
-      {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
-      {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
+      {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
+                                    pOpKernelType(vartype::FP16, XPUPlace())})},
+      {"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
+                               pOpKernelType(vartype::FP16, XPUPlace())})},
       {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
                                     pOpKernelType(vartype::FP16, XPUPlace())})},
       {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
@@ -288,8 +290,10 @@ XPUOpMap& get_kl2_ops() {
       {"reduce_sum_grad",
        XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
       {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
-      {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
-      {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
+      {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
+                                  pOpKernelType(vartype::FP16, XPUPlace())})},
+      {"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
+                             pOpKernelType(vartype::FP16, XPUPlace())})},
       {"reshape2_grad",
        XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
                      pOpKernelType(vartype::INT64, XPUPlace()),
-- 
GitLab