未验证 提交 5782ddda 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the concat and split kernel for specical cases when the number of...

Optimize the concat and split kernel for specical cases when the number of inputs/outputs is 2 (#17415)

* Optimize the concat and split kernel for special cases that the number of inputs/outputs is 2.
test=develop

* Refine codes.
test=develop

* Correct the condition.
test=develop

* Move the define of tmp_data outside the if statement.

* Print the cudnn minor version.
test=develop

* Fix the case when in_num/o_num is 1 in concat/split op.
test=develop

* Remove const_cast.
test=develop
上级 acbb4bf3
......@@ -96,7 +96,7 @@ if(CUDNN_FOUND)
endif()
message(STATUS "Current cuDNN header is ${CUDNN_INCLUDE_DIR}/cudnn.h. "
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}. ")
"Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ")
endif()
endif()
......@@ -24,9 +24,9 @@ namespace operators {
namespace math {
template <typename T>
__global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols,
T* output) {
__global__ void ConcatKernel(const T** inputs, const int* input_cols,
int col_size, const int output_rows,
const int output_cols, T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0;
int curr_offset = input_cols[0];
......@@ -41,7 +41,7 @@ __global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* input_ptr = inputs[curr_segment];
const T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
......@@ -50,14 +50,14 @@ __global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
}
template <typename T>
__global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
const int out_rows, const int out_cols,
T* output_data) {
__device__ void ConcatKernelDetail(const T** inputs_data,
const int fixed_in_col, const int out_rows,
const int out_cols, T* output_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * 1.0 / fixed_in_col;
int in_offset = tid_x - split * fixed_in_col;
T* input_ptr = inputs_data[split];
const T* input_ptr = inputs_data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
output_data[tid_y * out_cols + tid_x] =
......@@ -66,6 +66,25 @@ __global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
}
}
template <typename T>
__global__ void ConcatKernel(const T* input_addr0, const T* input_addr1,
const int fixed_in_col, const int out_rows,
const int out_cols, T* output_data) {
const T* inputs_data[2];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
output_data);
}
template <typename T>
__global__ void ConcatKernel(const T** inputs_data, const int in_num,
const int fixed_in_col, const int out_rows,
const int out_cols, T* output_data) {
ConcatKernelDetail<T>(inputs_data, fixed_in_col, out_rows, out_cols,
output_data);
}
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int* out_cols,
......@@ -94,7 +113,7 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
__device__ void SplitKernelDetail(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -111,6 +130,45 @@ __global__ void SplitKernel(const T* input_data, const int in_row,
}
}
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
T** outputs_data) {
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
template <typename T>
__global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
T* outputs_addr0, T* outputs_addr1) {
T* outputs_data[2];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
static inline void GetBlockDims(const platform::CUDADeviceContext& context,
int num_rows, int num_cols, dim3* block_dims,
dim3* grid_dims) {
// Set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((num_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
*block_dims = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((num_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1));
*grid_dims = dim3(grid_cols, grid_rows, 1);
}
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
......@@ -131,53 +189,47 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
int in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0;
std::vector<const T*> inputs_data;
std::vector<const T*> inputs_data(in_num);
std::vector<int> inputs_col(in_num + 1);
inputs_data.reserve(in_num);
inputs_col[0] = 0;
bool sameShape = true;
bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) {
int t_cols = input[i].numel() / in_row;
if (sameShape) {
if (t_cols != in_col) sameShape = false;
if (has_same_shape) {
if (t_cols != in_col) has_same_shape = false;
}
out_col += t_cols;
inputs_col[i + 1] = out_col;
inputs_data.emplace_back(input[i].data<T>());
inputs_data[i] = input[i].data<T>();
}
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (out_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(context, out_row, out_col, &block_dims, &grid_dims);
int grid_cols =
std::min((out_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
auto tmp_dev_ins_data =
memory::allocation::AllocationPtr tmp_dev_ins_data;
const T** dev_ins_data = nullptr;
if (!has_same_shape || (in_num != 2)) {
tmp_dev_ins_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data.data()),
inputs_data.size() * sizeof(T*), context.stream());
T** dev_ins_data = reinterpret_cast<T**>(tmp_dev_ins_data->ptr());
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr());
}
if (sameShape) {
ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, in_col, out_row, out_col, output->data<T>());
if (has_same_shape) {
if (in_num == 2) {
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
inputs_data[0], inputs_data[1], in_col, out_row, out_col,
output->data<T>());
} else {
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data, in_num, in_col, out_row, out_col, output->data<T>());
}
} else {
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
......@@ -188,7 +240,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
inputs_col.size() * sizeof(int), context.stream());
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr());
ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
out_row, out_col, output->data<T>());
}
......@@ -216,7 +268,7 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int out0_col = ref_inputs[0]->numel() / out_row;
int in_col = 0, in_row = out_row;
bool sameShape = true;
bool has_same_shape = true;
std::vector<T*> outputs_data(o_num);
std::vector<int> outputs_cols(o_num + 1);
......@@ -224,8 +276,8 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) {
int t_col = ref_inputs.at(i)->numel() / out_row;
if (sameShape) {
if (t_col != out0_col) sameShape = false;
if (has_same_shape) {
if (t_col != out0_col) has_same_shape = false;
}
in_col += t_col;
outputs_cols[i + 1] = in_col;
......@@ -236,36 +288,32 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
}
}
// computation
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (in_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((in_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((in_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(context, out_row, in_col, &block_dims, &grid_dims);
auto tmp_dev_outs_data =
memory::allocation::AllocationPtr tmp_dev_outs_data;
T** dev_out_gpu_data = nullptr;
if (!has_same_shape || (o_num != 2)) {
tmp_dev_outs_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data.data()),
outputs_data.size() * sizeof(T*), context.stream());
T** dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr());
}
if (sameShape) {
SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
if (has_same_shape) {
if (o_num == 2) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, outputs_data[0],
outputs_data[1]);
} else {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
}
} else {
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
......@@ -277,7 +325,7 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
int* dev_outs_col_data =
reinterpret_cast<int*>(tmp_dev_ins_col_data->ptr());
SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
}
......
......@@ -17,26 +17,24 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
/**
* case 1:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [3, 3, 4]
* output:
* out.shape: [5, 3, 4]
*/
template <typename DeviceContext, typename Place>
void testConcat() {
void ConcatCase1(DeviceContext* context) {
paddle::framework::Tensor input_a_cpu;
paddle::framework::Tensor input_b_cpu;
paddle::framework::Tensor out_cpu;
paddle::framework::Tensor input_a;
paddle::framework::Tensor input_b;
paddle::framework::Tensor out;
DeviceContext* context = new DeviceContext(Place());
// DeviceContext context(Place());
/**
* cast1:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [3, 3, 4]
* output:
* out.shape: [5, 3, 4]
*/
auto dim_a = paddle::framework::make_ddim({2, 3, 4});
auto dim_b = paddle::framework::make_ddim({3, 3, 4});
auto dim_out = paddle::framework::make_ddim({5, 3, 4});
......@@ -51,8 +49,8 @@ void testConcat() {
out_cpu.mutable_data<int>(dim_out, paddle::platform::CPUPlace());
}
int* a_ptr;
int* b_ptr;
int* a_ptr = nullptr;
int* b_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
......@@ -84,7 +82,7 @@ void testConcat() {
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
int* out_ptr;
int* out_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu);
......@@ -104,28 +102,42 @@ void testConcat() {
++idx_a;
}
}
//
/**
* cast2:
}
/**
* case 2:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 4, 4]
* output:
* out.shape: [2, 7, 4]
*/
dim_a = paddle::framework::make_ddim({2, 3, 4});
dim_b = paddle::framework::make_ddim({2, 4, 4});
dim_out = paddle::framework::make_ddim({2, 7, 4});
template <typename DeviceContext, typename Place>
void ConcatCase2(DeviceContext* context) {
paddle::framework::Tensor input_a_cpu;
paddle::framework::Tensor input_b_cpu;
paddle::framework::Tensor out_cpu;
paddle::framework::Tensor input_a;
paddle::framework::Tensor input_b;
paddle::framework::Tensor out;
auto dim_a = paddle::framework::make_ddim({2, 3, 4});
auto dim_b = paddle::framework::make_ddim({2, 4, 4});
auto dim_out = paddle::framework::make_ddim({2, 7, 4});
input_a.mutable_data<int>(dim_a, Place());
input_b.mutable_data<int>(dim_b, Place());
out.mutable_data<int>(dim_out, Place());
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
input_a_cpu.mutable_data<int>(dim_a, paddle::platform::CPUPlace());
input_b_cpu.mutable_data<int>(dim_b, paddle::platform::CPUPlace());
out_cpu.mutable_data<int>(dim_out, paddle::platform::CPUPlace());
}
int* a_ptr = nullptr;
int* b_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
......@@ -146,16 +158,18 @@ void testConcat() {
paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
}
input.clear();
std::vector<paddle::framework::Tensor> input;
input.push_back(input_a);
input.push_back(input_b);
paddle::operators::math::ConcatFunctor<DeviceContext, int> concat_functor;
concat_functor(*context, input, 1, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
int* out_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu);
......@@ -164,8 +178,8 @@ void testConcat() {
out_ptr = out.data<int>();
}
cols = 3 * 4;
idx_a = 0, idx_b = 0;
int cols = 3 * 4;
int idx_a = 0, idx_b = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 28; ++j) {
if (j >= cols) {
......@@ -177,28 +191,42 @@ void testConcat() {
}
}
}
}
/**
* cast3:
/**
* case 3:
* inputs:
* t_a.shape: [2, 3, 5]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 3, 9]
*/
dim_a = paddle::framework::make_ddim({2, 3, 4});
dim_b = paddle::framework::make_ddim({2, 3, 5});
dim_out = paddle::framework::make_ddim({2, 3, 9});
template <typename DeviceContext, typename Place>
void ConcatCase3(DeviceContext* context) {
paddle::framework::Tensor input_a_cpu;
paddle::framework::Tensor input_b_cpu;
paddle::framework::Tensor out_cpu;
paddle::framework::Tensor input_a;
paddle::framework::Tensor input_b;
paddle::framework::Tensor out;
auto dim_a = paddle::framework::make_ddim({2, 3, 4});
auto dim_b = paddle::framework::make_ddim({2, 3, 5});
auto dim_out = paddle::framework::make_ddim({2, 3, 9});
input_a.mutable_data<int>(dim_a, Place());
input_b.mutable_data<int>(dim_b, Place());
out.mutable_data<int>(dim_out, Place());
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
input_a_cpu.mutable_data<int>(dim_a, paddle::platform::CPUPlace());
input_b_cpu.mutable_data<int>(dim_b, paddle::platform::CPUPlace());
out_cpu.mutable_data<int>(dim_out, paddle::platform::CPUPlace());
}
int* a_ptr = nullptr;
int* b_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
......@@ -219,16 +247,18 @@ void testConcat() {
paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
}
input.clear();
std::vector<paddle::framework::Tensor> input;
input.push_back(input_a);
input.push_back(input_b);
paddle::operators::math::ConcatFunctor<DeviceContext, int> concat_functor;
concat_functor(*context, input, 2, &out);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
int* out_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu);
......@@ -238,8 +268,8 @@ void testConcat() {
}
// check the data
cols = 4;
idx_a = 0, idx_b = 0;
int cols = 4;
int idx_a = 0, idx_b = 0;
for (int i = 0; i < 6; ++i) {
for (int j = 0; j < 9; ++j) {
if (j >= cols) {
......@@ -251,9 +281,10 @@ void testConcat() {
}
}
}
}
/**
* cast4:
/**
* case 4:
* inputs:
* axis = 1
* t_a.shape: [2, 3, 4]
......@@ -261,19 +292,32 @@ void testConcat() {
* output:
* out.shape: [2, 6, 4]
*/
dim_a = paddle::framework::make_ddim({2, 3, 4});
dim_b = paddle::framework::make_ddim({2, 3, 4});
dim_out = paddle::framework::make_ddim({2, 6, 4});
template <typename DeviceContext, typename Place>
void ConcatCase4(DeviceContext* context) {
paddle::framework::Tensor input_a_cpu;
paddle::framework::Tensor input_b_cpu;
paddle::framework::Tensor out_cpu;
paddle::framework::Tensor input_a;
paddle::framework::Tensor input_b;
paddle::framework::Tensor out;
auto dim_a = paddle::framework::make_ddim({2, 3, 4});
auto dim_b = paddle::framework::make_ddim({2, 3, 4});
auto dim_out = paddle::framework::make_ddim({2, 6, 4});
input_a.mutable_data<int>(dim_a, Place());
input_b.mutable_data<int>(dim_b, Place());
out.mutable_data<int>(dim_out, Place());
input_a.Resize(dim_a);
input_b.Resize(dim_b);
out.Resize(dim_out);
if (paddle::platform::is_gpu_place(Place())) {
input_a_cpu.Resize(dim_a);
input_b_cpu.Resize(dim_b);
out_cpu.Resize(dim_out);
input_a_cpu.mutable_data<int>(dim_a, paddle::platform::CPUPlace());
input_b_cpu.mutable_data<int>(dim_b, paddle::platform::CPUPlace());
out_cpu.mutable_data<int>(dim_out, paddle::platform::CPUPlace());
}
int* a_ptr = nullptr;
int* b_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
a_ptr = input_a_cpu.data<int>();
b_ptr = input_b_cpu.data<int>();
......@@ -294,16 +338,19 @@ void testConcat() {
paddle::framework::TensorCopySync(input_b_cpu, Place(), &input_b);
}
input.clear();
std::vector<paddle::framework::Tensor> input;
input.push_back(input_a);
input.push_back(input_b);
paddle::operators::math::ConcatFunctor<DeviceContext, int> concat_functor;
concat_functor(*context, input, 1, &out);
context->Wait();
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ(input_a.dims(), dim_a);
PADDLE_ENFORCE_EQ(input_b.dims(), dim_b);
int* out_ptr = nullptr;
if (paddle::platform::is_gpu_place(Place())) {
paddle::framework::TensorCopySync(out, paddle::platform::CPUPlace(),
&out_cpu);
......@@ -313,8 +360,8 @@ void testConcat() {
}
// check the data
cols = 12;
idx_a = 0, idx_b = 0;
int cols = 12;
int idx_a = 0, idx_b = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 24; ++j) {
if (j >= cols) {
......@@ -328,10 +375,21 @@ void testConcat() {
}
}
template <typename DeviceContext, typename Place>
void TestConcatMain() {
DeviceContext* context = new DeviceContext(Place());
ConcatCase1<DeviceContext, Place>(context);
ConcatCase2<DeviceContext, Place>(context);
ConcatCase3<DeviceContext, Place>(context);
ConcatCase4<DeviceContext, Place>(context);
}
TEST(math, concat) {
testConcat<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
TestConcatMain<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA
testConcat<paddle::platform::CUDADeviceContext,
TestConcatMain<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
#endif
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册