未验证 提交 b1faa562 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement a common segmented array. (#49450)

* Implement a common PointerArray.

* Polish codes.

* Add including of header file.

* Add the branch of kFix8.

* Fix compiling error.

* Add alignas hint to fix the performance drop.

* Optimize the H2D copy in stack_grad.

* Rename the macro.

* Fix align hint for different compilers.

* Polish the define of PADDLE_ALIGN.

* Fix compiling error.

* Remove the align hint on windows.
上级 24f5c46e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/fast_divmod.h"
namespace phi {
namespace funcs {
template <typename IndexT>
struct GeneralDivMod {
public:
explicit GeneralDivMod(IndexT d) { divmoder = phi::funcs::FastDivMod(d); }
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
phi::funcs::FastDivMod divmoder;
};
template <>
struct GeneralDivMod<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
explicit GeneralDivMod(int64_t d) { divisor = d; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / divisor;
data[1] = val - data[0] * divisor;
return data;
}
int64_t divisor;
};
#if !defined(_WIN32)
#define PADDLE_ALIGN(x) __attribute__((aligned(x)))
#else
#define PADDLE_ALIGN(x)
#endif
enum class SegmentedArraySize {
kVariableLength = 0,
kFixed4 = 4,
kFixed8 = 8,
kFixed16 = 16,
kFixed32 = 32,
kFixed64 = 64,
};
template <typename T, SegmentedArraySize Size>
struct PADDLE_ALIGN(256) ConstPointerArray {
public:
const T* data[static_cast<int>(Size)];
void Set(const std::vector<const T*>& ptrs, const T** dev_ptr = nullptr) {
for (auto i = 0; i < ptrs.size(); ++i) {
data[i] = ptrs[i];
}
}
};
template <typename T>
struct PADDLE_ALIGN(256)
ConstPointerArray<T, SegmentedArraySize::kVariableLength> {
public:
const T** data{nullptr};
void Set(const std::vector<const T*>& ptrs, const T** dev_ptr = nullptr) {
data = dev_ptr;
}
};
template <typename T, SegmentedArraySize Size>
struct PADDLE_ALIGN(256) PointerArray {
public:
T* data[static_cast<int>(Size)];
void Set(const std::vector<T*>& ptrs, T** dev_ptr = nullptr) {
for (auto i = 0; i < ptrs.size(); ++i) {
data[i] = ptrs[i];
}
}
};
template <typename T>
struct PADDLE_ALIGN(256) PointerArray<T, SegmentedArraySize::kVariableLength> {
public:
T** data{nullptr};
void Set(const std::vector<T*>& ptrs, T** dev_ptr = nullptr) {
data = dev_ptr;
}
};
#undef PADDLE_ALIGN
template <typename Context>
struct ArraySetterBase {
protected:
void* AllocAndCopy(const Context& ctx, void* src, size_t num_bytes) {
allocation = paddle::memory::Alloc(
ctx.GetPlace(),
num_bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(),
allocation->ptr(),
phi::CPUPlace(),
src,
num_bytes,
ctx.stream());
return allocation->ptr();
}
phi::Allocator::AllocationPtr allocation{nullptr};
};
template <typename Context, typename T, SegmentedArraySize Size>
struct ConstPointerArraySetter : public ArraySetterBase<Context> {
public:
ConstPointerArray<T, Size> array;
ConstPointerArraySetter(const Context& ctx,
const std::vector<const DenseTensor*>& t) {
ptrs.resize(t.size());
for (int i = 0; i < t.size(); ++i) {
ptrs[i] = t[i]->data<T>();
}
const T** dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t.size() * sizeof(T*);
dev_ptr =
reinterpret_cast<const T**>(ArraySetterBase<Context>::AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
}
array.Set(ptrs, dev_ptr);
}
private:
std::vector<const T*> ptrs;
};
template <typename Context, typename T, SegmentedArraySize Size>
struct PointerArraySetter : public ArraySetterBase<Context> {
public:
PointerArray<T, Size> array;
PointerArraySetter(const Context& ctx, std::vector<DenseTensor*>* t) {
ptrs.resize(t->size());
for (int i = 0; i < t->size(); ++i) {
if (t->at(i) && (t->at(i)->numel() > 0)) {
ptrs[i] = ctx.template Alloc<T>(t->at(i));
} else {
ptrs[i] = nullptr;
}
}
T** dev_ptr = nullptr;
if (Size == SegmentedArraySize::kVariableLength) {
size_t num_bytes = t->size() * sizeof(T*);
dev_ptr = reinterpret_cast<T**>(ArraySetterBase<Context>::AllocAndCopy(
ctx, reinterpret_cast<void*>(ptrs.data()), num_bytes));
}
array.Set(ptrs, dev_ptr);
}
private:
std::vector<T*> ptrs;
};
inline SegmentedArraySize CalcArraySize(int n) {
if (n <= 4) {
return SegmentedArraySize::kFixed4;
} else if (n <= 8) {
return SegmentedArraySize::kFixed8;
} else if (n <= 16) {
return SegmentedArraySize::kFixed16;
} else if (n <= 32) {
return SegmentedArraySize::kFixed32;
} else if (n <= 64) {
return SegmentedArraySize::kFixed64;
} else {
return SegmentedArraySize::kVariableLength;
}
}
} // namespace funcs
#define _SEGMENTED_ARRAY_KERNEL_CASE(size, ...) \
case (size): { \
constexpr auto kArraySize = (size); \
__VA_ARGS__; \
} break
#define _SEGMENTED_ARRAY_KERNEL_DEFAULT(size, ...) \
default: { \
constexpr auto kArraySize = (size); \
__VA_ARGS__; \
} break
#define SEGMENTED_ARRAY_KERNEL_HELPER(...) \
_SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed4, \
##__VA_ARGS__); \
_SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed8, \
##__VA_ARGS__); \
_SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed16, \
##__VA_ARGS__); \
_SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed32, \
##__VA_ARGS__); \
_SEGMENTED_ARRAY_KERNEL_CASE(funcs::SegmentedArraySize::kFixed64, \
##__VA_ARGS__); \
_SEGMENTED_ARRAY_KERNEL_DEFAULT(funcs::SegmentedArraySize::kVariableLength, \
##__VA_ARGS__);
} // namespace phi
......@@ -16,16 +16,17 @@
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"
namespace phi {
template <typename T, typename IndexT>
__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
int pre_dim_size,
int split_dim_size,
int suf_dim_size,
int num_split,
T** output_ptrs) {
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernel(const T* __restrict__ input,
IndexT pre_dim_size,
IndexT split_dim_size,
IndexT suf_dim_size,
IndexT num_split,
ArrayT array) {
assert(blockDim.y == 1);
assert(blockDim.z == 1);
// In this case they are equal
......@@ -40,7 +41,7 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size;
IndexT k = offset % suf_dim_size;
T* output = output_ptrs[j / each_dim_size];
T* output = array.data[j / each_dim_size];
if (output == nullptr) {
return;
}
......@@ -50,12 +51,12 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
}
}
template <typename T, typename IndexT>
__global__ void StackGradKernelForLastDim(const T* __restrict__ in_data,
const IndexT cols,
const IndexT rows,
const IndexT tile_x_num,
T** out_datas) {
template <typename T, typename IndexT, typename ArrayT>
__global__ void UnStackCudaKernelForLastDim(const T* __restrict__ in_data,
const IndexT cols,
const IndexT rows,
const IndexT tile_x_num,
ArrayT array) {
constexpr int buffer_size = 512;
__shared__ T s_buf[buffer_size];
......@@ -71,112 +72,112 @@ __global__ void StackGradKernelForLastDim(const T* __restrict__ in_data,
}
__syncthreads();
if (is_valid) {
if (out_datas[col_idx] != nullptr) {
out_datas[col_idx][row_idx] = s_buf[s_idx];
if (array.data[col_idx]) {
array.data[col_idx][row_idx] = s_buf[s_idx];
}
}
}
}
template <typename Context, typename T, typename IndexT>
void LaunchStackGradCUDAKernel(const Context& ctx,
const DenseTensor& out,
std::vector<DenseTensor*>* x_grad_ptr,
const int axis,
const int64_t dy_pre) {
auto x_grad = *x_grad_ptr;
int out_num = out.dims()[axis];
PADDLE_ENFORCE_EQ(
out_num,
x_grad.size(),
phi::errors::InvalidArgument(
"Output x_grad size shall be equal to output num, but output num "
"received in stack_grad op is:%d, and x_grad size is:%d.",
out_num,
x_grad.size()));
std::vector<T*> outputs(out_num);
for (size_t j = 0; j < out_num; ++j) {
if (x_grad[j] == nullptr || x_grad[j]->numel() == 0UL) {
outputs[j] = nullptr;
} else {
outputs[j] = ctx.template Alloc<T>(x_grad[j]);
}
}
auto tmp_out_data = paddle::memory::Alloc(
ctx.GetPlace(),
out_num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(),
tmp_out_data->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(outputs.data()),
out_num * sizeof(T*),
ctx.stream());
if (axis == (out.dims().size() - 1)) {
template <typename Context,
typename T,
typename IndexT,
funcs::SegmentedArraySize Size>
void LaunchUnStackKernel(const Context& ctx,
const IndexT pre_dim,
const IndexT split_dim,
const IndexT suf_dim,
const IndexT num_splits,
const DenseTensor& out_grad,
std::vector<DenseTensor*>* x_grad) {
// each x_grad should have same shape
auto dout_ptr = out_grad.data<T>();
funcs::PointerArraySetter<Context, T, Size> setter(ctx, x_grad);
if (suf_dim == 1) {
// For the case axis == (out_grad.dims().size() - 1)
constexpr int kThreads = 512;
constexpr int kWarpSize = 32;
constexpr int kMaxOut = 16;
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
bool is_small_num = out_num < kMaxOut;
if (is_small_num) {
tid_y = out_num;
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
if (split_dim < kMaxOut) {
tid_y = split_dim;
tid_x =
std::min(backends::gpu::RoundToNextHighPowOfTwo(dy_pre, kWarpSize),
std::min(backends::gpu::RoundToNextHighPowOfTwo(pre_dim, kWarpSize),
kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y));
} else {
tid_y = kMaxOut;
tid_x = kWarpSize;
bid_y = backends::gpu::DivUp<int>(out_num, kMaxOut);
bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
}
int tile_x_num = backends::gpu::DivUp<int>(dy_pre, tid_x);
int tile_x_num = backends::gpu::DivUp<int>(pre_dim, tid_x);
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
dim3 blocks(tid_x, tid_y, 1);
dim3 grids(bid_x, bid_y, 1);
StackGradKernelForLastDim<T, IndexT><<<grids, blocks, 0, ctx.stream()>>>(
out.data<T>(),
out_num,
dy_pre,
tile_x_num,
reinterpret_cast<T**>(tmp_out_data->ptr()));
UnStackCudaKernelForLastDim<T, IndexT, decltype(setter.array)>
<<<grids, blocks, 0, ctx.stream()>>>(
dout_ptr, split_dim, pre_dim, tile_x_num, setter.array);
} else {
int dy_suf = out.numel() / (out_num * dy_pre);
auto config =
backends::gpu::GetGpuLaunchConfig1D(ctx, dy_pre * out_num * dy_suf);
UnStackHelperCUDAKernel<T, IndexT>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
out.data<T>(),
dy_pre,
out_num,
dy_suf,
out_num,
reinterpret_cast<T**>(tmp_out_data->ptr()));
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, pre_dim * split_dim * suf_dim);
UnStackCudaKernel<T, IndexT, decltype(setter.array)>
<<<config.block_per_grid.x,
config.thread_per_block.x,
0,
ctx.stream()>>>(
dout_ptr, pre_dim, split_dim, suf_dim, num_splits, setter.array);
}
}
template <typename T, typename Context>
void StackGradKernel(const Context& dev_ctx,
const DenseTensor& out,
void StackGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int axis,
std::vector<DenseTensor*> x_grad) {
const auto& dy_dims = out.dims();
int actual_axis = axis < 0 ? axis + dy_dims.size() : axis;
bool use_int32 = out.numel() < std::numeric_limits<int32_t>::max();
if (axis < 0) axis += out_grad.dims().size();
int64_t split_dim = out_grad.dims()[axis];
PADDLE_ENFORCE_EQ(
split_dim,
x_grad.size(),
phi::errors::InvalidArgument(
"Output x_grad size should be equal to the split_dim, but"
" received split_dim is:%d x_grad size is:%d.",
split_dim,
x_grad.size()));
int64_t dy_pre = 1;
for (int i = 0; i < actual_axis; ++i) {
dy_pre *= dy_dims[i];
auto dout_dims = out_grad.dims();
int64_t dout_pre = 1;
for (int i = 0; i < axis; ++i) {
dout_pre *= dout_dims[i];
}
if (use_int32) {
LaunchStackGradCUDAKernel<Context, T, int32_t>(
dev_ctx, out, &x_grad, actual_axis, dy_pre);
int64_t dout_suf = out_grad.numel() / (split_dim * dout_pre);
if (out_grad.numel() < std::numeric_limits<int32_t>::max()) {
switch (funcs::CalcArraySize(split_dim)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchUnStackKernel<Context, T, int32_t, kArraySize>(ctx,
dout_pre,
split_dim,
dout_suf,
split_dim,
out_grad,
&x_grad));
}
} else {
LaunchStackGradCUDAKernel<Context, T, int64_t>(
dev_ctx, out, &x_grad, actual_axis, dy_pre);
switch (funcs::CalcArraySize(split_dim)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchUnStackKernel<Context, T, int64_t, kArraySize>(ctx,
dout_pre,
split_dim,
dout_suf,
split_dim,
out_grad,
&x_grad));
}
}
}
......
......@@ -15,86 +15,15 @@
#include "paddle/phi/kernels/stack_kernel.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"
#include "paddle/phi/kernels/funcs/segmented_array.h"
namespace phi {
template <typename IndexT>
struct DivmodWarpper {
public:
void SetDivisor(IndexT divisor) {
divmoder = phi::funcs::FastDivMod(divisor);
}
__device__ inline phi::funcs::FastDivMod::DivModT div_mod(IndexT val) {
return divmoder.Divmod(val);
}
private:
phi::funcs::FastDivMod divmoder;
};
template <>
struct DivmodWarpper<int64_t> {
public:
using DivModT = phi::AlignedVector<int64_t, 2>;
void SetDivisor(int64_t divisor) { dividen_ = divisor; }
__device__ inline DivModT div_mod(int64_t val) {
DivModT data;
data[0] = val / dividen_;
data[1] = val - data[0] * dividen_;
return data;
}
private:
int64_t dividen_;
};
template <typename T, typename IndexT, int Size>
struct PointerArray : public DivmodWarpper<IndexT> {
public:
const T* data[Size];
PointerArray(const std::vector<const DenseTensor*>& x,
int num,
IndexT divisor) {
this->SetDivisor(divisor);
for (auto i = 0; i < num; ++i) {
data[i] = x[i]->data<T>();
}
}
};
template <typename Context, typename T, typename IndexT>
struct PointerToPointer : public DivmodWarpper<IndexT> {
public:
T** data{nullptr};
PointerToPointer(const Context& ctx,
const std::vector<const DenseTensor*>& x,
IndexT num,
IndexT divisor,
paddle::memory::AllocationPtr* dev_ins_ptr) {
this->SetDivisor(divisor);
std::vector<const T*> x_datas(num);
for (int i = 0; i < num; ++i) {
x_datas[i] = x[i]->data<T>();
}
*dev_ins_ptr = paddle::memory::Alloc(
ctx.GetPlace(),
num * sizeof(T*),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
paddle::memory::Copy(ctx.GetPlace(),
(*dev_ins_ptr)->ptr(),
phi::CPUPlace(),
reinterpret_cast<void*>(x_datas.data()),
num * sizeof(T*),
ctx.stream());
data = reinterpret_cast<T**>((*dev_ins_ptr)->ptr());
}
};
template <typename T, typename IndexT, typename WrapT>
__global__ void StackCUDAKernel(WrapT input_warpper,
template <typename T, typename IndexT, typename ArrayT>
__global__ void StackCUDAKernel(ArrayT array,
funcs::GeneralDivMod<IndexT> divmoder,
IndexT split_size,
IndexT rows,
IndexT cols,
......@@ -106,85 +35,69 @@ __global__ void StackCUDAKernel(WrapT input_warpper,
for (; grid_x < cols; grid_x += grid_x_stride) {
IndexT grid_y = static_cast<IndexT>(blockIdx.y) * blockDim.y + threadIdx.y;
auto divmod_rslt = input_warpper.div_mod(grid_x);
const T* input_ptr = input_warpper.data[divmod_rslt[0]];
auto divmod_rslt = divmoder.div_mod(grid_x);
IndexT split = divmod_rslt[0]; // grid_x / split_size
IndexT col_offset = divmod_rslt[1]; // grid_x % split_size
const T* input_ptr = array.data[split];
#pragma unroll
for (; grid_y < rows; grid_y += grid_y_stride) {
output[grid_y * cols + grid_x] =
input_ptr[grid_y * split_size + divmod_rslt[1]];
input_ptr[grid_y * split_size + col_offset];
}
}
}
template <typename T, typename IndexT, typename Context>
void LaunchStackCUDAKernelWithIndexType(
const Context& ctx,
const IndexT x_col,
const IndexT x_row,
const IndexT out_col,
const phi::backends::gpu::GpuLaunchConfig& cfg,
const std::vector<const DenseTensor*>& x,
T* dst_data) {
int num = static_cast<int>(x.size());
#define IMPL_STACK_CUDA_KERNEL_CASE(size_, ...) \
case size_: { \
PointerArray<T, IndexT, size_> ptr_array(x, num, x_col); \
__VA_ARGS__; \
} break;
#define IMPL_STACK_CUDA_KERNEL_HELPER(...) \
IMPL_STACK_CUDA_KERNEL_CASE(4, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(8, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(16, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(32, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(64, ##__VA_ARGS__); \
IMPL_STACK_CUDA_KERNEL_CASE(128, ##__VA_ARGS__);
switch (phi::backends::gpu::RoundToNextHighPowOfTwo(num, 4)) {
IMPL_STACK_CUDA_KERNEL_HELPER(
StackCUDAKernel<T, IndexT, decltype(ptr_array)>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
ptr_array, x_col, x_row, out_col, dst_data));
default: {
paddle::memory::AllocationPtr dev_ins_ptr{nullptr};
PointerToPointer<Context, T, IndexT> ptr_array(
ctx, x, num, x_col, &dev_ins_ptr);
StackCUDAKernel<T, IndexT, decltype(ptr_array)>
<<<cfg.block_per_grid, cfg.thread_per_block, 0, ctx.stream()>>>(
ptr_array, x_col, x_row, out_col, dst_data);
}
}
#undef IMPL_STACK_CUDA_KERNEL_HELPER
#undef IMPL_STACK_CUDA_KERNEL_CASE
template <typename Context,
typename T,
typename IndexT,
funcs::SegmentedArraySize Size>
void LaunchStackKernel(const Context& ctx,
const IndexT x_col,
const IndexT x_row,
const IndexT out_col,
const std::vector<const DenseTensor*>& x,
DenseTensor* out) {
T* out_ptr = ctx.template Alloc<T>(out);
auto config = phi::backends::gpu::GetGpuLaunchConfig2D(ctx, out_col, x_row);
funcs::ConstPointerArraySetter<Context, T, Size> setter(ctx, x);
funcs::GeneralDivMod<IndexT> divmoder(x_col);
StackCUDAKernel<T, IndexT, decltype(setter.array)>
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
setter.array, divmoder, x_col, x_row, out_col, out_ptr);
}
template <typename T, typename Context>
void StackKernel(const Context& dev_ctx,
void StackKernel(const Context& ctx,
const std::vector<const DenseTensor*>& x,
int axis,
DenseTensor* out) {
if (axis < 0) axis += (x[0]->dims().size() + 1);
int num = static_cast<int>(x.size());
T* dst_data = dev_ctx.template Alloc<T>(out);
// Split x dim from axis to matrix
int64_t x_row = 1, x_col = 1;
int64_t x_row = 1;
for (int i = 0; i < axis; ++i) {
x_row *= x[0]->dims()[i];
}
x_col = x[0]->numel() / x_row;
int64_t x_col = x[0]->numel() / x_row;
int64_t out_col = x_col * num;
auto config =
phi::backends::gpu::GetGpuLaunchConfig2D(dev_ctx, out_col, x_row);
if (out->numel() < std::numeric_limits<int32_t>::max()) {
LaunchStackCUDAKernelWithIndexType<T, int32_t, Context>(
dev_ctx, x_col, x_row, out_col, config, x, dst_data);
switch (funcs::CalcArraySize(num)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchStackKernel<Context, T, int32_t, kArraySize>(
ctx, x_col, x_row, out_col, x, out));
}
} else {
LaunchStackCUDAKernelWithIndexType<T, int64_t, Context>(
dev_ctx, x_col, x_row, out_col, config, x, dst_data);
switch (funcs::CalcArraySize(num)) {
SEGMENTED_ARRAY_KERNEL_HELPER(
LaunchStackKernel<Context, T, int64_t, kArraySize>(
ctx, x_col, x_row, out_col, x, out));
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(stack,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册