full_kernel.cc 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/full_kernel.h"
16

17
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18 19 20 21 22 23
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
24
#include "paddle/phi/core/visit_type.h"
25

26
// See Note [ Why still include the fluid headers? ]
27
#include "paddle/phi/common/memory_utils.h"
28

29
namespace phi {
30 31 32

template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
33
                const IntArray& shape,
34
                const Scalar& val,
35
                DataType dtype,
36
                DenseTensor* out) {
37
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
38
  out->Resize(phi::make_ddim(shape.GetData()));
39 40 41 42 43 44 45
  int numel = out->numel();
  dev_ctx.template Alloc<T>(out);
  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
  if (numel > 0) {
    int r = xpu::constant(dev_ctx.x_context(),
                          out_data,
                          out->numel(),
46
                          static_cast<XPUInTDType>(val.to<T>()));
47 48
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
  }
49 50 51 52
}

template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
53
                    const DenseTensor& x,
54
                    const Scalar& val,
55
                    DataType dtype,
56
                    DenseTensor* out) {
57
  dev_ctx.template Alloc<T>(out);
58
  auto value = val.to<double>();
59 60 61
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
  using CommonType = typename std::common_type<
      float,
62
      typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
63 64 65 66 67 68 69 70 71 72 73
                                float,
                                T>::type>::type;

  auto common_type_value = static_cast<CommonType>(value);

  PADDLE_ENFORCE_EQ(
      (common_type_value >=
       static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
          (common_type_value <=
           static_cast<CommonType>(std::numeric_limits<T>::max())),
      true,
74
      phi::errors::InvalidArgument(
75 76 77 78 79 80 81 82 83 84
          "The filled value is out of range for target type, "
          "current kernel type is %s, the range should between %f "
          "and %f, but now value is %f.",
          typeid(T).name(),
          static_cast<CommonType>(std::numeric_limits<T>::lowest()),
          static_cast<CommonType>(std::numeric_limits<T>::max()),
          static_cast<float>(value)));

  PADDLE_ENFORCE_EQ(std::isnan(value),
                    false,
85
                    phi::errors::InvalidArgument("The filled value is NaN."));
86 87
  PADDLE_ENFORCE_EQ(std::isinf(value),
                    false,
88
                    phi::errors::InvalidArgument("The filled value is Inf."));
89 90

  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
91 92 93 94 95
  int r = xpu::constant(dev_ctx.x_context(),
                        out_data,
                        out->numel(),
                        static_cast<XPUInTDType>(value));
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
96 97
}

98
}  // namespace phi
99

100
PD_REGISTER_KERNEL(full,
101 102
                   XPU,
                   ALL_LAYOUT,
103
                   phi::FullKernel,
104 105 106 107 108 109
                   float,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t,
                   bool,
110
                   phi::dtype::float16) {}
111

112
PD_REGISTER_KERNEL(full_like,
113 114
                   XPU,
                   ALL_LAYOUT,
115
                   phi::FullLikeKernel,
116
                   float,
117 118
                   uint8_t,
                   int16_t,
119 120
                   int,
                   int64_t,
121
                   bool,
J
Jiabin Yang 已提交
122
                   phi::dtype::float16) {
123 124
  kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}