full_kernel.cc 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/full_kernel.h"
16

17 18 19 20 21 22 23
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/kernel_registry.h"
24

25 26 27
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"

28
namespace phi {
29 30

template <typename InType, typename OutType>
31
void TensorSetConstantXPU(phi::DenseTensor* tensor,
32
                          InType value,
33
                          phi::Place place) {
34 35 36 37 38 39 40
  auto* begin = tensor->mutable_data<OutType>(place);
  int64_t numel = tensor->numel();
  std::unique_ptr<OutType[]> data_cpu(new OutType[numel]);
  std::fill(
      data_cpu.get(), data_cpu.get() + numel, static_cast<OutType>(value));
  paddle::memory::Copy(place,
                       begin,
41
                       phi::CPUPlace(),
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
                       static_cast<void*>(data_cpu.get()),
                       numel * sizeof(OutType));
}

template <typename T, typename Context, typename VType>
void FullValueXPU(const Context& dev_ctx, DenseTensor* tensor, VType val) {
  tensor->mutable_data<T>(dev_ctx.GetPlace());

  PD_VISIT_ALL_TYPES(tensor->dtype(), "FullValueXPU", ([&] {
                       TensorSetConstantXPU<VType, data_t>(
                           tensor, val, dev_ctx.GetPlace());
                     }));
}

template <typename T, typename Context>
void FullKernel(const Context& dev_ctx,
                const ScalarArray& shape,
                const Scalar& val,
60
                DataType dtype,
61
                DenseTensor* out) {
62
  out->ResizeAndAllocate(phi::make_ddim(shape.GetData()));
63 64 65 66 67
  FullValueXPU<T>(dev_ctx, out, val.to<T>());
}

template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
68
                    const DenseTensor& x,
69
                    const Scalar& val,
70
                    DataType dtype,
71 72 73 74 75
                    DenseTensor* out) {
  auto value = val.to<float>();
  using XPUInTDType = typename XPUTypeTrait<T>::Type;
  using CommonType = typename std::common_type<
      float,
76
      typename std::conditional<std::is_same<T, phi::dtype::float16>::value,
77 78 79 80 81 82 83 84 85 86 87
                                float,
                                T>::type>::type;

  auto common_type_value = static_cast<CommonType>(value);

  PADDLE_ENFORCE_EQ(
      (common_type_value >=
       static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
          (common_type_value <=
           static_cast<CommonType>(std::numeric_limits<T>::max())),
      true,
88
      phi::errors::InvalidArgument(
89 90 91 92 93 94 95 96 97 98
          "The filled value is out of range for target type, "
          "current kernel type is %s, the range should between %f "
          "and %f, but now value is %f.",
          typeid(T).name(),
          static_cast<CommonType>(std::numeric_limits<T>::lowest()),
          static_cast<CommonType>(std::numeric_limits<T>::max()),
          static_cast<float>(value)));

  PADDLE_ENFORCE_EQ(std::isnan(value),
                    false,
99
                    phi::errors::InvalidArgument("The filled value is NaN."));
100 101
  PADDLE_ENFORCE_EQ(std::isinf(value),
                    false,
102
                    phi::errors::InvalidArgument("The filled value is Inf."));
103 104 105 106 107 108 109 110 111

  auto out_data = reinterpret_cast<XPUInTDType*>(out->data<T>());
  int ret = xpu::constant(dev_ctx.x_context(),
                          out_data,
                          out->numel(),
                          static_cast<XPUInTDType>(value));
  PADDLE_ENFORCE_EQ(
      ret,
      XPU_SUCCESS,
112 113 114
      phi::errors::External("XPU CONSTANT API return wrong value[%d %s].",
                            ret,
                            XPUAPIErrorMsg[ret]));
115 116
}

117
}  // namespace phi
118 119 120 121

PT_REGISTER_KERNEL(full,
                   XPU,
                   ALL_LAYOUT,
122
                   phi::FullKernel,
123 124 125 126 127 128 129
                   float,
                   double,
                   uint8_t,
                   int16_t,
                   int,
                   int64_t,
                   bool,
130 131 132 133
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>) {}
134 135 136 137

PT_REGISTER_KERNEL(full_like,
                   XPU,
                   ALL_LAYOUT,
138
                   phi::FullLikeKernel,
139 140 141
                   float,
                   int,
                   int64_t,
142
                   phi::dtype::float16) {}