fused_gemm_epilogue_op.cu 18.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* y = ctx.Input<Tensor>("Y");
    const Tensor* bias = ctx.Input<Tensor>("Bias");

    Tensor* out = ctx.Output<Tensor>("Out");
    Tensor* reserve_space = ctx.Output<Tensor>("ReserveSpace");

    bool trans_x = ctx.Attr<bool>("trans_x");
    bool trans_y = ctx.Attr<bool>("trans_y");

    std::string activation = ctx.Attr<std::string>("activation");
R
root 已提交
43 44 45
    VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y
             << " , activation = " << activation;
    // activation = "none";
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    bool enable_auxiliary = reserve_space == nullptr ? false : true;

    out->mutable_data<T>(ctx.GetPlace());
    auto* out_data = out->data<T>();

    auto x_mat_dims =
        phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1);
    int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0];
    int64_t K = trans_y ? y->dims()[1] : y->dims()[0];
    int64_t N = trans_y ? y->dims()[0] : y->dims()[1];

    cudaDataType_t mat_type = CUDA_R_32F;
    cudaDataType_t scale_type = CUDA_R_32F;
    cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
    if (std::is_same<T, paddle::platform::float16>::value) {
      mat_type = CUDA_R_16F;
R
root 已提交
62
      scale_type = CUDA_R_32F;
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    }
    if (std::is_same<T, double>::value) {
      mat_type = CUDA_R_64F;
      scale_type = CUDA_R_64F;
      compute_type = CUBLAS_COMPUTE_64F;
    }

    cublasLtMatmulDesc_t operation_desc = NULL;
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
        &operation_desc, compute_type, scale_type));
    cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
    cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
            operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx,
            sizeof(transx)));
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
            operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy,
            sizeof(transy)));

    cublasLtEpilogue_t epiloque_func =
        get_epilogue_type_(activation, enable_auxiliary);
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
            operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func,
            sizeof(epiloque_func)));
    const T* bias_data = bias->data<T>();
    PADDLE_ENFORCE_GPU_SUCCESS(
        platform::dynload::cublasLtMatmulDescSetAttribute(
            operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_data,
            sizeof(bias_data)));

    if (enable_auxiliary && activation != "none") {
      size_t reserve_space_size = 0;
      if (activation == "relu") {
        // Count in bits.
        reserve_space_size = phi::product(out->dims()) / 8;
      } else {
        reserve_space_size = phi::product(out->dims()) * sizeof(T);
      }
      reserve_space->mutable_data(ctx.GetPlace(), out->type(),
                                  reserve_space_size);
      void* aux_data = reinterpret_cast<void*>(reserve_space->data<T>());

      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
              operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
              &aux_data, sizeof(aux_data)));
R
root 已提交
112 113
      // int64_t aux_ld = trans_y ? K : N;
      int64_t aux_ld = N;
114 115
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
R
root 已提交
116 117
              operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld,
              sizeof(aux_ld)));
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
    }

    cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
    if (trans_x)
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &x_desc, mat_type, M, K, M));
    else
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &x_desc, mat_type, K, M, K));
    if (trans_y)
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &y_desc, mat_type, K, N, K));
    else
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &y_desc, mat_type, N, K, N));
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
        &out_desc, mat_type, N, M, N));

    cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
R
root 已提交
137
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
    const cublasLtMatmulAlgo_t* algo = nullptr;
    cudaStream_t stream = dev_ctx.stream();
    memory::allocation::AllocationPtr workspace =
        memory::Alloc(dev_ctx, workspace_size);

    double alpha64 = 1.0, beta64 = 0.0;
    float alpha32 = 1.0f, beta32 = 0.0f;
    void *alpha = nullptr, *beta = nullptr;
    if (std::is_same<T, double>::value) {
      alpha = &alpha64;
      beta = &beta64;
    } else {
      alpha = &alpha32;
      beta = &beta32;
    }

    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
        lt_handle, operation_desc, alpha, y->data<T>(), y_desc, x->data<T>(),
        x_desc, beta, out_data, out_desc, out_data, out_desc, algo,
        workspace->ptr(), workspace_size, stream));
  }

 private:
  static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation,
                                               bool enable_auxiliary) {
    if (activation == "relu") {
      return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS
                              : CUBLASLT_EPILOGUE_RELU_BIAS;
    } else if (activation == "gelu") {
      return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS
                              : CUBLASLT_EPILOGUE_GELU_BIAS;
    } else if (activation == "none") {
      return CUBLASLT_EPILOGUE_BIAS;
    } else {
      PADDLE_ENFORCE_EQ(
          true, false,
          platform::errors::InvalidArgument(
              "The activation attribute of fused_gemm_epilogue op should be"
              " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
              "But received activation=%s.",
              activation));
    }
  }
};

template <typename DeviceContext, typename T>
class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();

    const Tensor* dout = ctx.Input<Tensor>("DOut");
    const Tensor* x = ctx.Input<Tensor>("X");
    const Tensor* y = ctx.Input<Tensor>("Y");
    const Tensor* reserve_space = ctx.Input<Tensor>("ReserveSpace");

    Tensor* dx = ctx.Output<Tensor>("DX");
    Tensor* dy = ctx.Output<Tensor>("DY");
    Tensor* dbias = ctx.Output<Tensor>("DBias");

    std::string activation_grad = ctx.Attr<std::string>("activation_grad");

R
root 已提交
200 201
    bool transpose_x = ctx.Attr<bool>("trans_x");
    bool transpose_y = ctx.Attr<bool>("trans_y");
202

R
root 已提交
203 204 205 206 207 208 209 210 211 212 213
    VLOG(10) << "trans_x = " << transpose_x << " , trans_y = " << transpose_y
             << " , activation_grad = " << activation_grad;

    // activation_grad = "none";

    auto x_mat_dims =
        phi::flatten_to_2d(x->dims(), transpose_x ? 1 : x->dims().size() - 1);

    int64_t M = transpose_x ? x_mat_dims[1] : x_mat_dims[0];
    int64_t K = transpose_y ? y->dims()[1] : y->dims()[0];
    int64_t N = transpose_y ? y->dims()[0] : y->dims()[1];
214 215 216 217 218 219

    cudaDataType_t mat_type = CUDA_R_32F;
    cudaDataType_t scale_type = CUDA_R_32F;
    cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
    if (std::is_same<T, paddle::platform::float16>::value) {
      mat_type = CUDA_R_16F;
R
root 已提交
220
      scale_type = CUDA_R_32F;
221 222 223 224 225 226 227 228
    }
    if (std::is_same<T, double>::value) {
      mat_type = CUDA_R_64F;
      scale_type = CUDA_R_64F;
      compute_type = CUBLAS_COMPUTE_64F;
    }

    cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
R
root 已提交
229
    size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024;
230 231 232 233 234 235 236 237 238 239 240 241 242 243
    const cublasLtMatmulAlgo_t* algo = nullptr;
    cudaStream_t stream = dev_ctx.stream();

    double alpha64 = 1.0, beta64 = 0.0;
    float alpha32 = 1.0f, beta32 = 0.0f;
    void *alpha = nullptr, *beta = nullptr;
    if (std::is_same<T, double>::value) {
      alpha = &alpha64;
      beta = &beta64;
    } else {
      alpha = &alpha32;
      beta = &beta32;
    }

R
root 已提交
244
    cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr;
245
    if (dx) {
R
root 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
      cublasOperation_t trans_dout = transpose_x ? CUBLAS_OP_T : CUBLAS_OP_N;
      cublasOperation_t trans_y =
          (transpose_x ^ transpose_y) ? CUBLAS_OP_N : CUBLAS_OP_T;

      cublasLtMatrixLayout_t dout_desc_for_dx, y_desc, dx_desc;
      if (trans_dout == CUBLAS_OP_T) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(&dout_trans_desc,
                                                          mat_type, M, N, M));
        dout_desc_for_dx = dout_trans_desc;
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(&dout_desc, mat_type,
                                                          N, M, N));
        dout_desc_for_dx = dout_desc;
      }

      if (transpose_y) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, K,
                                                          N, K));
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, N,
                                                          K, N));
      }

      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &dx_desc, mat_type, K, M, K));

276 277 278
      cublasLtMatmulDesc_t dx_operation_desc = NULL;
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
          &dx_operation_desc, compute_type, scale_type));
R
root 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301

      if (transpose_x) {
        // dx = B * dout
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
                sizeof(trans_dout)));
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_y,
                sizeof(trans_y)));
      } else {
        // dx = dout * B
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
                sizeof(trans_dout)));
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y,
                sizeof(trans_y)));
      }

302 303 304 305 306 307 308 309 310 311 312 313 314
      cublasLtEpilogue_t epiloque_func_for_dx =
          get_epilogue_type_(activation_grad);
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
              dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
              &epiloque_func_for_dx, sizeof(epiloque_func_for_dx)));

      if (activation_grad != "none") {
        auto* aux_data = reserve_space->data<T>();
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
                &aux_data, sizeof(aux_data)));
R
root 已提交
315
        int64_t aux_ld = transpose_x ? M : K;
316 317
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
R
root 已提交
318 319
                dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
                &aux_ld, sizeof(aux_ld)));
320 321 322 323 324 325 326 327 328 329 330 331 332 333
      }

      memory::allocation::AllocationPtr dx_workspace =
          memory::Alloc(dev_ctx, workspace_size);

      dx->mutable_data<T>(ctx.GetPlace());
      auto* dx_data = dx->data<T>();
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
          lt_handle, dx_operation_desc, alpha, y->data<T>(), y_desc,
          dout->data<T>(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc,
          algo, dx_workspace->ptr(), workspace_size, stream));
    }

    if (dy) {
R
root 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
      cublasOperation_t trans_dout = transpose_y ? CUBLAS_OP_T : CUBLAS_OP_N;
      cublasOperation_t trans_x =
          (transpose_x ^ transpose_y) ? CUBLAS_OP_N : CUBLAS_OP_T;

      cublasLtMatrixLayout_t dout_desc_for_dx;
      if (trans_dout == CUBLAS_OP_T) {
        if (dout_trans_desc == nullptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutCreate(&dout_trans_desc,
                                                            mat_type, M, N, M));
        }
        dout_desc_for_dx = dout_trans_desc;
      } else {
        if (dout_desc == nullptr) {
          PADDLE_ENFORCE_GPU_SUCCESS(
              platform::dynload::cublasLtMatrixLayoutCreate(&dout_desc,
                                                            mat_type, N, M, N));
        }
        dout_desc_for_dx = dout_desc;
      }

355 356 357
      cublasLtMatmulDesc_t dy_operation_desc = NULL;
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
          &dy_operation_desc, compute_type, scale_type));
R
root 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383

      if (transpose_y) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout,
                sizeof(trans_dout)));
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_x,
                sizeof(trans_x)));
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout,
                sizeof(trans_dout)));
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x,
                sizeof(trans_x)));
      }

      cublasLtEpilogue_t epiloque_func_for_dy =
          dbias == nullptr ? CUBLASLT_EPILOGUE_DEFAULT
                           : (transpose_y ? CUBLASLT_EPILOGUE_BGRADB
                                          : CUBLASLT_EPILOGUE_BGRADA);

384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
      PADDLE_ENFORCE_GPU_SUCCESS(
          platform::dynload::cublasLtMatmulDescSetAttribute(
              dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
              &epiloque_func_for_dy, sizeof(epiloque_func_for_dy)));

      if (dbias) {
        dbias->mutable_data<T>(ctx.GetPlace());
        auto* dbias_data = dbias->data<T>();
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatmulDescSetAttribute(
                dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
                &dbias_data, sizeof(dbias_data)));
      }

      cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL;
R
root 已提交
399 400 401 402 403 404 405 406 407 408
      if (transpose_x) {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, M,
                                                          K, M));
      } else {
        PADDLE_ENFORCE_GPU_SUCCESS(
            platform::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, K,
                                                          M, K));
      }

409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
          &dy_desc, mat_type, N, K, N));

      memory::allocation::AllocationPtr dy_workspace =
          memory::Alloc(dev_ctx, workspace_size);

      dy->mutable_data<T>(ctx.GetPlace());
      auto* dy_data = dy->data<T>();
      PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
          lt_handle, dy_operation_desc, alpha, dout->data<T>(), dout_desc,
          x->data<T>(), x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, algo,
          dy_workspace->ptr(), workspace_size, stream));
    }
  }

 private:
  static cublasLtEpilogue_t get_epilogue_type_(
      const std::string& activation_grad) {
    if (activation_grad == "relu_grad") {
      return CUBLASLT_EPILOGUE_DRELU;
    } else if (activation_grad == "gelu_grad") {
      return CUBLASLT_EPILOGUE_DGELU;
    } else if (activation_grad == "none") {
      return CUBLASLT_EPILOGUE_DEFAULT;
    } else {
      PADDLE_ENFORCE_EQ(
          true, false,
          platform::errors::InvalidArgument(
              "The activation_grad attribute of fused_gemm_epilogue op should "
              "be"
              " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
              "But received activation_grad=%s.",
              activation_grad));
    }
  }
};

}  // namespace operators
}  // namespace paddle

#if CUDA_VERSION >= 11060
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
    fused_gemm_epilogue,
    ops::FusedGemmEpilogueKernel<paddle::platform::CUDADeviceContext, float>,
    ops::FusedGemmEpilogueKernel<paddle::platform::CUDADeviceContext, double>,
    ops::FusedGemmEpilogueKernel<paddle::platform::CUDADeviceContext,
                                 paddle::platform::float16>);

REGISTER_OP_CUDA_KERNEL(
    fused_gemm_epilogue_grad,
    ops::FusedGemmEpilogueGradKernel<paddle::platform::CUDADeviceContext,
                                     float>,
    ops::FusedGemmEpilogueGradKernel<paddle::platform::CUDADeviceContext,
                                     double>,
    ops::FusedGemmEpilogueGradKernel<paddle::platform::CUDADeviceContext,
                                     paddle::platform::float16>);
#endif