未验证 提交 46e4fb2a 编写于 作者: L limingshu 提交者: GitHub

Performance fix for broadcast kernel [Part3] (#46071)

* first commit

* refine code with template argument

* refine code with template argument

* add ternary broadcast test file

* add ternary broadcast test file

* fix accoriding to ci

* fix op-benchmark ci error
上级 066633a5
......@@ -254,26 +254,85 @@ int GetVecsize(const std::vector<const DenseTensor *> &ins,
return std::min(out_vec_size, in_vec_size);
}
template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig &config,
int numel,
int num,
int need_broadcast,
int read_lens) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
#ifndef PADDLE_WITH_XPU_KP
template <typename T,
int VecSize,
int Arity,
bool IsBoundary,
bool is_all_broadcast>
struct BroadcastDataLoader {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<T, VecSize>(args[i], static_cast<T>(1.0f));
if (use_broadcast[i]) {
kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
dst, src, block_offset, config, numel, read_lens);
args[i], ins[i], block_offset, configs[i], numel, VecSize);
} else {
kps::ReadData<T, VecSize, 1, IsBoundary>(
dst, src + block_offset, num, read_lens);
args[i], ins[i] + block_offset, num, VecSize);
}
}
}
};
template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, true> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
uint32_t index_bc[Arity][VecSize];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
index_bc[j][k] = 0;
args[j][k] = static_cast<T>(1);
}
}
uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
uint32_t idx = thread_offset + k;
if (IsBoundary) {
if (idx == numel) break;
}
}
#pragma unroll
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
if (i == configs[0].kDims) break;
auto fast_divmoder = configs[0].divmoders[i].Divmod(idx);
idx = fast_divmoder.val[0];
#pragma unroll
for (int j = 0; j < Arity; ++j) {
index_bc[j][k] += fast_divmoder.val[1] * configs[j].strides[i];
}
}
}
#pragma unroll
for (int j = 0; j < Arity; ++j) {
#pragma unroll
for (int k = 0; k < VecSize; ++k) {
args[j][k] = ins[j][index_bc[j][k]];
}
}
}
};
#endif
template <typename InT,
typename OutT,
......@@ -281,12 +340,13 @@ template <typename InT,
int Arity,
int NumOuts,
int VecSize,
bool IsBoundary = false>
bool IsBoundary,
bool IsAllBroadcast = false>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
uint32_t numel,
const uint32_t numel,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
......@@ -294,19 +354,23 @@ __device__ void VectorizedBroadcastKernelImpl(
Functor func) {
__simd__ InT args[Arity][VecSize];
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#ifdef PADDLE_WITH_XPU_KP
#pragma unroll
for (int i = 0; i < Arity; ++i) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i],
read_lens);
if (use_broadcast[i]) {
kps::ReadDataBc<InT, VecSize, 1, IsBoundary>(
args[i], ins[i], block_offset, configs[i], numel, read_lens);
} else {
kps::ReadData<InT, VecSize, 1, IsBoundary>(
args[i], ins[i] + block_offset, num, read_lens);
}
}
#else
BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, IsAllBroadcast>()(
args, ins, configs, use_broadcast, block_offset, num, numel);
#endif
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
phi::funcs::ElementwisePrimitiveCaller<InT,
......@@ -321,12 +385,13 @@ __device__ void VectorizedBroadcastKernelImpl(
outs, result, block_offset, num, read_lens);
}
template <typename InT,
template <typename Functor,
typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
int VecSize,
bool IsAllBroadcast>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
......@@ -337,10 +402,9 @@ __global__ void VectorizedBroadcastKernel(
int tail_tid,
int read_lens,
Functor func) {
#ifdef PADDLE_WITH_XPU_KP
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
#ifdef PADDLE_WITH_XPU_KP
for (; block_offset < main_offset; block_offset += stride) {
VectorizedBroadcastKernelImpl<InT,
OutT,
......@@ -348,7 +412,8 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
false>(ins,
false,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
......@@ -366,7 +431,8 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
true>(ins,
true,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
......@@ -377,6 +443,7 @@ __global__ void VectorizedBroadcastKernel(
func);
}
#else
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
if (block_offset < main_offset) {
VectorizedBroadcastKernelImpl<InT,
OutT,
......@@ -384,7 +451,8 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
false>(ins,
false,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
......@@ -400,7 +468,8 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
true>(ins,
true,
IsAllBroadcast>(ins,
outs,
use_broadcast,
numel,
......@@ -425,6 +494,7 @@ void LaunchBroadcastKernel(
std::vector<DenseTensor *> *outs,
Functor func,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs) {
int broadcast_num = 0;
int numel = (*outs)[0]->numel();
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
......@@ -435,7 +505,12 @@ void LaunchBroadcastKernel(
}
for (int i = 0; i < Arity; ++i) {
use_broadcast[i] = (ins[i]->numel() != numel);
if (ins[i]->numel() != numel) {
broadcast_num++;
use_broadcast[i] = true;
} else {
use_broadcast[i] = false;
}
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
}
......@@ -446,6 +521,17 @@ void LaunchBroadcastKernel(
auto stream = ctx.x_context()->xpu_stream;
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);
VectorizedBroadcastKernel<Functor, InT, OutT, Arity, NumOuts, VecSize, false>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
#else
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
......@@ -456,8 +542,15 @@ void LaunchBroadcastKernel(
int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) *
read_lens * gpu_config.GetBlockSize();
int tail_tid = numel % (read_lens * gpu_config.GetBlockSize());
#endif
VectorizedBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize>
if (broadcast_num > (Arity >> 1)) {
VectorizedBroadcastKernel<Functor,
InT,
OutT,
Arity,
NumOuts,
VecSize,
(Arity > 1)>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
......@@ -467,6 +560,25 @@ void LaunchBroadcastKernel(
tail_tid,
read_lens,
func);
} else {
VectorizedBroadcastKernel<Functor,
InT,
OutT,
Arity,
NumOuts,
VecSize,
false>
<<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
func);
}
#endif
}
#ifndef PADDLE_WITH_XPU_KP
......@@ -985,6 +1097,7 @@ void BroadcastKernelForDifferentVecSize(
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
// if (ins[i]->numel() != (*outs)[0]->numel()) {
if (ins[i]->numel()) {
configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
......
......@@ -84,6 +84,10 @@ if(WITH_GPU)
test_math_function_gpu
SRCS test_math_function.cu
DEPS math_function)
nv_test(
test_broadcast_gpu
SRCS test_ternary_broadcast.cu
DEPS gtest)
endif()
if(WITH_ROCM)
hip_test(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
template <typename T>
struct AddTernary_1 {
inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
};
template <typename T>
struct AddTernary_2 {
inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
};
template <typename T>
struct AddTernary_3 {
inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
};
template <typename T>
void InitValue(T* data, size_t numel, const int val) {
for (auto i = 0; i < numel; ++i) {
data[i] = static_cast<T>(val);
}
}
template <typename T, typename Func>
void TestCase(const phi::GPUContext& dev_ctx,
const phi::DDim& dim1,
const phi::DDim& dim2,
const phi::DDim& dim3,
const phi::DDim& dim_out,
const size_t times,
Func compute) {
phi::DataType dtype = paddle::experimental::CppTypeToDataType<T>::Type();
const auto alloc_cpu =
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
const auto alloc_gpu =
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
auto in1 = std::make_shared<phi::DenseTensor>(
alloc_cpu.get(),
phi::DenseTensorMeta(dtype, dim1, phi::DataLayout::NCHW));
auto in2 = std::make_shared<phi::DenseTensor>(
alloc_cpu.get(),
phi::DenseTensorMeta(dtype, dim2, phi::DataLayout::NCHW));
auto in3 = std::make_shared<phi::DenseTensor>(
alloc_cpu.get(),
phi::DenseTensorMeta(dtype, dim3, phi::DataLayout::NCHW));
InitValue(in1->data<T>(), in1->numel(), 1);
InitValue(in2->data<T>(), in2->numel(), 1);
InitValue(in3->data<T>(), in3->numel(), 1);
auto d_in1 = std::make_shared<phi::DenseTensor>(
alloc_gpu.get(),
phi::DenseTensorMeta(dtype, dim1, phi::DataLayout::NCHW));
auto d_in2 = std::make_shared<phi::DenseTensor>(
alloc_gpu.get(),
phi::DenseTensorMeta(dtype, dim2, phi::DataLayout::NCHW));
auto d_in3 = std::make_shared<phi::DenseTensor>(
alloc_gpu.get(),
phi::DenseTensorMeta(dtype, dim3, phi::DataLayout::NCHW));
auto d_out = std::make_shared<phi::DenseTensor>(
alloc_gpu.get(),
phi::DenseTensorMeta(dtype, dim_out, phi::DataLayout::NCHW));
phi::Copy(dev_ctx, *in1.get(), phi::GPUPlace(), false, d_in1.get());
phi::Copy(dev_ctx, *in2.get(), phi::GPUPlace(), false, d_in2.get());
phi::Copy(dev_ctx, *in3.get(), phi::GPUPlace(), false, d_in3.get());
std::vector<const phi::DenseTensor*> inputs{
d_in1.get(), d_in2.get(), d_in3.get()};
std::vector<phi::DenseTensor*> outputs{d_out.get()};
for (int i = 0; i < times; ++i) {
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx, inputs, &outputs, -1, compute);
}
dev_ctx.Wait();
}
TEST(Broadcast, add) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto place = paddle::platform::CUDAPlace();
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* dev_ctx = static_cast<const phi::GPUContext*>(pool.GetByPlace(place));
size_t times = 10;
do {
auto dim1 = phi::make_ddim({1, 2048, 3584});
auto dim2 = phi::make_ddim({1, 2048, 1});
auto dim3 = phi::make_ddim({1, 1, 3584});
auto dim_out = phi::make_ddim({1, 2048, 3584});
TestCase<float>(
*dev_ctx, dim1, dim2, dim3, dim_out, times, AddTernary_1<float>());
TestCase<phi::dtype::float16>(*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_1<phi::dtype::float16>());
TestCase<phi::dtype::bfloat16>(*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_1<phi::dtype::bfloat16>());
} while (0);
do {
auto dim1 = phi::make_ddim({1, 256, 4, 256, 256});
auto dim2 = phi::make_ddim({1, 256, 1, 1, 256});
auto dim3 = phi::make_ddim({1, 1, 4, 256, 256});
auto dim_out = phi::make_ddim({1, 256, 4, 256, 256});
TestCase<float>(
*dev_ctx, dim1, dim2, dim3, dim_out, times, AddTernary_2<float>());
TestCase<phi::dtype::float16>(*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_2<phi::dtype::float16>());
TestCase<phi::dtype::bfloat16>(*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_2<phi::dtype::bfloat16>());
} while (0);
do {
auto dim1 = phi::make_ddim({1, 256, 256});
auto dim2 = phi::make_ddim({1, 1, 256});
auto dim3 = phi::make_ddim({1, 256, 1});
auto dim_out = phi::make_ddim({1, 256, 256});
TestCase<float>(
*dev_ctx, dim1, dim2, dim3, dim_out, times, AddTernary_3<float>());
TestCase<phi::dtype::float16>(*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_3<phi::dtype::float16>());
TestCase<phi::dtype::bfloat16>(*dev_ctx,
dim1,
dim2,
dim3,
dim_out,
times,
AddTernary_3<phi::dtype::bfloat16>());
} while (0);
#endif
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册