// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/linspace_kernel.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { template __global__ void LinspaceKernelInner( T start, T stop, double step, int64_t size, T* out) { int64_t index = blockIdx.x * blockDim.x + threadIdx.x; for (; index < size; index += blockDim.x * gridDim.x) { if (index < size / 2) { out[index] = static_cast(start + step * index); } else { out[index] = static_cast(stop - step * (size - index - 1)); } } } template __global__ void LinspaceSpecialKernel(T start, T* out) { out[0] = static_cast(start); } template T GetValue(const Context& ctx, const DenseTensor& x) { T value = static_cast(0); if (x.place() != CPUPlace()) { DenseTensor cpu_x; Copy(ctx, x, CPUPlace(), true, &cpu_x); value = cpu_x.data()[0]; } else { value = x.data()[0]; } return value; } template T GetValueOfExpectedType(const Context& ctx, const DenseTensor& x) { switch (x.dtype()) { case DataType::FLOAT32: return static_cast(GetValue(ctx, x)); case DataType::FLOAT64: return static_cast(GetValue(ctx, x)); case DataType::INT32: return static_cast(GetValue(ctx, x)); case DataType::INT64: return static_cast(GetValue(ctx, x)); case DataType::FLOAT16: return static_cast(GetValue(ctx, x)); case DataType::BFLOAT16: return static_cast(GetValue(ctx, x)); case DataType::BOOL: return static_cast(GetValue(ctx, x)); case DataType::INT16: return static_cast(GetValue(ctx, x)); case DataType::UINT8: return static_cast(GetValue(ctx, x)); default: PADDLE_THROW(phi::errors::Unimplemented( "Data type (%s) is not supported when casting data type.", x.dtype())); } } template void LinspaceKernel(const Context& ctx, const DenseTensor& start, const DenseTensor& stop, const DenseTensor& number, DataType dtype, DenseTensor* out) { T start_value = GetValueOfExpectedType(ctx, start); T stop_value = GetValueOfExpectedType(ctx, stop); int64_t num = GetValueOfExpectedType(ctx, number); PADDLE_ENFORCE_GT( num, 0, phi::errors::InvalidArgument("The num of linspace op should be larger " "than 0, but received num is %d", num)); out->Resize(phi::make_ddim({num})); T* out_data = ctx.template Alloc(out); auto stream = ctx.stream(); if (num != 1) { int block = 512; int grid = (num + block - 1) / block; double step = (static_cast(stop_value - start_value)) / (num - 1); LinspaceKernelInner<<>>( start_value, stop_value, step, num, out_data); } else { LinspaceSpecialKernel<<<1, 1, 0, stream>>>(start_value, out_data); } } } // namespace phi PD_REGISTER_KERNEL(linspace, GPU, ALL_LAYOUT, phi::LinspaceKernel, float, int32_t, int64_t, double) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); }