未验证 提交 057ba778 编写于 作者: L limingshu 提交者: GitHub

H2D data transfer optimization for split kernel (#49086)

* profile reduce kernel for fp16 and reduceHigherdim

* use reinterpret_cast

* fix for CI on ROCm

* add Macro for ROCm

* ROCm CI config

* ROCm CI config

* unit test repair

* pull

* add common_funcs.h

* reduceType

* Update reduce_function.h

* not higher

* rename

* implement of matmul using cublasLt instead of cublas

* cublasLt bugfix

* Update matmul_kernel_impl.h

* Update matmul_kernel_impl_via_blasLt.h

* for-loop-algo

* PR comments changes

* add macro

* ci unused variable isCublasLt

* ci unused variable isCublasLt macro

* split matmul to autotune

* rewrite the split kernel with segmented_array

* rewrite the split kernel with segmented_array

* rewrite the split kernel with segmented_array

* add some method for cuda_graph

* fix bugs for rocm

* change for ci-error

* i dont know why ci-model-benchmark gives a shit error, so i recover codes with original one to see if original codes work.

* add some changes for passing mode_benchmark and coverage ci

* fix ci error

* fix ci-rocm error

* add some changes for header

---------
Co-authored-by: Nzhangbopd <1299246947@qq.com>
Co-authored-by: NBo Zhang <105368690+zhangbopd@users.noreply.github.com>
上级 dc1b6511
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -45,6 +45,12 @@ static inline void GetBlockDims(const phi::GPUContext& context, ...@@ -45,6 +45,12 @@ static inline void GetBlockDims(const phi::GPUContext& context,
*grid_dims = dim3(grid_cols, grid_rows, 1); *grid_dims = dim3(grid_cols, grid_rows, 1);
} }
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x)
#endif
template <typename T, int Size> template <typename T, int Size>
struct PointerWrapper { struct PointerWrapper {
public: public:
...@@ -55,12 +61,29 @@ struct PointerWrapper { ...@@ -55,12 +61,29 @@ struct PointerWrapper {
PointerWrapper(const phi::GPUContext& ctx, PointerWrapper(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins, const std::vector<phi::DenseTensor>& ins,
const T** pre_alloced_host_ptr) { const T** pre_alloced_host_ptr) {
SetInputAddr(ins);
}
protected:
void SetInputAddr(const std::vector<phi::DenseTensor>& ins) {
for (auto i = 0; i < ins.size(); ++i) { for (auto i = 0; i < ins.size(); ++i) {
ins_addr[i] = ins[i].data(); ins_addr[i] = ins[i].data();
} }
} }
}; };
template <typename T, int Size>
struct PADDLE_ALIGN(256) AlignedPointerWrapper
: public PointerWrapper<T, Size> {
public:
AlignedPointerWrapper() {}
AlignedPointerWrapper(const phi::GPUContext& ctx,
const std::vector<phi::DenseTensor>& ins,
const T** pre_alloced_host_ptr) {
this->SetInputAddr(ins);
}
};
template <typename T> template <typename T>
struct PointerToPointer { struct PointerToPointer {
public: public:
...@@ -93,7 +116,7 @@ struct PointerToPointer { ...@@ -93,7 +116,7 @@ struct PointerToPointer {
}; };
template <typename T, typename IndexT, int Size> template <typename T, typename IndexT, int Size>
struct PointerAndColWrapper { struct PADDLE_ALIGN(256) PointerAndColWrapper {
public: public:
IndexT col_length[Size]; IndexT col_length[Size];
PointerAndColWrapper(const phi::GPUContext& ctx, PointerAndColWrapper(const phi::GPUContext& ctx,
...@@ -151,6 +174,8 @@ struct PointerToPointerAndCol { ...@@ -151,6 +174,8 @@ struct PointerToPointerAndCol {
PointerToPointer<T> ins_ptr_wrapper; PointerToPointer<T> ins_ptr_wrapper;
}; };
#undef PADDLE_ALIGN
template <int MovSize> template <int MovSize>
struct alignas(MovSize) Packed { struct alignas(MovSize) Packed {
__device__ Packed() { __device__ Packed() {
...@@ -360,7 +385,7 @@ void DispatchConcatWithSameShapeKernelLimitNum( ...@@ -360,7 +385,7 @@ void DispatchConcatWithSameShapeKernelLimitNum(
#define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \ #define IMPL_CONCAT_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \ case size_: { \
PointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \ AlignedPointerWrapper<T, size_> ptr_array(ctx, ins, inputs_data); \
__VA_ARGS__; \ __VA_ARGS__; \
} break; } break;
...@@ -519,108 +544,6 @@ void DispatchConcatKernel(const phi::GPUContext& ctx, ...@@ -519,108 +544,6 @@ void DispatchConcatKernel(const phi::GPUContext& ctx,
} }
} }
template <typename T>
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t* out_cols,
int out_cols_size,
T** outputs_data) {
int64_t curr_segment = 0;
int64_t curr_offset = out_cols[0];
CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) {
int64_t curr_col_offset = out_cols[curr_segment + 1];
while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
curr_col_offset = out_cols[curr_segment + 1];
}
int64_t local_col = tid_x - curr_offset;
int64_t segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs_data[curr_segment];
if (output_ptr != nullptr) {
int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input_data[tid_y * in_col + tid_x];
}
}
}
template <typename T>
__device__ void SplitKernelDetail(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T** outputs_data) {
CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) {
int64_t split = tid_x / fixed_out_col;
int64_t in_offset = tid_x - split * fixed_out_col;
T* output_ptr = outputs_data[split];
if (output_ptr != nullptr) {
int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * in_col + tid_x];
}
}
}
template <typename T>
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T** outputs_data) {
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
template <typename T>
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1) {
T* outputs_data[2];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
template <typename T>
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1,
T* outputs_addr2) {
T* outputs_data[3];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
outputs_data[2] = outputs_addr2;
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
template <typename T>
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1,
T* outputs_addr2,
T* outputs_addr3) {
T* outputs_data[4];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
outputs_data[2] = outputs_addr2;
outputs_data[3] = outputs_addr3;
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
/* /*
* All tensors' dimension should be the same and the values of * All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension. * each dimension must be the same, except the axis dimension.
...@@ -708,37 +631,152 @@ struct ConcatFunctor<phi::GPUContext, T> { ...@@ -708,37 +631,152 @@ struct ConcatFunctor<phi::GPUContext, T> {
} }
}; };
template <typename T> template <typename T, typename IndexT, funcs::SegmentedArraySize Size>
class SplitFunctor<phi::GPUContext, T> { struct PointerAndColArray
: public funcs::PointerArraySetter<phi::GPUContext, T, Size> {
public: public:
void operator()(const phi::GPUContext& context, funcs::ValueArray<IndexT, Size> val_array;
const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& ref_inputs, PointerAndColArray() {}
int axis, PointerAndColArray(const phi::GPUContext& ctx,
std::vector<phi::DenseTensor*>* outputs) { const int out_col_num,
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3 IndexT* out_cols,
// tensors of shape [0,1,4] std::vector<DenseTensor*>* t,
if (input.numel() == 0) { T** pre_alloc_host_buf = nullptr)
return; : funcs::PointerArraySetter<phi::GPUContext, T, Size>(
ctx,
t,
/*need_alloc=*/false,
/*use_cuda_graph=*/true,
pre_alloc_host_buf) {
IndexT* dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = out_col_num * sizeof(IndexT);
dev_ptr = reinterpret_cast<IndexT*>(this->AllocAndCopy(
ctx, reinterpret_cast<void*>(out_cols), num_bytes, true));
val_array.Set(dev_ptr, out_col_num);
} else {
val_array.Set(out_cols, out_col_num);
}
}
};
template <typename T, typename IndexT, typename DataArrayT>
__global__ void SplitTensorWithSameShape(const T* input_data,
const IndexT out_row,
const IndexT cumulative_col,
const IndexT fixed_out_col,
DataArrayT data_array) {
CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) {
IndexT split = tid_x / fixed_out_col;
IndexT in_offset = tid_x - split * fixed_out_col;
T* output_ptr = data_array.data[split];
if (output_ptr != nullptr) {
IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * cumulative_col + tid_x];
}
}
}
template <typename T, typename IndexT, typename DataArrayT, typename ValArrayT>
__global__ void SplitTensorWithDifferentShape(const T* input_data,
const IndexT out_row,
const IndexT cumulative_col,
DataArrayT data_array,
ValArrayT col_array) {
IndexT curr_segment = 0;
IndexT curr_offset = col_array.data[0];
CUDA_KERNEL_LOOP_TYPE(tid_x, cumulative_col, IndexT) {
IndexT curr_col_offset = col_array.data[curr_segment + 1];
while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
curr_col_offset = col_array.data[curr_segment + 1];
}
IndexT local_col = tid_x - curr_offset;
IndexT segment_width = curr_col_offset - curr_offset;
T* output_ptr = data_array.data[curr_segment];
if (output_ptr != nullptr) {
IndexT tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < out_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input_data[tid_y * cumulative_col + tid_x];
}
} }
}
template <typename T, typename IndexT, funcs::SegmentedArraySize Size>
void SplitFunctionDispatchWithSameShape(const phi::GPUContext& ctx,
const IndexT out_col,
const IndexT out_row,
const IndexT cumulative_col,
const T* input_data,
std::vector<phi::DenseTensor*>* outs,
T** pre_alloc_host_buf) {
dim3 grid_dims;
dim3 block_dims;
GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims);
funcs::PointerArraySetter<phi::GPUContext, T, Size> setter(
ctx,
outs,
/*need_alloc=*/false,
/*use_cuda_graph=*/true,
pre_alloc_host_buf);
SplitTensorWithSameShape<T, IndexT, decltype(setter.array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
input_data, out_row, cumulative_col, out_col, setter.array);
}
template <typename T, typename IndexT, funcs::SegmentedArraySize Size>
void SplitFunctionDispatchWithDifferentShape(
const phi::GPUContext& ctx,
const int out_col_num,
const IndexT out_row,
const IndexT cumulative_col,
const T* input_data,
std::vector<phi::DenseTensor*>* outs,
IndexT* output_cols,
T** pre_alloc_host_buf) {
dim3 grid_dims;
dim3 block_dims;
GetBlockDims(ctx, out_row, cumulative_col, &block_dims, &grid_dims);
PointerAndColArray<T, IndexT, Size> setter(
ctx, out_col_num, output_cols, outs, pre_alloc_host_buf);
SplitTensorWithDifferentShape<T,
IndexT,
decltype(setter.array),
decltype(setter.val_array)>
<<<grid_dims, block_dims, 0, ctx.stream()>>>(
input_data, out_row, cumulative_col, setter.array, setter.val_array);
}
template <typename T, typename IndexT>
void SplitFunctorDispatchWithIndexType(
const phi::GPUContext& ctx,
int axis,
const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& ref_ins,
std::vector<phi::DenseTensor*>* outs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int out_num = outs->size();
int64_t out_row = 1; IndexT out_row = 1;
auto dim_0 = ref_inputs[0]->dims(); auto ref_dim = ref_ins[0]->dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
out_row *= dim_0[i]; out_row *= ref_dim[i];
} }
IndexT out_col = ref_ins[0]->numel() / out_row;
int64_t out0_col = ref_inputs[0]->numel() / out_row; IndexT cumulative_col = 0;
int64_t in_col = 0, in_row = out_row;
bool has_same_shape = true; bool has_same_shape = true;
int outputs_cols_num = o_num + 1; int out_cols_num = out_num + 1;
std::vector<T*> outputs_data_vec(o_num); std::vector<IndexT> outputs_cols_vec(out_cols_num, 0);
std::vector<int64_t> outputs_cols_vec(outputs_cols_num); IndexT* outs_cols = outputs_cols_vec.data();
T** outputs_data = outputs_data_vec.data(); T** outs_data = nullptr;
int64_t* outputs_cols = outputs_cols_vec.data();
// There are some differences between hip runtime and NV runtime. // There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from // In NV, when the pageable memory data less than 64K is transferred from
...@@ -751,125 +789,87 @@ class SplitFunctor<phi::GPUContext, T> { ...@@ -751,125 +789,87 @@ class SplitFunctor<phi::GPUContext, T> {
paddle::memory::AllocationPtr data_alloc, cols_alloc; paddle::memory::AllocationPtr data_alloc, cols_alloc;
// TODO(chentianyu03): try to find a method to remove the Alloc function // TODO(chentianyu03): try to find a method to remove the Alloc function
data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), data_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
o_num * sizeof(T*)); out_num * sizeof(T*));
outputs_data = reinterpret_cast<T**>(data_alloc->ptr()); outs_data = reinterpret_cast<T**>(data_alloc->ptr());
// TODO(chentianyu03): try to find a method to remove the Alloc function // TODO(chentianyu03): try to find a method to remove the Alloc function
cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(), cols_alloc = paddle::memory::Alloc(paddle::platform::CUDAPinnedPlace(),
(outputs_cols_num) * sizeof(int64_t)); (out_cols_num) * sizeof(IndexT));
outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr()); outs_cols = reinterpret_cast<IndexT*>(cols_alloc->ptr());
#endif #endif
outputs_cols[0] = 0; outs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) { for (int i = 0; i < out_num; ++i) {
int64_t t_col = ref_inputs.at(i)->numel() / out_row; IndexT t_col = ref_ins.at(i)->numel() / out_row;
if (has_same_shape) { if (has_same_shape) {
if (t_col != out0_col) has_same_shape = false; has_same_shape &= (t_col == cumulative_col);
}
in_col += t_col;
outputs_cols[i + 1] = in_col;
if (outputs->at(i) != nullptr) {
outputs_data[i] = outputs->at(i)->data<T>();
} else {
outputs_data[i] = nullptr;
}
} }
cumulative_col += t_col;
dim3 block_dims; outs_cols[i + 1] = cumulative_col;
dim3 grid_dims;
GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);
paddle::memory::allocation::AllocationPtr tmp_dev_outs_data;
T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) {
// TODO(chentianyu03): try to find a method to remove the Alloc function
tmp_dev_outs_data = paddle::memory::Alloc(
context.GetPlace(),
o_num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(context.stream())));
auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
outputs_data, o_num);
paddle::memory::Copy(context.GetPlace(),
tmp_dev_outs_data->ptr(),
phi::CPUPlace(),
restored,
o_num * sizeof(T*),
context.stream());
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
} }
int limit_num = has_same_shape ? out_num : out_cols_num;
if (has_same_shape) { if (has_same_shape) {
if (o_num == 2) { switch (funcs::CalcArraySize(limit_num)) {
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>( SEGMENTED_ARRAY_KERNEL_HELPER(
input.data<T>(), SplitFunctionDispatchWithSameShape<T, IndexT, kArraySize>(
in_row, ctx,
in_col, out_col,
out0_col, out_row,
outputs_data[0], cumulative_col,
outputs_data[1]);
} else if (o_num == 3) {
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
out0_col,
outputs_data[0],
outputs_data[1],
outputs_data[2]);
} else if (o_num == 4) {
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), input.data<T>(),
in_row, outs,
in_col, outs_data));
out0_col,
outputs_data[0],
outputs_data[1],
outputs_data[2],
outputs_data[3]);
} else {
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
} }
} else { } else {
auto tmp_dev_ins_col_data = switch (funcs::CalcArraySize(limit_num)) {
// TODO(chentianyu03): try to find a method to remove the Alloc SEGMENTED_ARRAY_KERNEL_HELPER(
// function SplitFunctionDispatchWithDifferentShape<T, IndexT, kArraySize>(
paddle::memory::Alloc( ctx,
context.GetPlace(), out_cols_num,
outputs_cols_num * sizeof(int64_t), out_row,
phi::Stream(reinterpret_cast<phi::StreamId>(context.stream()))); cumulative_col,
auto* restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph(
outputs_cols, outputs_cols_num);
paddle::memory::Copy(context.GetPlace(),
tmp_dev_ins_col_data->ptr(),
phi::CPUPlace(),
restored,
outputs_cols_num * sizeof(int64_t),
context.stream());
int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), input.data<T>(),
in_row, outs,
in_col, outs_cols,
dev_outs_col_data, outs_data));
static_cast<int>(outputs_cols_num), }
dev_out_gpu_data);
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory // Prevent pinned memory from being covered and release the memory after
// after the launch kernel of the stream is executed (reapply pinned memory // kernel launch of the stream is executed (reapply pinned memory next time)
// next time)
auto* data_alloc_released = data_alloc.release(); auto* data_alloc_released = data_alloc.release();
auto* cols_alloc_released = cols_alloc.release(); auto* cols_alloc_released = cols_alloc.release();
context.AddStreamCallback([data_alloc_released, cols_alloc_released] { ctx.AddStreamCallback([data_alloc_released, cols_alloc_released] {
paddle::memory::allocation::Allocator::AllocationDeleter( paddle::memory::allocation::Allocator::AllocationDeleter(
data_alloc_released); data_alloc_released);
paddle::memory::allocation::Allocator::AllocationDeleter( paddle::memory::allocation::Allocator::AllocationDeleter(
cols_alloc_released); cols_alloc_released);
}); });
#endif #endif
}
template <typename T>
class SplitFunctor<phi::GPUContext, T> {
public:
void operator()(const phi::GPUContext& context,
const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& ref_inputs,
int axis,
std::vector<phi::DenseTensor*>* outputs) {
int64_t numel = input.numel();
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in
// 3 tensors of shape [0,1,4]
if (input.numel() == 0) {
return;
}
if (numel < std::numeric_limits<int32_t>::max()) {
SplitFunctorDispatchWithIndexType<T, int32_t>(
context, axis, input, ref_inputs, outputs);
} else {
SplitFunctorDispatchWithIndexType<T, int64_t>(
context, axis, input, ref_inputs, outputs);
}
} }
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace phi { namespace phi {
...@@ -34,6 +35,26 @@ enum class SegmentedArraySize { ...@@ -34,6 +35,26 @@ enum class SegmentedArraySize {
kFixed64 = 64, kFixed64 = 64,
}; };
template <typename T, SegmentedArraySize Size, int Num = static_cast<int>(Size)>
struct PADDLE_ALIGN(256) ValueArray {
public:
T data[Num];
void Set(T* ptr, const int num) {
for (auto i = 0; i < num; ++i) {
data[i] = ptr[i];
}
}
};
template <typename T>
struct PADDLE_ALIGN(256) ValueArray<T, SegmentedArraySize::kVariableLength, 0> {
public:
T* data{nullptr};
void Set(T* ptr, const int num) { data = ptr; }
};
template <typename T, SegmentedArraySize Size> template <typename T, SegmentedArraySize Size>
struct PADDLE_ALIGN(256) ConstPointerArray { struct PADDLE_ALIGN(256) ConstPointerArray {
public: public:
...@@ -62,8 +83,8 @@ struct PADDLE_ALIGN(256) PointerArray { ...@@ -62,8 +83,8 @@ struct PADDLE_ALIGN(256) PointerArray {
public: public:
T* data[static_cast<int>(Size)]; T* data[static_cast<int>(Size)];
void Set(const std::vector<T*>& ptrs, T** dev_ptr = nullptr) { void Set(T** ptrs, const int num, T** dev_ptr = nullptr) {
for (auto i = 0; i < ptrs.size(); ++i) { for (auto i = 0; i < num; ++i) {
data[i] = ptrs[i]; data[i] = ptrs[i];
} }
} }
...@@ -74,9 +95,7 @@ struct PADDLE_ALIGN(256) PointerArray<T, SegmentedArraySize::kVariableLength> { ...@@ -74,9 +95,7 @@ struct PADDLE_ALIGN(256) PointerArray<T, SegmentedArraySize::kVariableLength> {
public: public:
T** data{nullptr}; T** data{nullptr};
void Set(const std::vector<T*>& ptrs, T** dev_ptr = nullptr) { void Set(T** ptrs, const int num, T** dev_ptr = nullptr) { data = dev_ptr; }
data = dev_ptr;
}
}; };
#undef PADDLE_ALIGN #undef PADDLE_ALIGN
...@@ -84,13 +103,24 @@ struct PADDLE_ALIGN(256) PointerArray<T, SegmentedArraySize::kVariableLength> { ...@@ -84,13 +103,24 @@ struct PADDLE_ALIGN(256) PointerArray<T, SegmentedArraySize::kVariableLength> {
template <typename Context> template <typename Context>
struct ArraySetterBase { struct ArraySetterBase {
protected: protected:
void* AllocAndCopy(const Context& ctx, void* src, size_t num_bytes) { void* AllocAndCopy(const Context& ctx,
void* src,
size_t num_bytes,
bool use_cuda_graph = false) {
allocation = paddle::memory::Alloc( allocation = paddle::memory::Alloc(
ctx.GetPlace(), ctx.GetPlace(),
num_bytes, num_bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
int8_t* restored = reinterpret_cast<int8_t*>(src);
#ifdef PADDLE_WITH_CUDA
if (use_cuda_graph) {
restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph<int8_t>(
restored, num_bytes);
}
#endif
phi::backends::gpu::GpuMemcpyAsync(allocation->ptr(), phi::backends::gpu::GpuMemcpyAsync(allocation->ptr(),
src, restored,
num_bytes, num_bytes,
phi::gpuMemcpyHostToDevice, phi::gpuMemcpyHostToDevice,
ctx.stream()); ctx.stream());
...@@ -131,13 +161,28 @@ struct PointerArraySetter : public ArraySetterBase<Context> { ...@@ -131,13 +161,28 @@ struct PointerArraySetter : public ArraySetterBase<Context> {
public: public:
PointerArray<T, Size> array; PointerArray<T, Size> array;
PointerArraySetter(const Context& ctx, std::vector<DenseTensor*>* t) { // need_alloc : tensor data needs extra buffer or not.
// use_cuda_graph: tensor data shall be captured by cuda_graph or not.
// pre_alloc_host_buf: tensor data is temporaily stored by pinned memory or
// not.
PointerArraySetter(const Context& ctx,
std::vector<DenseTensor*>* t,
bool need_alloc = false,
bool use_cuda_graph = false,
T** pre_alloc_host_buf = nullptr) {
ptrs.resize(t->size()); ptrs.resize(t->size());
T** data_ptr = ptrs.data();
#ifdef PADDLE_WITH_HIP
if (pre_alloc_host_buf) {
data_ptr = pre_alloc_host_buf;
}
#endif
for (int i = 0; i < t->size(); ++i) { for (int i = 0; i < t->size(); ++i) {
if (t->at(i) && (t->at(i)->numel() > 0)) { if (t->at(i) && (t->at(i)->numel() > 0)) {
ptrs[i] = ctx.template Alloc<T>(t->at(i)); data_ptr[i] =
need_alloc ? ctx.template Alloc<T>(t->at(i)) : t->at(i)->data<T>();
} else { } else {
ptrs[i] = nullptr; data_ptr[i] = nullptr;
} }
} }
...@@ -145,10 +190,9 @@ struct PointerArraySetter : public ArraySetterBase<Context> { ...@@ -145,10 +190,9 @@ struct PointerArraySetter : public ArraySetterBase<Context> {
if (Size == SegmentedArraySize::kVariableLength) { if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t->size() * sizeof(T*); size_t num_bytes = t->size() * sizeof(T*);
dev_ptr = reinterpret_cast<T**>(this->AllocAndCopy( dev_ptr = reinterpret_cast<T**>(this->AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes)); ctx, reinterpret_cast<void*>(data_ptr), num_bytes, use_cuda_graph));
} }
array.Set(data_ptr, t->size(), dev_ptr);
array.Set(ptrs, dev_ptr);
} }
private: private:
......
...@@ -192,7 +192,7 @@ void LaunchUnStackKernel(const Context& ctx, ...@@ -192,7 +192,7 @@ void LaunchUnStackKernel(const Context& ctx,
<< ", out_col=" << out_col << ", num_splits=" << num_splits; << ", out_col=" << out_col << ", num_splits=" << num_splits;
auto x_ptr = x.data<T>(); auto x_ptr = x.data<T>();
PointerArraySetter<Context, T, Size> setter(ctx, outs); PointerArraySetter<Context, T, Size> setter(ctx, outs, /*need_alloc=*/true);
if (out_col == 1) { if (out_col == 1) {
// For the case axis == (x.dims().size() - 1) // For the case axis == (x.dims().size() - 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册