未验证 提交 03ef0bdc 编写于 作者: S Siming Dai 提交者: GitHub

[Geometric] Fix cuda configuration error for message_passing api (#45315)

上级 0e384ade
......@@ -54,7 +54,7 @@ struct MaxFunctor {
if (x > cap) {
return cap;
}
return x;
return x >= 0 ? x : 0;
}
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Copyright 2022 The DGL team for some useful functions.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
......@@ -23,6 +24,10 @@
namespace phi {
#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF
#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF
#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF
inline void CopyBCastOff(const BroadCastInfo& bcast_info,
thrust::device_vector<int64_t>& l_bcastoff,
thrust::device_vector<int64_t>& r_bcastoff) {
......@@ -63,6 +68,37 @@ inline int FindNumThreads(int dim, int max_num_threads) {
return res;
}
inline int FindNumBlocks(char axis, int nblocks, int max_num_blocks = -1) {
int default_max_num_blocks = -1;
switch (axis) {
case 'x':
default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_X;
break;
case 'y':
default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_Y;
break;
case 'z':
default_max_num_blocks = CUDA_MAX_NUM_BLOCKS_Z;
break;
default:
PADDLE_THROW(
phi::errors::InvalidArgument("%c axis is not recognized", axis));
}
if (max_num_blocks == -1) {
max_num_blocks = default_max_num_blocks;
}
PADDLE_ENFORCE_GT(
max_num_blocks,
0,
phi::errors::InvalidArgument("max_num_blocks should be larger than 0, "
"but received %d",
max_num_blocks));
if (nblocks < max_num_blocks) {
return nblocks;
}
return max_num_blocks;
}
template <typename T>
struct GraphSendUERecvSumCUDAFunctor {
DEVICE inline void operator()(T* output, T val) {
......
......@@ -52,7 +52,7 @@ void CalculateXEGradForMinMax(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby);
const dim3 block(ntx, nty);
......@@ -183,7 +183,7 @@ void CalculateXGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty);
funcs::MultiplyFunctor<T> mul_functor;
......@@ -306,7 +306,7 @@ void CalculateXGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty);
if (!reduce) {
......@@ -392,7 +392,7 @@ void CalculateEGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby);
const dim3 block(ntx, nty);
if (reduce_op == "SUM") {
......
......@@ -81,6 +81,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
if (index_size == 0) return;
const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims());
const T* x_data = x.data<T>();
const T* e_data = e.data<T>();
const IndexT* s_index = src_index.data<IndexT>();
......@@ -95,7 +96,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby);
const dim3 block(ntx, nty);
int64_t input_size = x.dims()[0];
......
......@@ -73,7 +73,7 @@ void CalculateGrad(const Context& ctx,
const int ntx = FindNumThreads(slice_size, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (slice_size + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_tmp(nbx, nby);
const dim3 block_tmp(ntx, nty);
GraphSendUVGradCUDAKernel<T, IndexT>
......@@ -93,7 +93,7 @@ void CalculateGrad(const Context& ctx,
FindNumThreads(bcast_info.out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (bcast_info.out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_tmp(nbx, nby);
const dim3 block_tmp(ntx, nty);
GraphSendUVGradCUDAKernel<T, IndexT>
......@@ -133,7 +133,7 @@ void CalculateGrad(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid_(nbx, nby);
const dim3 block_(ntx, nty);
funcs::MultiplyFunctor<T> mul_functor;
......
......@@ -101,7 +101,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
const int nbx = (out_len + ntx - 1) / ntx;
const int nby = (index_size + nty - 1) / nty;
const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
const dim3 grid(nbx, nby);
const dim3 block(ntx, nty);
if (message_op == "ADD") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册