未验证 提交 03ef0bdc 编写于 作者: S Siming Dai 提交者: GitHub

[Geometric] Fix cuda configuration error for message_passing api (#45315)

上级 0e384ade
...@@ -54,7 +54,7 @@ struct MaxFunctor { ...@@ -54,7 +54,7 @@ struct MaxFunctor {
if (x > cap) { if (x > cap) {
return cap; return cap;
} }
return x; return x >= 0 ? x : 0;
} }
}; };
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// // Copyright 2022 The DGL team for some useful functions.
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
...@@ -23,6 +24,10 @@ ...@@ -23,6 +24,10 @@
namespace phi { namespace phi {
#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF
#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF
#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF
inline void CopyBCastOff(const BroadCastInfo& bcast_info, inline void CopyBCastOff(const BroadCastInfo& bcast_info,
thrust::device_vector<int64_t>& l_bcastoff, thrust::device_vector<int64_t>& l_bcastoff,
thrust::device_vector<int64_t>& r_bcastoff) { thrust::device_vector<int64_t>& r_bcastoff) {
...@@ -63,6 +68,37 @@ inline int FindNumThreads(int dim, int max_num_threads) { ...@@ -63,6 +68,37 @@ inline int FindNumThreads(int dim, int max_num_threads) {
return res; return res;
} }
inline int FindNumBlocks(char axis, int nblocks, int max_num_blocks = -1) {
int default_max_num_blocks = -1;
switch (axis) {
case 'x':
default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_X;
break;
case 'y':
default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_Y;
break;
case 'z':
default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_Z;
break;
default:
PADDLE_THROW(
phi::errors::InvalidArgument("%c axis is not recognized", axis));
}
if (max_num_blocks == -1) {
max_num_blocks = default_max_num_blocks;
}
PADDLE_ENFORCE_GT(
max_num_blocks,
0,
phi::errors::InvalidArgument("max_num_blocks should be larger than 0, "
"but received %d",
max_num_blocks));
if (nblocks < max_num_blocks) {
return nblocks;
}
return max_num_blocks;
}
template <typename T> template <typename T>
struct GraphSendUERecvSumCUDAFunctor { struct GraphSendUERecvSumCUDAFunctor {
DEVICE inline void operator()(T* output, T val) { DEVICE inline void operator()(T* output, T val) {
......
...@@ -52,7 +52,7 @@ void CalculateXEGradForMinMax(const Context& ctx, ...@@ -52,7 +52,7 @@ void CalculateXEGradForMinMax(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby); const dim3 grid(nbx, nby);
const dim3 block(ntx, nty); const dim3 block(ntx, nty);
...@@ -183,7 +183,7 @@ void CalculateXGrad(const Context& ctx, ...@@ -183,7 +183,7 @@ void CalculateXGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_(nbx, nby); const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty); const dim3 block_(ntx, nty);
funcs::MultiplyFunctor<T> mul_functor; funcs::MultiplyFunctor<T> mul_functor;
...@@ -306,7 +306,7 @@ void CalculateXGrad(const Context& ctx, ...@@ -306,7 +306,7 @@ void CalculateXGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_(nbx, nby); const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty); const dim3 block_(ntx, nty);
if (!reduce) { if (!reduce) {
...@@ -392,7 +392,7 @@ void CalculateEGrad(const Context& ctx, ...@@ -392,7 +392,7 @@ void CalculateEGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby); const dim3 grid(nbx, nby);
const dim3 block(ntx, nty); const dim3 block(ntx, nty);
if (reduce_op == "SUM") { if (reduce_op == "SUM") {
......
...@@ -81,6 +81,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -81,6 +81,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
if (index_size == 0) return; if (index_size == 0) return;
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims()); const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims());
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
const T* e_data = e.data<T>(); const T* e_data = e.data<T>();
const IndexT* s_index = src_index.data<IndexT>(); const IndexT* s_index = src_index.data<IndexT>();
...@@ -95,7 +96,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -95,7 +96,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby); const dim3 grid(nbx, nby);
const dim3 block(ntx, nty); const dim3 block(ntx, nty);
int64_t input_size = x.dims()[0]; int64_t input_size = x.dims()[0];
......
...@@ -73,7 +73,7 @@ void CalculateGrad(const Context& ctx, ...@@ -73,7 +73,7 @@ void CalculateGrad(const Context& ctx,
const int ntx = FindNumThreads(slice_size, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(slice_size, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (slice_size + ntx - 1) / ntx; const int nbx = (slice_size + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_tmp(nbx, nby); const dim3 grid_tmp(nbx, nby);
const dim3 block_tmp(ntx, nty); const dim3 block_tmp(ntx, nty);
GraphSendUVGradCUDAKernel<T, IndexT> GraphSendUVGradCUDAKernel<T, IndexT>
...@@ -93,7 +93,7 @@ void CalculateGrad(const Context& ctx, ...@@ -93,7 +93,7 @@ void CalculateGrad(const Context& ctx,
FindNumThreads(bcast_info.out_len, ctx.GetMaxThreadsPerBlock()); FindNumThreads(bcast_info.out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (bcast_info.out_len + ntx - 1) / ntx; const int nbx = (bcast_info.out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_tmp(nbx, nby); const dim3 grid_tmp(nbx, nby);
const dim3 block_tmp(ntx, nty); const dim3 block_tmp(ntx, nty);
GraphSendUVGradCUDAKernel<T, IndexT> GraphSendUVGradCUDAKernel<T, IndexT>
...@@ -133,7 +133,7 @@ void CalculateGrad(const Context& ctx, ...@@ -133,7 +133,7 @@ void CalculateGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_(nbx, nby); const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty); const dim3 block_(ntx, nty);
funcs::MultiplyFunctor<T> mul_functor; funcs::MultiplyFunctor<T> mul_functor;
......
...@@ -101,7 +101,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx, ...@@ -101,7 +101,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock()); const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx; const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx; const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty; const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby); const dim3 grid(nbx, nby);
const dim3 block(ntx, nty); const dim3 block(ntx, nty);
if (message_op == "ADD") { if (message_op == "ADD") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册