diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 2395966ad0dbe36dd358b884be0ad4509143f7ea..59c3df0fce5e503691818cc0781fc73d12b99848 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -254,26 +254,85 @@ int GetVecsize(const std::vector &ins, return std::min(out_vec_size, in_vec_size); } -template -__device__ __forceinline__ void LoadData( - T *dst, - const _ptr_ T *src, - uint32_t block_offset, - const kps::details::BroadcastConfig &config, - int numel, - int num, - int need_broadcast, - int read_lens) { - // numel : whole num of output - // num: how many data will be deal with in this time - if (need_broadcast) { - kps::ReadDataBc( - dst, src, block_offset, config, numel, read_lens); - } else { - kps::ReadData( - dst, src + block_offset, num, read_lens); +#ifndef PADDLE_WITH_XPU_KP +template +struct BroadcastDataLoader { + __device__ __forceinline__ void operator()( + T args[Arity][VecSize], + const phi::Array &ins, + const phi::Array &configs, + const phi::Array &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { +#pragma unroll + for (int i = 0; i < Arity; ++i) { + kps::Init(args[i], static_cast(1.0f)); + if (use_broadcast[i]) { + kps::ReadDataBc( + args[i], ins[i], block_offset, configs[i], numel, VecSize); + } else { + kps::ReadData( + args[i], ins[i] + block_offset, num, VecSize); + } + } } -} +}; + +template +struct BroadcastDataLoader { + __device__ __forceinline__ void operator()( + T args[Arity][VecSize], + const phi::Array &ins, + const phi::Array &configs, + const phi::Array &use_broadcast, + const int block_offset, + const int num, + const uint32_t numel) { + uint32_t index_bc[Arity][VecSize]; +#pragma unroll + for (int j = 0; j < Arity; ++j) { +#pragma unroll + for (int k = 0; k < VecSize; ++k) { + index_bc[j][k] = 0; + args[j][k] = static_cast(1); + } + } + + uint32_t thread_offset = block_offset + threadIdx.x * VecSize; +#pragma unroll + for (int k = 0; k < VecSize; ++k) { + uint32_t idx = thread_offset + k; + if (IsBoundary) { + if (idx == numel) break; + } + +#pragma unroll + for (int i = 0; i < phi::DDim::kMaxRank; ++i) { + if (i == configs[0].kDims) break; + auto fast_divmoder = configs[0].divmoders[i].Divmod(idx); + idx = fast_divmoder.val[0]; +#pragma unroll + for (int j = 0; j < Arity; ++j) { + index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i]; + } + } + } + +#pragma unroll + for (int j = 0; j < Arity; ++j) { +#pragma unroll + for (int k = 0; k < VecSize; ++k) { + args[j][k] = ins[j][index_bc[j][k]]; + } + } + } +}; +#endif template + bool IsBoundary, + bool IsAllBroadcast = false> __device__ void VectorizedBroadcastKernelImpl( const phi::Array &ins, phi::Array<_ptr_ OutT *, NumOuts> outs, const phi::Array &use_broadcast, - uint32_t numel, + const uint32_t numel, const phi::Array &configs, int num, int block_offset, @@ -294,19 +354,23 @@ __device__ void VectorizedBroadcastKernelImpl( Functor func) { __simd__ InT args[Arity][VecSize]; __simd__ ConditionalT result[VecSize]; - +#ifdef PADDLE_WITH_XPU_KP #pragma unroll for (int i = 0; i < Arity; ++i) { kps::Init(args[i], static_cast(1.0f), read_lens); - LoadData(args[i], - ins[i], - block_offset, - configs[i], - numel, - num, - use_broadcast[i], - read_lens); + if (use_broadcast[i]) { + kps::ReadDataBc( + args[i], ins[i], block_offset, configs[i], numel, read_lens); + } else { + kps::ReadData( + args[i], ins[i] + block_offset, num, read_lens); + } } +#else + BroadcastDataLoader()( + args, ins, configs, use_broadcast, block_offset, num, numel); +#endif + constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; phi::funcs::ElementwisePrimitiveCaller + int VecSize, + bool IsAllBroadcast> __global__ void VectorizedBroadcastKernel( phi::Array ins, phi::Array<_ptr_ OutT *, NumOuts> outs, @@ -337,10 +402,9 @@ __global__ void VectorizedBroadcastKernel( int tail_tid, int read_lens, Functor func) { +#ifdef PADDLE_WITH_XPU_KP int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; - -#ifdef PADDLE_WITH_XPU_KP for (; block_offset < main_offset; block_offset += stride) { VectorizedBroadcastKernelImpl(ins, - outs, - use_broadcast, - numel, - configs, - BLOCK_NUM_X * read_lens, - block_offset, - read_lens, - func); + false, + IsAllBroadcast>(ins, + outs, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * read_lens, + block_offset, + read_lens, + func); } int num = numel - block_offset; if (num > 0) { @@ -366,17 +431,19 @@ __global__ void VectorizedBroadcastKernel( Arity, NumOuts, VecSize, - true>(ins, - outs, - use_broadcast, - numel, - configs, - num, - block_offset, - read_lens, - func); + true, + IsAllBroadcast>(ins, + outs, + use_broadcast, + numel, + configs, + num, + block_offset, + read_lens, + func); } #else + int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; if (block_offset < main_offset) { VectorizedBroadcastKernelImpl(ins, - outs, - use_broadcast, - numel, - configs, - BLOCK_NUM_X * VecSize, - block_offset, - read_lens, - func); + false, + IsAllBroadcast>(ins, + outs, + use_broadcast, + numel, + configs, + BLOCK_NUM_X * VecSize, + block_offset, + read_lens, + func); } else { VectorizedBroadcastKernelImpl(ins, - outs, - use_broadcast, - numel, - configs, - tail_tid, - block_offset, - read_lens, - func); + true, + IsAllBroadcast>(ins, + outs, + use_broadcast, + numel, + configs, + tail_tid, + block_offset, + read_lens, + func); } #endif } @@ -425,6 +494,7 @@ void LaunchBroadcastKernel( std::vector *outs, Functor func, const phi::Array &configs) { + int broadcast_num = 0; int numel = (*outs)[0]->numel(); phi::Array use_broadcast; phi::Array ins_data; @@ -435,7 +505,12 @@ void LaunchBroadcastKernel( } for (int i = 0; i < Arity; ++i) { - use_broadcast[i] = (ins[i]->numel() != numel); + if (ins[i]->numel() != numel) { + broadcast_num++; + use_broadcast[i] = true; + } else { + use_broadcast[i] = false; + } ins_data[i] = (const _ptr_ InT *)(ins[i]->data()); } @@ -446,6 +521,17 @@ void LaunchBroadcastKernel( auto stream = ctx.x_context()->xpu_stream; int main_offset = (numel / (read_lens * threads)) * read_lens * threads; int tail_tid = numel % (read_lens * threads); + + VectorizedBroadcastKernel + <<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + read_lens, + func); #else auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize); @@ -456,17 +542,43 @@ void LaunchBroadcastKernel( int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) * read_lens * gpu_config.GetBlockSize(); int tail_tid = numel % (read_lens * gpu_config.GetBlockSize()); + + if (broadcast_num > (Arity >> 1)) { + VectorizedBroadcastKernel 1)> + <<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + read_lens, + func); + } else { + VectorizedBroadcastKernel + <<>>(ins_data, + outs_data, + use_broadcast, + numel, + configs, + main_offset, + tail_tid, + read_lens, + func); + } #endif - VectorizedBroadcastKernel - <<>>(ins_data, - outs_data, - use_broadcast, - numel, - configs, - main_offset, - tail_tid, - read_lens, - func); } #ifndef PADDLE_WITH_XPU_KP @@ -985,6 +1097,7 @@ void BroadcastKernelForDifferentVecSize( // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + // if (ins[i]->numel() != (*outs)[0]->numel()) { if (ins[i]->numel()) { configs[i] = kps::details::BroadcastConfig( merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); diff --git a/paddle/phi/tests/kernels/CMakeLists.txt b/paddle/phi/tests/kernels/CMakeLists.txt index d1c9d25483fec117ec614f90e46a7d1c40539cee..09349ef782bba0fa352f43e33b4c66a87ce1d290 100644 --- a/paddle/phi/tests/kernels/CMakeLists.txt +++ b/paddle/phi/tests/kernels/CMakeLists.txt @@ -84,6 +84,10 @@ if(WITH_GPU) test_math_function_gpu SRCS test_math_function.cu DEPS math_function) + nv_test( + test_broadcast_gpu + SRCS test_ternary_broadcast.cu + DEPS gtest) endif() if(WITH_ROCM) hip_test( diff --git a/paddle/phi/tests/kernels/test_ternary_broadcast.cu b/paddle/phi/tests/kernels/test_ternary_broadcast.cu new file mode 100644 index 0000000000000000000000000000000000000000..5a6338d889b450cc77d9a271e7fedb8c6f3f7776 --- /dev/null +++ b/paddle/phi/tests/kernels/test_ternary_broadcast.cu @@ -0,0 +1,176 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "glog/logging.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" + +template +struct AddTernary_1 { + inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; } +}; + +template +struct AddTernary_2 { + inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; } +}; + +template +struct AddTernary_3 { + inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; } +}; + +template +void InitValue(T* data, size_t numel, const int val) { + for (auto i = 0; i < numel; ++i) { + data[i] = static_cast(val); + } +} + +template +void TestCase(const phi::GPUContext& dev_ctx, + const phi::DDim& dim1, + const phi::DDim& dim2, + const phi::DDim& dim3, + const phi::DDim& dim_out, + const size_t times, + Func compute) { + phi::DataType dtype = paddle::experimental::CppTypeToDataType::Type(); + const auto alloc_cpu = + std::make_unique( + paddle::platform::CPUPlace()); + const auto alloc_gpu = + std::make_unique( + paddle::platform::CUDAPlace()); + + auto in1 = std::make_shared( + alloc_cpu.get(), + phi::DenseTensorMeta(dtype, dim1, phi::DataLayout::NCHW)); + auto in2 = std::make_shared( + alloc_cpu.get(), + phi::DenseTensorMeta(dtype, dim2, phi::DataLayout::NCHW)); + auto in3 = std::make_shared( + alloc_cpu.get(), + phi::DenseTensorMeta(dtype, dim3, phi::DataLayout::NCHW)); + InitValue(in1->data(), in1->numel(), 1); + InitValue(in2->data(), in2->numel(), 1); + InitValue(in3->data(), in3->numel(), 1); + + auto d_in1 = std::make_shared( + alloc_gpu.get(), + phi::DenseTensorMeta(dtype, dim1, phi::DataLayout::NCHW)); + auto d_in2 = std::make_shared( + alloc_gpu.get(), + phi::DenseTensorMeta(dtype, dim2, phi::DataLayout::NCHW)); + auto d_in3 = std::make_shared( + alloc_gpu.get(), + phi::DenseTensorMeta(dtype, dim3, phi::DataLayout::NCHW)); + auto d_out = std::make_shared( + alloc_gpu.get(), + phi::DenseTensorMeta(dtype, dim_out, phi::DataLayout::NCHW)); + phi::Copy(dev_ctx, *in1.get(), phi::GPUPlace(), false, d_in1.get()); + phi::Copy(dev_ctx, *in2.get(), phi::GPUPlace(), false, d_in2.get()); + phi::Copy(dev_ctx, *in3.get(), phi::GPUPlace(), false, d_in3.get()); + + std::vector inputs{ + d_in1.get(), d_in2.get(), d_in3.get()}; + std::vector outputs{d_out.get()}; + for (int i = 0; i < times; ++i) { + phi::funcs::BroadcastKernel( + dev_ctx, inputs, &outputs, -1, compute); + } + dev_ctx.Wait(); +} + +TEST(Broadcast, add) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + auto place = paddle::platform::CUDAPlace(); + phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance(); + auto* dev_ctx = static_cast(pool.GetByPlace(place)); + size_t times = 10; + + do { + auto dim1 = phi::make_ddim({1, 2048, 3584}); + auto dim2 = phi::make_ddim({1, 2048, 1}); + auto dim3 = phi::make_ddim({1, 1, 3584}); + auto dim_out = phi::make_ddim({1, 2048, 3584}); + TestCase( + *dev_ctx, dim1, dim2, dim3, dim_out, times, AddTernary_1()); + TestCase(*dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_1()); + TestCase(*dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_1()); + } while (0); + + do { + auto dim1 = phi::make_ddim({1, 256, 4, 256, 256}); + auto dim2 = phi::make_ddim({1, 256, 1, 1, 256}); + auto dim3 = phi::make_ddim({1, 1, 4, 256, 256}); + auto dim_out = phi::make_ddim({1, 256, 4, 256, 256}); + TestCase( + *dev_ctx, dim1, dim2, dim3, dim_out, times, AddTernary_2()); + TestCase(*dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_2()); + TestCase(*dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_2()); + } while (0); + + do { + auto dim1 = phi::make_ddim({1, 256, 256}); + auto dim2 = phi::make_ddim({1, 1, 256}); + auto dim3 = phi::make_ddim({1, 256, 1}); + auto dim_out = phi::make_ddim({1, 256, 256}); + TestCase( + *dev_ctx, dim1, dim2, dim3, dim_out, times, AddTernary_3()); + TestCase(*dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_3()); + TestCase(*dev_ctx, + dim1, + dim2, + dim3, + dim_out, + times, + AddTernary_3()); + } while (0); +#endif +}