diff --git a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu index 8457936eb65bc23aa1e824f8821a7cad1d823f74..6632d3f8b2ec9bbf9ab3c9ac3842ea878ac98290 100644 --- a/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu @@ -54,7 +54,7 @@ struct MaxFunctor { if (x > cap) { return cap; } - return x; + return x >= 0 ? x : 0; } }; diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h index 5ae9393dba0d79adfa6f8d4041ffa44e83c407dd..1bc841a6d8ba4fb8fa2f07fd9150138e47ef038d 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h @@ -1,5 +1,6 @@ // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// +// Copyright 2022 The DGL team for some useful functions. + // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -23,6 +24,10 @@ namespace phi { +#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF +#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF +#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF + inline void CopyBCastOff(const BroadCastInfo& bcast_info, thrust::device_vector& l_bcastoff, thrust::device_vector& r_bcastoff) { @@ -63,6 +68,37 @@ inline int FindNumThreads(int dim, int max_num_threads) { return res; } +inline int FindNumBlocks(char axis, int nblocks, int max_num_blocks = -1) { + int default_max_num_blocks = -1; + switch (axis) { + case 'x': + default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_X; + break; + case 'y': + default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_Y; + break; + case 'z': + default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_Z; + break; + default: + PADDLE_THROW( + phi::errors::InvalidArgument("%c axis is not recognized", axis)); + } + if (max_num_blocks == -1) { + max_num_blocks = default_max_num_blocks; + } + PADDLE_ENFORCE_GT( + max_num_blocks, + 0, + phi::errors::InvalidArgument("max_num_blocks should be larger than 0, " + "but received %d", + max_num_blocks)); + if (nblocks < max_num_blocks) { + return nblocks; + } + return max_num_blocks; +} + template struct GraphSendUERecvSumCUDAFunctor { DEVICE inline void operator()(T* output, T val) { diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu index c5d5fb7196fb238c78df099debe6feb6299f8f58..a1d522cc3d4d1da9f787ae3926c77bd12dfd86d9 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu @@ -52,7 +52,7 @@ void CalculateXEGradForMinMax(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid(nbx, nby); const dim3 block(ntx, nty); @@ -183,7 +183,7 @@ void CalculateXGrad(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid_(nbx, nby); const dim3 block_(ntx, nty); funcs::MultiplyFunctor mul_functor; @@ -306,7 +306,7 @@ void CalculateXGrad(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid_(nbx, nby); const dim3 block_(ntx, nty); if (!reduce) { @@ -392,7 +392,7 @@ void CalculateEGrad(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid(nbx, nby); const dim3 block(ntx, nty); if (reduce_op == "SUM") { diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu index 7351c562dff9d50db56b94da351d997743cb018f..8a5897316ca9cafbce2730526206b64861221b2d 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu @@ -81,6 +81,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, if (index_size == 0) return; const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims()); + const T* x_data = x.data(); const T* e_data = e.data(); const IndexT* s_index = src_index.data(); @@ -95,7 +96,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid(nbx, nby); const dim3 block(ntx, nty); int64_t input_size = x.dims()[0]; diff --git a/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu index 5b8d7b28dcc2979ffa1dd64a14e0fc74ee560b9d..d845e9cc4372aaf96db9c2fdeff1d28a0916139b 100644 --- a/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu @@ -73,7 +73,7 @@ void CalculateGrad(const Context& ctx, const int ntx = FindNumThreads(slice_size, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (slice_size + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid_tmp(nbx, nby); const dim3 block_tmp(ntx, nty); GraphSendUVGradCUDAKernel @@ -93,7 +93,7 @@ void CalculateGrad(const Context& ctx, FindNumThreads(bcast_info.out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (bcast_info.out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid_tmp(nbx, nby); const dim3 block_tmp(ntx, nty); GraphSendUVGradCUDAKernel @@ -133,7 +133,7 @@ void CalculateGrad(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid_(nbx, nby); const dim3 block_(ntx, nty); funcs::MultiplyFunctor mul_functor; diff --git a/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu index f1e4581773f5419cdf2748edfab8157ee4d04090..32b8b014d0c224b04d1f49bd6808d6c255a7dc5b 100644 --- a/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_uv_kernel.cu @@ -101,7 +101,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx, const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nbx = (out_len + ntx - 1) / ntx; - const int nby = (index_size + nty - 1) / nty; + const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty); const dim3 grid(nbx, nby); const dim3 block(ntx, nty); if (message_op == "ADD") {