diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index 147f716c126ec54ad88236102facfcc38e286107..59c8c9f3b8f0ed0569f40ab1476f629bd847d0c2 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/utils/dim.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -44,7 +43,7 @@ __global__ void GatherCUDAKernel(const T* params, template __global__ void GatherNdCUDAKernel(const T* input, - const int64_t* input_dims, + const Dim input_dims, const IndexT* indices, T* output, size_t remain_size, @@ -149,19 +148,11 @@ void GPUGatherNd(const phi::GPUContext& ctx, slice_size *= input_dims[i]; } // source dim - std::vector v_input_dims(input_dims_size); + Dim g_input_dims; for (int i = 0; i < input_dims_size; ++i) { - v_input_dims[i] = input_dims[i]; + g_input_dims[i] = input_dims[i]; } - phi::DenseTensor input_dims_tensor; - input_dims_tensor.Resize({input_dims_size}); - auto* g_input_dims = ctx.Alloc(&input_dims_tensor); - int64_t bytes = input_dims_size * sizeof(int64_t); - - paddle::memory::Copy( - gplace, g_input_dims, cplace, v_input_dims.data(), bytes, ctx.stream()); - int block = 512; int64_t n = slice_size * remain_numel; int64_t grid = (n + block - 1) / block; diff --git a/paddle/phi/kernels/funcs/scatter.cu.h b/paddle/phi/kernels/funcs/scatter.cu.h index 4d33c28e77f6bd28eff18964ee323233d521ea9f..254dd45edb596243a8867dee52708d5de6776bea 100644 --- a/paddle/phi/kernels/funcs/scatter.cu.h +++ b/paddle/phi/kernels/funcs/scatter.cu.h @@ -77,7 +77,7 @@ template __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, T* output, - const int64_t* output_dims, + const Dim output_dims, size_t remain_size, size_t slice_size, size_t end_size) { @@ -222,23 +222,12 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx, slice_size *= output_dims[i]; } const size_t slice_bytes = slice_size * sizeof(T); - // put output_dims int CUDA - // gplace and cplace - const auto gplace = ctx.GetPlace(); - auto cplace = phi::CPUPlace(); - std::vector v_output_dims(output_dims_size); + Dim g_output_dims; for (int i = 0; i < output_dims_size; ++i) { - v_output_dims[i] = output_dims[i]; + g_output_dims[i] = output_dims[i]; } - phi::DenseTensor out_dims_tensor; - out_dims_tensor.Resize({output_dims_size}); - auto* g_output_dims = ctx.Alloc(&out_dims_tensor); - int64_t bytes = output_dims_size * sizeof(int64_t); - paddle::memory::Copy( - gplace, g_output_dims, cplace, v_output_dims.data(), bytes, ctx.stream()); - int block = 512; int64_t n = slice_size * remain_numel; int64_t grid = (n + block - 1) / block; diff --git a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu index a393eecd51242193fa3b2192ff8e8f1111d350b6..c63063bc57846bd726c5f974b755f7ceb3f29918 100644 --- a/paddle/phi/kernels/gpu/index_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_grad_kernel.cu @@ -109,7 +109,6 @@ void IndexSelectGradKernel(const Context& ctx, stride, size, delta); - phi::backends::gpu::GpuStreamSync(stream); } else { const int* index_data = index.data(); index_select_grad_cuda_kernel<<< @@ -124,7 +123,6 @@ void IndexSelectGradKernel(const Context& ctx, stride, size, delta); - phi::backends::gpu::GpuStreamSync(stream); } } diff --git a/paddle/phi/kernels/gpu/index_select_kernel.cu b/paddle/phi/kernels/gpu/index_select_kernel.cu index f774522318acb8f44798030870886dd1dc7accc1..e82976d46e68b517327eea3486151166277bd981 100644 --- a/paddle/phi/kernels/gpu/index_select_kernel.cu +++ b/paddle/phi/kernels/gpu/index_select_kernel.cu @@ -82,7 +82,6 @@ void IndexSelectKernel(const Context& ctx, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data, numel, stride, size, delta); - phi::backends::gpu::GpuStreamSync(stream); } else { const int* index_data = index.data(); index_select_cuda_kernel< @@ -92,7 +91,6 @@ void IndexSelectKernel(const Context& ctx, 0, stream>>>( in_data, out_data, index_data, numel, stride, size, delta); - phi::backends::gpu::GpuStreamSync(stream); } } diff --git a/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu b/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu index eadd91773c00810e3f4187d079926028733a4945..2ae8911fde510f36d192a6020bef4d37fe25611b 100644 --- a/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu +++ b/paddle/phi/kernels/gpu/scatter_nd_add_kernel.cu @@ -26,7 +26,7 @@ void ScatterNdAddKernel(const Context &ctx, const DenseTensor &index, const DenseTensor &updates, DenseTensor *out) { - Copy(ctx, x, ctx.GetPlace(), true, out); + Copy(ctx, x, ctx.GetPlace(), false, out); const auto &index_type = index.dtype(); bool index_type_match = index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; diff --git a/paddle/phi/kernels/gpu/where_index_kernel.cu b/paddle/phi/kernels/gpu/where_index_kernel.cu index 616679057ffce29b8d911d56d5cf428801138589..3ff73ce8b3babce6ca46ce09575f765785837ee0 100644 --- a/paddle/phi/kernels/gpu/where_index_kernel.cu +++ b/paddle/phi/kernels/gpu/where_index_kernel.cu @@ -29,33 +29,32 @@ namespace cub = hipcub; #include "paddle/phi/core/kernel_registry.h" namespace phi { -template +template struct IndexFunctor { - T2 stride[phi::DDim::kMaxRank]; - int dims; + IndexT strides[phi::DDim::kMaxRank]; + int rank; + explicit IndexFunctor(const phi::DDim &in_dims) { - dims = in_dims.size(); - std::vector strides_in_tmp; - strides_in_tmp.resize(dims, 1); - // get strides according to in_dims - for (T2 i = 1; i < dims; i++) { - strides_in_tmp[i] = strides_in_tmp[i - 1] * in_dims[dims - i]; + rank = in_dims.size(); + // Get strides according to in_dims + strides[0] = 1; + for (IndexT i = 1; i < rank; i++) { + strides[i] = strides[i - 1] * in_dims[rank - i]; } - memcpy(stride, strides_in_tmp.data(), dims * sizeof(T2)); } HOSTDEVICE inline void operator()(OutT *out, - const T1 *mask, - const T2 *index, + const MaskT *mask, + const IndexT *index, const int num) { int store_fix = 0; for (int idx = 0; idx < num; idx++) { if (mask[idx]) { - T2 data_index = index[idx]; + IndexT data_index = index[idx]; // get index - for (int rank_id = dims - 1; rank_id >= 0; --rank_id) { - out[store_fix] = static_cast(data_index / stride[rank_id]); - data_index = data_index % stride[rank_id]; + for (int rank_id = rank - 1; rank_id >= 0; --rank_id) { + out[store_fix] = static_cast(data_index / strides[rank_id]); + data_index = data_index % strides[rank_id]; store_fix++; } }