未验证 提交 3c14b094 编写于 作者: R Rayman 提交者: GitHub

【Hackathon No.34】优化 poisson op (#45160)

* 【Hackathon No.34】优化 poisson op

* [poisson] code style fix

* modify code style

* prevent from big number

* modify code style

* modify code style

* modify import

* modify import

* modify code style
上级 a012d426
...@@ -229,6 +229,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context, ...@@ -229,6 +229,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context,
return config; return config;
} }
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2];
}
} // namespace gpu } // namespace gpu
} // namespace backends } // namespace backends
} // namespace phi } // namespace phi
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#endif #endif
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/poisson_kernel.h" #include "paddle/phi/kernels/poisson_kernel.h"
...@@ -27,48 +28,38 @@ limitations under the License. */ ...@@ -27,48 +28,38 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename T> template <typename T>
struct PoissonCudaFunctor { __global__ void GetPoisson(
public: const T* in, T* out, const int N, unsigned int seed, unsigned int offset) {
PoissonCudaFunctor(const T* in, CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
T* out,
unsigned int seed,
unsigned int offset)
: in_(in), out_(out), seed_(seed), offset_(offset) {}
__device__ void operator()(int64_t idx) {
#ifdef __NVCC__ #ifdef __NVCC__
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed_, idx, offset_, &state); curand_init(seed, idx, offset, &state);
out_[idx] = static_cast<T>(curand_poisson(&state, in_[idx])); out[idx] = static_cast<T>(curand_poisson(&state, in[idx]));
#elif __HIPCC__ #elif __HIPCC__
hiprandStatePhilox4_32_10_t state; hiprandStatePhilox4_32_10_t state;
hiprand_init(seed_, idx, offset_, &state); hiprand_init(seed, idx, offset, &state);
out_[idx] = static_cast<T>(hiprand_poisson(&state, in_[idx])); out[idx] = static_cast<T>(hiprand_poisson(&state, in[idx]));
#endif #endif
} }
}
private:
const T* in_;
T* out_;
const unsigned int seed_;
const unsigned int offset_;
};
template <typename T, typename Context> template <typename T, typename Context>
void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { void PoissonKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out); T* out_data = ctx.template Alloc<T>(out);
auto size = x.numel(); const int size = x.numel();
const int kMaxBlockDim = 256;
int block_size = std::min(kMaxBlockDim, ctx.GetMaxThreadsPerBlock());
dim3 dim_block(block_size);
dim3 dim_grid((size + block_size - 1) / block_size);
phi::backends::gpu::LimitGridDim(ctx, &dim_grid);
auto gen_cuda = ctx.GetGenerator(); auto gen_cuda = ctx.GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(20); auto seed_offset = gen_cuda->IncrementOffset(20);
uint64_t seed = seed_offset.first; uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second; uint64_t offset = seed_offset.second;
GetPoisson<T><<<dim_grid, dim_block>>>(x_data, out_data, size, seed, offset);
phi::funcs::ForRange<Context> for_range(ctx, size);
PoissonCudaFunctor<T> functor(x_data, out_data, seed, offset);
for_range(functor);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册