full_kernel.cc 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/full_kernel.h"
16

17
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18 19 20 21 22 23
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
24
#include "paddle/phi/core/visit_type.h"
25

26 27 28
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"

29
namespace phi {
30 31

template <typename InType, typename OutType>
32
void TensorSetConstantXPU(phi::DenseTensor* tensor,
33
                          InType value,
34
                          phi::Place place) {
35 36 37 38 39 40 41
  auto* begin = tensor->mutable_data<OutType>(place);
  int64_t numel = tensor->numel();
  std::unique_ptr<OutType[]> data_cpu(new OutType[numel]);
  std::fill(
      data_cpu.get(), data_cpu.get() + numel, static_cast<OutType>(value));
  paddle::memory::Copy(place,
                       begin,
42
                       phi::CPUPlace(),
43 44 45 46 47 48
                       static_cast<void*>(data_cpu.get()),
                       numel * sizeof(OutType));
}

template <typename T, typename Context, typename VType>
void FullValueXPU(const Context& dev_ctx, DenseTensor* tensor, VType val) {
49
  dev_ctx.template Alloc<T>(tensor);
50 51 52 53 54 55 56 57 58

  PD_VISIT_ALL_TYPES(tensor->dtype(), "FullValueXPU", ([&] {
                       TensorSetConstantXPU<VType, data_t>(
                           tensor, val, dev_ctx.GetPlace());
                     }));
}

template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
59
                const IntArray& shape,
60
                const Scalar& val,
61
                DataType dtype,
62
                DenseTensor* out) {
63
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
64
  out->Resize(phi::make_ddim(shape.GetData()));
65 66 67 68 69 70 71 72 73 74 75
  int numel = out->numel();
  dev_ctx.template Alloc<T>(out);
  auto value = val.to<double>();
  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
  if (numel > 0) {
    int r = xpu::constant(dev_ctx.x_context(),
                          out_data,
                          out->numel(),
                          static_cast<XPUInTDType>(value));
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
  }
76 77 78 79
}

template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
80
                    const DenseTensor& x,
81
                    const Scalar& val,
82
                    DataType dtype,
83
                    DenseTensor* out) {
84
  dev_ctx.template Alloc<T>(out);
85
  auto value = val.to<double>();
86 87 88
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
  using CommonType = typename std::common_type<
      float,
89
      typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
90 91 92 93 94 95 96 97 98 99 100
                                float,
                                T>::type>::type;

  auto common_type_value = static_cast<CommonType>(value);

  PADDLE_ENFORCE_EQ(
      (common_type_value >=
       static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
          (common_type_value <=
           static_cast<CommonType>(std::numeric_limits<T>::max())),
      true,
101
      phi::errors::InvalidArgument(
102 103 104 105 106 107 108 109 110 111
          "The filled value is out of range for target type, "
          "current kernel type is %s, the range should between %f "
          "and %f, but now value is %f.",
          typeid(T).name(),
          static_cast<CommonType>(std::numeric_limits<T>::lowest()),
          static_cast<CommonType>(std::numeric_limits<T>::max()),
          static_cast<float>(value)));

  PADDLE_ENFORCE_EQ(std::isnan(value),
                    false,
112
                    phi::errors::InvalidArgument("The filled value is NaN."));
113 114
  PADDLE_ENFORCE_EQ(std::isinf(value),
                    false,
115
                    phi::errors::InvalidArgument("The filled value is Inf."));
116 117

  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
118 119 120 121 122
  int r = xpu::constant(dev_ctx.x_context(),
                        out_data,
                        out->numel(),
                        static_cast<XPUInTDType>(value));
  PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
123 124
}

125
}  // namespace phi
126

127
PD_REGISTER_KERNEL(full,
128 129
                   XPU,
                   ALL_LAYOUT,
130
                   phi::FullKernel,
131 132 133 134 135 136
                   float,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t,
                   bool,
137
                   phi::dtype::float16) {}
138

139
PD_REGISTER_KERNEL(full_like,
140 141
                   XPU,
                   ALL_LAYOUT,
142
                   phi::FullLikeKernel,
143
                   float,
144 145
                   uint8_t,
                   int16_t,
146 147
                   int,
                   int64_t,
148
                   bool,
J
Jiabin Yang 已提交
149
                   phi::dtype::float16) {
150 151
  kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}