// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/utils/array.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/nonzero_kernel.h" #include "paddle/phi/kernels/reshape_kernel.h" #include "paddle/phi/kernels/split_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #ifdef __NVCC__ #include #include #elif defined(__HIPCC__) #include #endif #endif namespace phi { namespace funcs { template phi::DenseTensor GetReshapeAndExpandTensor(const Context& dev_ctx, const phi::DenseTensor& tensor, const phi::DDim& res_dim, const phi::DDim& bd_dim, int index) { std::vector before_dims = phi::vectorize(tensor.dims()); std::vector mid_dims(res_dim.size(), 1); if (index == 0) { for (size_t i = 0; i < before_dims.size(); ++i) { mid_dims[bd_dim.size() - i - 1] = before_dims[before_dims.size() - i - 1]; } } else { mid_dims[index] = before_dims[0]; } phi::DenseTensor mid_tensor(tensor.dtype()); mid_tensor.Resize(phi::make_ddim(mid_dims)); ReshapeInferKernel(dev_ctx, tensor, IntArray(mid_dims), &mid_tensor); phi::DenseTensor res_tensor(tensor.dtype()); res_tensor.Resize(res_dim); ExpandKernel( dev_ctx, mid_tensor, IntArray(phi::vectorize(res_dim)), &res_tensor); return res_tensor; } template std::vector DealWithBoolIndices( const Context& dev_ctx, const std::vector& indices_v, std::vector* tmp_indices_v) { std::vector res(indices_v.begin(), indices_v.end()); bool contains_bool_tensor = false; for (size_t i = 0; i < indices_v.size(); ++i) { if (indices_v[i]->dtype() == phi::DataType::BOOL) { contains_bool_tensor = true; } else if ((indices_v[i]->dtype() == phi::DataType::INT64) || (indices_v[i]->dtype() == phi::DataType::INT32)) { PADDLE_ENFORCE_EQ( contains_bool_tensor, false, phi::errors::InvalidArgument( "indices contains bool tensor and int32/int64 tensor at the same " "time")); } else { PADDLE_THROW(phi::errors::InvalidArgument( "data type of tensor in indices must be int32, int64 or bool")); } } if (contains_bool_tensor) { if (indices_v.size() != 1) { PADDLE_THROW(phi::errors::InvalidArgument( "the size of indices must be 1 when it containts bool tensor")); } int rank = indices_v[0]->dims().size(); PADDLE_ENFORCE_GE( rank, 1UL, phi::errors::InvalidArgument("the only bool tensor in indices should " "have number of dimension at least 1")); phi::DenseTensor nonzero_indices(phi::DataType::INT64); nonzero_indices.Resize(phi::make_ddim({-1, rank})); NonZeroKernel(dev_ctx, *indices_v[0], &nonzero_indices); std::vector integer_indices(rank, nullptr); for (int i = 0; i < rank; ++i) { tmp_indices_v->emplace_back( DenseTensor(phi::DataType::INT64) .Resize(phi::make_ddim({nonzero_indices.dims()[0]}))); } for (int i = 0; i < rank; ++i) { integer_indices[i] = &((*tmp_indices_v)[i]); } SplitWithNumKernel( dev_ctx, nonzero_indices, rank, 1, integer_indices); std::vector res_tmp(integer_indices.size(), nullptr); for (int i = 0; i < rank; ++i) { res_tmp[i] = &((*tmp_indices_v)[i]); } res.swap(res_tmp); } return res; } static phi::DDim BroadCastTensorsDims( const std::vector& tensors) { int target_rank = 0; for (const auto& tensor : tensors) { target_rank = std::max(target_rank, tensor->dims().size()); } PADDLE_ENFORCE_GT(target_rank, 0, errors::InvalidArgument("BroadCastTensorsDims requires at " "least one input tensor to have " "rank greater than zero")); std::vector target_dims(target_rank, 0); for (int index = 0; index < target_rank; index++) { int target_dim_size = 1; for (const auto& tensor : tensors) { auto input_ddim = tensor->dims(); int axis = static_cast(input_ddim.size()) - index - 1; int dim_size = 1; if (axis >= 0) { dim_size = input_ddim[axis]; } if (target_dim_size != 1 && dim_size != 1 && target_dim_size != dim_size) { PADDLE_THROW(errors::InvalidArgument( "BroadCastTensorsDims inputs does not satisfy bcast semantics, " "please check axis = %d in reverse order", index)); } target_dim_size = dim_size == 1 ? target_dim_size : dim_size; } target_dims[target_rank - index - 1] = target_dim_size; } return phi::make_ddim(target_dims); } template T** GetDevicePointerArray(const Context& ctx, const std::vector& indices_v) { std::vector h_indices_v(indices_v.size()); for (int i = 0; i < indices_v.size(); ++i) { h_indices_v[i] = indices_v[i]->data(); } auto d_indices_data = phi::memory_utils::Alloc( ctx.GetPlace(), h_indices_v.size() * sizeof(T*), phi::Stream(reinterpret_cast(ctx.stream()))); phi::memory_utils::Copy(ctx.GetPlace(), d_indices_data->ptr(), phi::CPUPlace(), reinterpret_cast(h_indices_v.data()), h_indices_v.size() * sizeof(T*), ctx.stream()); return reinterpret_cast(d_indices_data->ptr()); } template void DealWithIndices(const Context& dev_ctx, const DenseTensor& x, const std::vector& int_indices_v, std::vector* res_indices_v, std::vector* tmp_res_indices_v, const std::vector& range_tensor_v, const phi::DDim& bd_dim, std::vector* res_dim_v) { size_t total_dims = x.dims().size(); if (int_indices_v.size() < total_dims) { std::vector tmp_x_dims = phi::vectorize(x.dims()); int len_bd_dim = bd_dim.size(); res_dim_v->insert(res_dim_v->end(), tmp_x_dims.begin() + int_indices_v.size(), tmp_x_dims.end()); std::vector reshaped_indices_v; for (size_t i = 0; i < int_indices_v.size(); ++i) { if (int_indices_v[i]->dtype() == phi::DataType::INT32) { reshaped_indices_v.emplace_back(phi::Cast( dev_ctx, *int_indices_v[i], phi::DataType::INT64)); } else { reshaped_indices_v.emplace_back(*int_indices_v[i]); } } reshaped_indices_v.insert( reshaped_indices_v.end(), range_tensor_v.begin(), range_tensor_v.end()); phi::DDim res_dim = phi::make_ddim(*res_dim_v); for (size_t i = 0; i < reshaped_indices_v.size(); ++i) { tmp_res_indices_v->emplace_back( GetReshapeAndExpandTensor( dev_ctx, reshaped_indices_v[i], res_dim, bd_dim, ((i < int_indices_v.size()) ? 0 : i - int_indices_v.size() + len_bd_dim))); } for (size_t i = 0; i < res_indices_v->size(); ++i) { (*res_indices_v)[i] = &(*tmp_res_indices_v)[i]; } } else { std::vector int_indices_v_tmp; for (size_t i = 0; i < int_indices_v.size(); ++i) { if (int_indices_v[i]->dtype() == phi::DataType::INT32) { int_indices_v_tmp.emplace_back(phi::Cast( dev_ctx, *int_indices_v[i], phi::DataType::INT64)); } else { int_indices_v_tmp.emplace_back(*int_indices_v[i]); } } for (size_t i = 0; i < int_indices_v.size(); ++i) { if (bd_dim != int_indices_v[i]->dims()) { tmp_res_indices_v->emplace_back( DenseTensor(phi::DataType::INT64).Resize(bd_dim)); ExpandKernel( dev_ctx, int_indices_v_tmp[i], IntArray(phi::vectorize(bd_dim)), &(*tmp_res_indices_v)[i]); } else { tmp_res_indices_v->emplace_back(int_indices_v_tmp[i]); } } for (size_t i = 0; i < res_indices_v->size(); ++i) { (*res_indices_v)[i] = &(*tmp_res_indices_v)[i]; } } } static void CalCompressedDimsWith1AndWithout1( std::vector* after_dims, std::vector* before_dims, std::vector* compress_dims, std::vector* dims_without_1) { int i = static_cast(after_dims->size()) - 1; int j = static_cast(before_dims->size()) - 1; if (i < j) { PADDLE_THROW(phi::errors::InvalidArgument( "shape of value can't not be broadcast to shape of x[indices]")); } while ((i >= 0) && (j >= 0)) { if ((*after_dims)[i] == (*before_dims)[j]) { dims_without_1->push_back((*before_dims)[j]); i--; j--; continue; } else if ((*before_dims)[j] == 1) { compress_dims->push_back(i); i--; j--; } else { PADDLE_THROW(phi::errors::InvalidArgument( "shape of value can't not be broadcast to shape of x[indices]")); } } while (i >= 0) { compress_dims->push_back(i); i--; } } #if defined(__NVCC__) || defined(__HIPCC__) template __global__ void range_cuda_kernel(int64_t N, T* out) { int64_t idx = threadIdx.x + blockDim.x * blockIdx.x; if (idx >= N) { return; } out[idx] = idx; } template phi::DenseTensor GetRangeCudaTensor(const Context& dev_ctx, int64_t N, phi::DataType dtype) { phi::DenseTensor res(dtype); res.Resize(phi::make_ddim({N})); DenseTensor* p_res = &res; T* out = dev_ctx.template Alloc(p_res); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, N); range_cuda_kernel <<>>( N, out); return res; } #endif template void range_kernel(int64_t N, T* out) { for (int64_t idx = 0; idx < N; ++idx) { out[idx] = idx; } } template phi::DenseTensor GetRangeTensor(const Context& dev_ctx, int64_t N, phi::DataType dtype) { phi::DenseTensor res(dtype); res.Resize(phi::make_ddim({N})); DenseTensor* p_res = &res; T* out = dev_ctx.template Alloc(p_res); range_kernel(N, out); return res; } } // namespace funcs } // namespace phi