未验证 提交 393b3bd6 编写于 作者: T Thunderbrook 提交者: GitHub

fix split core (#31892)

* fix split core

* format
上级 3a95a0bc
...@@ -114,8 +114,8 @@ __global__ void ConcatKernel(const T** inputs_data, const int in_num, ...@@ -114,8 +114,8 @@ __global__ void ConcatKernel(const T** inputs_data, const int in_num,
} }
template <typename T> template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int in_col, const int* out_cols, const int64_t in_col, const int64_t* out_cols,
int out_cols_size, T** outputs_data) { int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0; int curr_segment = 0;
...@@ -159,15 +159,15 @@ __device__ void SplitKernelDetail(const T* input_data, const int in_row, ...@@ -159,15 +159,15 @@ __device__ void SplitKernelDetail(const T* input_data, const int in_row,
} }
template <typename T> template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int in_col, const int fixed_out_col, const int64_t in_col, const int64_t fixed_out_col,
T** outputs_data) { T** outputs_data) {
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data); SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
} }
template <typename T> template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int in_col, const int fixed_out_col, const int64_t in_col, const int64_t fixed_out_col,
T* outputs_addr0, T* outputs_addr1) { T* outputs_addr0, T* outputs_addr1) {
T* outputs_data[2]; T* outputs_data[2];
outputs_data[0] = outputs_addr0; outputs_data[0] = outputs_addr0;
...@@ -176,8 +176,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row, ...@@ -176,8 +176,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
} }
template <typename T> template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int in_col, const int fixed_out_col, const int64_t in_col, const int64_t fixed_out_col,
T* outputs_addr0, T* outputs_addr1, T* outputs_addr0, T* outputs_addr1,
T* outputs_addr2) { T* outputs_addr2) {
T* outputs_data[3]; T* outputs_data[3];
...@@ -188,8 +188,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row, ...@@ -188,8 +188,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
} }
template <typename T> template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int64_t in_row,
const int in_col, const int fixed_out_col, const int64_t in_col, const int64_t fixed_out_col,
T* outputs_addr0, T* outputs_addr1, T* outputs_addr0, T* outputs_addr1,
T* outputs_addr2, T* outputs_addr3) { T* outputs_addr2, T* outputs_addr3) {
T* outputs_data[4]; T* outputs_data[4];
...@@ -201,8 +201,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row, ...@@ -201,8 +201,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
} }
static inline void GetBlockDims(const platform::CUDADeviceContext& context, static inline void GetBlockDims(const platform::CUDADeviceContext& context,
int num_rows, int num_cols, dim3* block_dims, int64_t num_rows, int64_t num_cols,
dim3* grid_dims) { dim3* block_dims, dim3* grid_dims) {
// Set the thread block and grid according to CurrentDeviceId // Set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024; const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock; int block_cols = kThreadsPerBlock;
...@@ -213,12 +213,12 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context, ...@@ -213,12 +213,12 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context,
*block_dims = dim3(block_cols, block_rows, 1); *block_dims = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount(); int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols = int grid_cols =
std::min((num_cols + block_cols - 1) / block_cols, max_blocks); std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows = int grid_rows = std::min(max_blocks / grid_cols,
std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1)); std::max(num_rows / block_rows, (int64_t)1));
*grid_dims = dim3(grid_cols, grid_rows, 1); *grid_dims = dim3(grid_cols, grid_rows, 1);
} }
...@@ -319,22 +319,22 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -319,22 +319,22 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int axis, std::vector<framework::Tensor*>* outputs) { int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
int out_row = 1; int64_t out_row = 1;
auto dim_0 = ref_inputs[0]->dims(); auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) { for (int i = 0; i < axis; ++i) {
out_row *= dim_0[i]; out_row *= dim_0[i];
} }
int out0_col = ref_inputs[0]->numel() / out_row; int64_t out0_col = ref_inputs[0]->numel() / out_row;
int in_col = 0, in_row = out_row; int64_t in_col = 0, in_row = out_row;
bool has_same_shape = true; bool has_same_shape = true;
std::vector<T*> outputs_data(o_num); std::vector<T*> outputs_data(o_num);
std::vector<int> outputs_cols(o_num + 1); std::vector<int64_t> outputs_cols(o_num + 1);
outputs_cols[0] = 0; outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) { for (int i = 0; i < o_num; ++i) {
int t_col = ref_inputs.at(i)->numel() / out_row; int64_t t_col = ref_inputs.at(i)->numel() / out_row;
if (has_same_shape) { if (has_same_shape) {
if (t_col != out0_col) has_same_shape = false; if (t_col != out0_col) has_same_shape = false;
} }
...@@ -384,13 +384,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> { ...@@ -384,13 +384,13 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
auto tmp_dev_ins_col_data = auto tmp_dev_ins_col_data =
memory::Alloc(context, memory::Alloc(context,
outputs_cols.size() * sizeof(int)); outputs_cols.size() * sizeof(int64_t));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols.data()), reinterpret_cast<void*>(outputs_cols.data()),
outputs_cols.size() * sizeof(int), context.stream()); outputs_cols.size() * sizeof(int64_t), context.stream());
int* dev_outs_col_data = int64_t* dev_outs_col_data =
reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr()); reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>( SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data, input.data<T>(), in_row, in_col, dev_outs_col_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册