full_kernel.cc 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/full_kernel.h"
16

17
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18 19 20 21
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
22
#include "paddle/phi/common/memory_utils.h"
23 24
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
25
#include "paddle/phi/core/visit_type.h"
26

27
namespace phi {
28 29 30

template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
31
                const IntArray& shape,
32
                const Scalar& val,
33
                DataType dtype,
34
                DenseTensor* out) {
35
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
36
  out->Resize(phi::make_ddim(shape.GetData()));
37 38 39 40 41 42 43
  int numel = out->numel();
  dev_ctx.template Alloc<T>(out);
  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
  if (numel > 0) {
    int r = xpu::constant(dev_ctx.x_context(),
                          out_data,
                          out->numel(),
44
                          static_cast<XPUInTDType>(val.to<T>()));
45 46
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
  }
47 48 49 50
}

template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
51
                    const DenseTensor& x,
52
                    const Scalar& val,
53
                    DataType dtype,
54
                    DenseTensor* out) {
55
  dev_ctx.template Alloc<T>(out);
56
  auto value = val.to<double>();
57 58 59
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
  using CommonType = typename std::common_type<
      float,
60
      typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
61 62 63 64 65 66 67 68 69 70 71
                                float,
                                T>::type>::type;

  auto common_type_value = static_cast<CommonType>(value);

  PADDLE_ENFORCE_EQ(
      (common_type_value >=
       static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
          (common_type_value <=
           static_cast<CommonType>(std::numeric_limits<T>::max())),
      true,
72
      phi::errors::InvalidArgument(
73 74 75 76 77 78 79 80 81 82
          "The filled value is out of range for target type, "
          "current kernel type is %s, the range should between %f "
          "and %f, but now value is %f.",
          typeid(T).name(),
          static_cast<CommonType>(std::numeric_limits<T>::lowest()),
          static_cast<CommonType>(std::numeric_limits<T>::max()),
          static_cast<float>(value)));

  PADDLE_ENFORCE_EQ(std::isnan(value),
                    false,
83
                    phi::errors::InvalidArgument("The filled value is NaN."));
84 85
  PADDLE_ENFORCE_EQ(std::isinf(value),
                    false,
86
                    phi::errors::InvalidArgument("The filled value is Inf."));
87 88

  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
89 90 91 92 93 94 95
  if (out->numel() > 0) {
    int r = xpu::constant(dev_ctx.x_context(),
                          out_data,
                          out->numel(),
                          static_cast<XPUInTDType>(value));
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
  }
96 97
}

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
template <typename T, typename Context>
void FullBatchSizeLikeKernel(const Context& dev_ctx,
                             const DenseTensor& x,
                             const std::vector<int>& shape,
                             const Scalar& val,
                             DataType dtype,
                             int x_batch_size_dim,
                             int out_batch_size_dim,
                             DenseTensor* out) {
  if (x.lod().size() && x_batch_size_dim == 0) {
    // set the correct batch size for the LoDTensor.
    auto odims = out->dims();
    odims[out_batch_size_dim] = static_cast<int>(x.lod().back().size()) - 1;
    FullKernel<T, Context>(dev_ctx, phi::vectorize(odims), val, dtype, out);
  }
  FullLikeKernel<T, Context>(dev_ctx, x, val, dtype, out);
}
115
}  // namespace phi
116

117
PD_REGISTER_KERNEL(full,
118 119
                   XPU,
                   ALL_LAYOUT,
120
                   phi::FullKernel,
121
                   float,
122
                   int8_t,
123 124 125 126 127
                   uint8_t,
                   int16_t,
                   int,
                   int64_t,
                   bool,
128
                   phi::dtype::float16) {}
129

130
PD_REGISTER_KERNEL(full_like,
131 132
                   XPU,
                   ALL_LAYOUT,
133
                   phi::FullLikeKernel,
134
                   float,
135 136
                   uint8_t,
                   int16_t,
137 138
                   int,
                   int64_t,
139
                   bool,
J
Jiabin Yang 已提交
140
                   phi::dtype::float16) {
141 142
  kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
143 144 145 146 147 148 149 150 151 152 153 154

PD_REGISTER_KERNEL(full_batch_size_like,
                   XPU,
                   ALL_LAYOUT,
                   phi::FullBatchSizeLikeKernel,
                   float,
                   int,
                   int64_t,
                   bool,
                   phi::dtype::float16) {
  kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}