linspace_kernel.cu 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/linspace_kernel.h"

17
#include "paddle/phi/backends/gpu/gpu_context.h"
W
Wang Xin 已提交
18
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19
#include "paddle/phi/core/kernel_registry.h"
20
#include "paddle/phi/core/tensor_utils.h"
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T>
__global__ void LinspaceKernelInner(
    T start, T stop, double step, int64_t size, T* out) {
  int64_t index = blockIdx.x * blockDim.x + threadIdx.x;

  for (; index < size; index += blockDim.x * gridDim.x) {
    if (index < size / 2) {
      out[index] = static_cast<T>(start + step * index);
    } else {
      out[index] = static_cast<T>(stop - step * (size - index - 1));
    }
  }
}

template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) {
  out[0] = static_cast<T>(start);
}

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
template <typename T, typename Context>
T GetValue(const Context& ctx, const DenseTensor& x) {
  T value = static_cast<T>(0);
  if (x.place() != CPUPlace()) {
    DenseTensor cpu_x;
    Copy(ctx, x, CPUPlace(), true, &cpu_x);
    value = cpu_x.data<T>()[0];
  } else {
    value = x.data<T>()[0];
  }
  return value;
}

template <typename T, typename Context>
T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) {
  switch (x.dtype()) {
    case DataType::FLOAT32:
      return static_cast<T>(GetValue<float, Context>(ctx, x));
    case DataType::FLOAT64:
      return static_cast<T>(GetValue<double, Context>(ctx, x));
    case DataType::INT32:
      return static_cast<T>(GetValue<int32_t, Context>(ctx, x));
    case DataType::INT64:
      return static_cast<T>(GetValue<int64_t, Context>(ctx, x));
    case DataType::FLOAT16:
      return static_cast<T>(GetValue<phi::dtype::float16, Context>(ctx, x));
    case DataType::BFLOAT16:
      return static_cast<T>(GetValue<phi::dtype::bfloat16, Context>(ctx, x));
    case DataType::BOOL:
      return static_cast<T>(GetValue<bool, Context>(ctx, x));
    case DataType::INT16:
      return static_cast<T>(GetValue<int16_t, Context>(ctx, x));
    case DataType::UINT8:
      return static_cast<T>(GetValue<uint8_t, Context>(ctx, x));
    default:
      PADDLE_THROW(phi::errors::Unimplemented(
          "Data type (%s) is not supported when casting data type.",
          x.dtype()));
  }
}

85 86 87 88 89 90 91
template <typename T, typename Context>
void LinspaceKernel(const Context& ctx,
                    const DenseTensor& start,
                    const DenseTensor& stop,
                    const DenseTensor& number,
                    DataType dtype,
                    DenseTensor* out) {
92 93 94
  T start_value = GetValueOfExpectedType<T, Context>(ctx, start);
  T stop_value = GetValueOfExpectedType<T, Context>(ctx, stop);
  int64_t num = GetValueOfExpectedType<int64_t, Context>(ctx, number);
95 96 97 98 99 100 101 102 103 104 105 106 107

  PADDLE_ENFORCE_GT(
      num,
      0,
      phi::errors::InvalidArgument("The num of linspace op should be larger "
                                   "than 0, but received num is %d",
                                   num));

  out->Resize(phi::make_ddim({num}));
  T* out_data = ctx.template Alloc<T>(out);

  auto stream = ctx.stream();
  if (num != 1) {
108 109 110
    int block = 512;
    int grid = (num + block - 1) / block;
    double step = (static_cast<double>(stop_value - start_value)) / (num - 1);
111
    LinspaceKernelInner<T><<<grid, block, 0, stream>>>(
112
        start_value, stop_value, step, num, out_data);
113
  } else {
114
    LinspaceSpecialKernel<T><<<1, 1, 0, stream>>>(start_value, out_data);
115 116 117 118 119 120 121 122 123 124 125 126
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(linspace,
                   GPU,
                   ALL_LAYOUT,
                   phi::LinspaceKernel,
                   float,
                   int32_t,
                   int64_t,
127 128 129 130 131
                   double) {
  kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
  kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
  kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}