cast_kernel.cc 3.7 KB
Newer Older
C
chentianyu03 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/cast_kernel.h"
C
chentianyu03 已提交
16

17
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18 19 20 21
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
C
chentianyu03 已提交
22

23 24
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
C
chentianyu03 已提交
25

26
namespace phi {
C
chentianyu03 已提交
27 28 29 30 31 32 33

template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
                const DenseTensor& x,
                DataType out_dtype,
                DenseTensor* out) {
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
34
  using float16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
C
chentianyu03 已提交
35 36 37 38 39 40

  auto* in_data = x.data<T>();
  auto numel = x.numel();

  int r = -1;
  switch (out_dtype) {
41
    case phi::DataType::FLOAT32:
42
      r = xpu::cast<XPUInTDType, float>(
C
chentianyu03 已提交
43 44
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
45
          dev_ctx.template Alloc<float>(out),
C
chentianyu03 已提交
46 47
          numel);
      break;
48
    case phi::DataType::FLOAT16:
49
      r = xpu::cast<XPUInTDType, float16>(
C
chentianyu03 已提交
50 51 52
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
          reinterpret_cast<float16*>(
53
              dev_ctx.template Alloc<phi::dtype::float16>(out)),
C
chentianyu03 已提交
54 55
          numel);
      break;
56
    case phi::DataType::INT64:
57
      r = xpu::cast<XPUInTDType, int64_t>(
C
chentianyu03 已提交
58 59
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
60
          dev_ctx.template Alloc<int64_t>(out),
C
chentianyu03 已提交
61 62
          numel);
      break;
63
    case phi::DataType::INT32:
64
      r = xpu::cast<XPUInTDType, int32_t>(
C
chentianyu03 已提交
65 66
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
67
          dev_ctx.template Alloc<int>(out),
C
chentianyu03 已提交
68 69
          numel);
      break;
70
    case phi::DataType::BOOL:
71
      r = xpu::cast<XPUInTDType, bool>(
C
chentianyu03 已提交
72 73
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
74 75 76 77
          dev_ctx.template Alloc<bool>(out),
          numel);
      break;
    case phi::DataType::UINT8:
78
      r = xpu::cast<XPUInTDType, uint8_t>(
79 80 81
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
          dev_ctx.template Alloc<uint8_t>(out),
C
chentianyu03 已提交
82 83
          numel);
      break;
84
    case phi::DataType::FLOAT64:
85
      r = xpu::cast<XPUInTDType, double>(
86 87 88 89 90
          dev_ctx.x_context(),
          reinterpret_cast<const XPUInTDType*>(in_data),
          dev_ctx.template Alloc<double>(out),
          numel);
      break;
C
chentianyu03 已提交
91
    default:
92
      PADDLE_THROW(phi::errors::Unavailable(
C
chentianyu03 已提交
93 94 95
          "Not supported cast %d -> %d", x.dtype(), out_dtype));
  }

96
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
C
chentianyu03 已提交
97
}
98
}  // namespace phi
C
chentianyu03 已提交
99

100
PD_REGISTER_KERNEL(cast,
C
chentianyu03 已提交
101 102
                   XPU,
                   ALL_LAYOUT,
103
                   phi::CastKernel,
C
chentianyu03 已提交
104 105
                   int32_t,
                   float,
106
                   phi::dtype::float16,
C
chentianyu03 已提交
107
                   int64_t,
108 109 110
                   bool,
                   uint8_t,
                   double) {
C
chentianyu03 已提交
111 112
  kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}