diff --git a/lite/api/_paddle_use_kernels.h b/lite/api/_paddle_use_kernels.h index 9b16781935c2921b62b2502cb2f1794fc44a2c23..efeb75509422a2dce1658fdab61b58f68765066b 100644 --- a/lite/api/_paddle_use_kernels.h +++ b/lite/api/_paddle_use_kernels.h @@ -159,6 +159,7 @@ USE_LITE_KERNEL(conv2d, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def); USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(concat, kCUDA, kFloat, kNCHW, def); #endif #ifdef LITE_WITH_OPENCL diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index a0c79465ec3154d5e346a0a46f9e8cfdaba0476f..25c7cee185591b11b4a35ab87112dc402b5705ef 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -5,17 +5,22 @@ endif() message(STATUS "compile with lite CUDA kernels") nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context) + lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) + nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps}) -nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps} -${math_cuda}) -nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) +nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda}) +nv_library(concat_compute_cuda SRCS concat_compute.cu DEPS ${lite_kernel_deps}) +nv_library(elementwise_add_compute_cuda SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps}) +nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) +nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) +nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) nv_library(calib_compute_cuda SRCS calib_compute.cu DEPS ${lite_kernel_deps}) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) @@ -26,6 +31,8 @@ mul_compute_cuda io_copy_compute_cuda leaky_relu_compute_cuda nearest_interp_compute_cuda +concat_compute_cuda +elementwise_add_compute_cuda yolo_box_compute_cuda ) diff --git a/lite/kernels/cuda/concat_compute.cu b/lite/kernels/cuda/concat_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..10a9414935d2d7bf552490aa2aeab8446eca47a4 --- /dev/null +++ b/lite/kernels/cuda/concat_compute.cu @@ -0,0 +1,276 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/concat_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { +using Tensor = lite::Tensor; + +template +__global__ void ConcatKernel(const T** inputs, + const int* input_cols, + int col_size, + const int output_rows, + const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int curr_segment = 0; + int curr_offset = input_cols[0]; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int curr_col_offset = input_cols[curr_segment + 1]; + while (curr_col_offset <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + curr_col_offset = input_cols[curr_segment + 1]; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + + const T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__device__ void ConcatKernelDetail(const T** inputs_data, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * 1.0 / fixed_in_col; + int in_offset = tid_x - split * fixed_in_col; + const T* input_ptr = inputs_data[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { + output_data[tid_y * out_cols + tid_x] = + input_ptr[tid_y * fixed_in_col + in_offset]; + } + } + // for (int i = 0; i < 4; i++){ + // printf("input[0][%d] = %.1f\n", i, inputs_data[0][i]); + // printf("output[%d] = %.1f\n", i, output_data[i]); + // } +} + +template +__global__ void ConcatKernel(const T* input_addr0, + const T* input_addr1, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + const T* inputs_data[2]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel(const T* input_addr0, + const T* input_addr1, + const T* input_addr2, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + const T* inputs_data[3]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel(const T* input_addr0, + const T* input_addr1, + const T* input_addr2, + const T* input_addr3, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + const T* inputs_data[4]; + inputs_data[0] = input_addr0; + inputs_data[1] = input_addr1; + inputs_data[2] = input_addr2; + inputs_data[3] = input_addr3; + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +template +__global__ void ConcatKernel(const T** inputs_data, + const int in_num, + const int fixed_in_col, + const int out_rows, + const int out_cols, + T* output_data) { + ConcatKernelDetail( + inputs_data, fixed_in_col, out_rows, out_cols, output_data); +} + +static inline void GetBlockDims(const CUDAContext& context, + int num_rows, + int num_cols, + dim3* block_dims, + dim3* grid_dims) { + // Set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((num_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + *block_dims = dim3(block_cols, block_rows, 1); + + int grid_cols = (num_cols + block_cols - 1) / block_cols; + int grid_rows = std::max(num_rows / block_rows, 1); + *grid_dims = dim3(grid_cols, grid_rows, 1); +} + +void ConcatCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + std::vector input = param.x; + Tensor* output = param.output; + int axis = param.axis; + + int in_num = input.size(); + int in_row = 1; + auto dim_0 = input[0]->dims(); + for (int i = 0; i < axis; ++i) { + in_row *= dim_0[i]; + } + int in_col = input[0]->numel() / in_row; + int out_row = in_row, out_col = 0; + + std::vector inputs_data(in_num); + std::vector inputs_col(in_num + 1); + inputs_col[0] = 0; + bool has_same_shape = true; + for (int i = 0; i < in_num; ++i) { + int t_cols = input[i]->numel() / in_row; + if (has_same_shape) { + if (t_cols != in_col) has_same_shape = false; + } + out_col += t_cols; + inputs_col[i + 1] = out_col; + inputs_data[i] = input[i]->data(); + } + dim3 block_dims; + dim3 grid_dims; + GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); + const float** dev_ins_data = nullptr; + if (!has_same_shape || in_num < 2 || in_num > 4) { + float* tmp_dev_ins_data = nullptr; + CHECK(cudaSuccess == + cudaMalloc(&tmp_dev_ins_data, inputs_data.size() * sizeof(float*))); + CHECK(cudaSuccess == cudaMemcpy(tmp_dev_ins_data, + static_cast(inputs_data.data()), + inputs_data.size() * sizeof(float*), + cudaMemcpyHostToDevice)); + dev_ins_data = reinterpret_cast(tmp_dev_ins_data); + } + if (has_same_shape) { + if (in_num == 2) { + ConcatKernel<<>>( + inputs_data[0], + inputs_data[1], + in_col, + out_row, + out_col, + output->mutable_data()); + } else if (in_num == 3) { + ConcatKernel<<>>( + inputs_data[0], + inputs_data[1], + inputs_data[2], + in_col, + out_row, + out_col, + output->mutable_data()); + } else if (in_num == 4) { + ConcatKernel<<>>( + inputs_data[0], + inputs_data[1], + inputs_data[2], + inputs_data[3], + in_col, + out_row, + out_col, + output->mutable_data()); + } else { + ConcatKernel<<>>( + dev_ins_data, + in_num, + in_col, + out_row, + out_col, + output->mutable_data()); + cudaFree(dev_ins_data); + } + } else { + int* tmp_dev_ins_col_data = nullptr; + + CHECK(cudaSuccess == + cudaMalloc(&tmp_dev_ins_col_data, inputs_col.size() * sizeof(int))); + CHECK(cudaSuccess == cudaMemcpy(tmp_dev_ins_col_data, + static_cast(inputs_col.data()), + inputs_col.size() * sizeof(int), + cudaMemcpyHostToDevice)); + int* dev_ins_col_data = static_cast(tmp_dev_ins_col_data); + ConcatKernel<<>>( + dev_ins_data, + dev_ins_col_data, + static_cast(inputs_col.size()), + out_row, + out_col, + output->mutable_data()); + cudaFree(dev_ins_data); + cudaFree(dev_ins_col_data); + } + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(concat, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ConcatCompute, + def) + .BindInput("x", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("axis", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("output", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/concat_compute.h b/lite/kernels/cuda/concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..342ab5cba7d88ff4a04822e4c135a903f3dbe406 --- /dev/null +++ b/lite/kernels/cuda/concat_compute.h @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ConcatCompute : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void Run() override; + virtual ~ConcatCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/concat_compute_test.cc b/lite/kernels/cuda/concat_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8dc097be48c1ccc55a8e068c0aca1f52effede5d --- /dev/null +++ b/lite/kernels/cuda/concat_compute_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/concat_compute.h" +#include +#include +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +bool infer_shape(const operators::ConcatParam& param) { + std::vector input_dims; + for (auto p : param.x) { + input_dims.push_back(p->dims()); + } + size_t axis = static_cast(param.axis); + const size_t n = input_dims.size(); + CHECK_GT_OR_FALSE(n, 0); + auto& out_dims = input_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += input_dims[i][j]; + } else { + CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + } + } + } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } + // Set output dims + param.output->Resize(lite::DDim(out_dims)); + return true; +} + +void concat_compute_ref(const operators::ConcatParam& param) { + std::vector input = param.x; + int axis = param.axis; + infer_shape(param); + + lite::Tensor* output = param.output; + int num = input.size(); + int rows = 1; + auto dim_0 = input[0]->dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(input.size()); + for (int i = 0; i < num; ++i) { + int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1; + for (int didx = 0; didx < input[i]->dims().size(); ++didx) { + input_i_numel *= input[i]->dims()[didx]; + } + int t_cols = input_i_numel / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + auto output_data = output->mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = input[j]->data(); + for (int k = 0; k < out_rows; ++k) { + memcpy(output_data + k * out_cols + col_idx, + input_data + k * col_len, + sizeof(float) * col_len); + } + col_idx += col_len; + } +} + +TEST(concat, init) { + ConcatCompute concat; + ASSERT_EQ(concat.precision(), PRECISION(kFloat)); + ASSERT_EQ(concat.target(), TARGET(kCUDA)); +} + +TEST(concat, compute_input_multi) { + ConcatCompute concat_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ConcatParam param; + operators::ConcatParam param_ref; + + LOG(INFO) << "test concat start"; + // init param + std::vector x; + std::vector x_cpu; + std::vector x_ref; + lite::Tensor out; + lite::Tensor out_cpu; + lite::Tensor out_ref; + lite::Tensor tensorA; + lite::Tensor tensorB; + lite::Tensor tensorC; + lite::Tensor tensorD; + lite::Tensor tensorA_cpu; + lite::Tensor tensorB_cpu; + lite::Tensor tensorC_cpu; + lite::Tensor tensorD_cpu; + lite::Tensor tensorA_ref; + lite::Tensor tensorB_ref; + lite::Tensor tensorC_ref; + lite::Tensor tensorD_ref; + + DDimLite ddimA({1, 3, 1, 2}); + DDimLite ddimB({1, 4, 1, 2}); + DDimLite ddimC({1, 5, 1, 2}); + DDimLite ddimD({1, 6, 1, 2}); + + tensorA.Resize(ddimA); + tensorB.Resize(ddimB); + tensorC.Resize(ddimC); + tensorD.Resize(ddimD); + tensorA_cpu.Resize(ddimA); + tensorB_cpu.Resize(ddimB); + tensorC_cpu.Resize(ddimC); + tensorD_cpu.Resize(ddimD); + tensorA_ref.Resize(ddimA); + tensorB_ref.Resize(ddimB); + tensorC_ref.Resize(ddimC); + tensorD_ref.Resize(ddimD); + + auto* out_data = out.mutable_data(TARGET(kCUDA)); + auto* out_cpu_data = out_cpu.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + for (int i = 0; i < tensorA_cpu.numel(); i++) { + tensorA_cpu.mutable_data()[i] = i; + tensorA_ref.mutable_data()[i] = i; + } + for (int i = 0; i < tensorB_cpu.numel(); i++) { + tensorB_cpu.mutable_data()[i] = i + 3; + tensorB_ref.mutable_data()[i] = i + 3; + } + for (int i = 0; i < tensorC_cpu.numel(); i++) { + tensorC_cpu.mutable_data()[i] = i + 6; + tensorC_ref.mutable_data()[i] = i + 6; + } + for (int i = 0; i < tensorD_cpu.numel(); i++) { + tensorD_cpu.mutable_data()[i] = i + 9; + tensorD_ref.mutable_data()[i] = i + 9; + } + tensorA.Assign( + tensorA_cpu.mutable_data(), tensorA_cpu.dims()); + tensorB.Assign( + tensorB_cpu.mutable_data(), tensorB_cpu.dims()); + tensorC.Assign( + tensorC_cpu.mutable_data(), tensorC_cpu.dims()); + tensorD.Assign( + tensorD_cpu.mutable_data(), tensorD_cpu.dims()); + + x.push_back(&tensorA); + x.push_back(&tensorB); + x.push_back(&tensorC); + x.push_back(&tensorD); + x_cpu.push_back(&tensorA_cpu); + x_cpu.push_back(&tensorB_cpu); + x_cpu.push_back(&tensorC_cpu); + x_cpu.push_back(&tensorD_cpu); + x_ref.push_back(&tensorA_ref); + x_ref.push_back(&tensorB_ref); + x_ref.push_back(&tensorC_ref); + x_ref.push_back(&tensorD_ref); + + for (int cur_axis : {1}) { + param.x = x; + param.axis = cur_axis; + param.output = &out; + + concat_kernel.SetParam(param); + LOG(INFO) << "test concat start cur_axis:" << cur_axis; + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + concat_kernel.SetContext(std::move(ctx)); + concat_kernel.Launch(); + cudaDeviceSynchronize(); + LOG(INFO) << "sync end"; + CHECK(cudaSuccess == cudaMemcpy(out_cpu_data, + out_data, + sizeof(float) * out.numel(), + cudaMemcpyDeviceToHost)); + LOG(INFO) << "concat.Run end"; + + param_ref.x = x_ref; + param_ref.axis = cur_axis; + param_ref.output = &out_ref; + + LOG(INFO) << "concat_compute_ref start"; + concat_compute_ref(param_ref); + LOG(INFO) << "concat_compute_ref end"; + + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_add_compute.cu b/lite/kernels/cuda/elementwise_add_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..fd3b74fb6e20d98fb04244ebc709e91f732d3e7a --- /dev/null +++ b/lite/kernels/cuda/elementwise_add_compute.cu @@ -0,0 +1,79 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "lite/core/op_registry.h" +#include "lite/kernels/cuda/elementwise_add_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +__global__ void KeElementwiseAdd(const float* x_data, + const float* y_data, + float* out_data, + const size_t total) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < total; tid += stride) { + out_data[tid] = x_data[tid] + y_data[tid]; + } +} + +void ElementwiseAddCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + const lite::Tensor* x = param.X; + const lite::Tensor* y = param.Y; + lite::Tensor* out = param.Out; + + CHECK(x->dims() == y->dims()); + + const int n = x->dims()[0]; + const int c = x->dims()[1]; + const int h = x->dims()[2]; + const int w = x->dims()[3]; + + auto* x_data = x->data(); + auto* y_data = y->data(); + auto out_data = out->mutable_data(TARGET(kCUDA)); + + int pixel_num = x->numel(); + int threads = 512; + int blocks = (pixel_num + threads - 1) / threads; + blocks = blocks > 8 ? 8 : blocks; + + KeElementwiseAdd<<>>( + x_data, y_data, out_data, pixel_num); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(elementwise_add, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::ElementwiseAddCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/elementwise_add_compute.h b/lite/kernels/cuda/elementwise_add_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..772dda8aba3a443dc94affae01fddfaf68e889e5 --- /dev/null +++ b/lite/kernels/cuda/elementwise_add_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class ElementwiseAddCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void Run() override; + virtual ~ElementwiseAddCompute() = default; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/elementwise_add_compute_test.cc b/lite/kernels/cuda/elementwise_add_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ba67e8eb73c8c96c433fe805529e35aa2fa8df0 --- /dev/null +++ b/lite/kernels/cuda/elementwise_add_compute_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/elementwise_add_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +using Tensor = lite::Tensor; + +static void ElementwiseAddRef(float* x, float* y, float* out, int num) { + for (int i = 0; i < num; ++i) { + out[i] = x[i] + y[i]; + // LOG(INFO) << x[i] << " + " << y[i] << " = " << out[i]; + } +} + +TEST(elementwise_add, normal) { + ElementwiseAddCompute elementwise_add_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::ElementwiseParam param; + Tensor x, y, out; + Tensor x_cpu, y_cpu, out_cpu; + Tensor x_ref, y_ref, out_ref; + + const int n = 1; + const int c = 3; + const int h = 2000; + const int w = 2000; + + x.Resize({n, c, h, w}); + y.Resize({n, c, h, w}); + out.Resize({n, c, h, w}); + x_cpu.Resize({n, c, h, w}); + y_cpu.Resize({n, c, h, w}); + out_cpu.Resize({n, c, h, w}); + x_ref.Resize({n, c, h, w}); + y_ref.Resize({n, c, h, w}); + out_ref.Resize({n, c, h, w}); + + auto* x_data = x.mutable_data(TARGET(kCUDA)); + auto* y_data = y.mutable_data(TARGET(kCUDA)); + auto* out_data = out.mutable_data(TARGET(kCUDA)); + + auto* x_cpu_data = x_cpu.mutable_data(); + auto* y_cpu_data = y_cpu.mutable_data(); + auto* out_cpu_data = out_cpu.mutable_data(); + + auto* x_ref_data = x_ref.mutable_data(); + auto* y_ref_data = y_ref.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = i + 5.0; + x_ref_data[i] = i + 5.0; + } + for (int i = 0; i < y_cpu.numel(); ++i) { + y_cpu_data[i] = i - 5.0; + y_ref_data[i] = i - 5.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + y.Assign(y_cpu_data, y_cpu.dims()); + + param.X = &x; + param.Y = &y; + param.Out = &out; + elementwise_add_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + elementwise_add_kernel.SetContext(std::move(ctx)); + elementwise_add_kernel.Launch(); + cudaDeviceSynchronize(); + + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + ElementwiseAddRef(x_ref_data, y_ref_data, out_ref_data, out.numel()); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/nearest_interp_compute_test.cc b/lite/kernels/cuda/nearest_interp_compute_test.cc index 4aec6db1a21eba59439ff4d6601bd1d220c4e804..6b98bf143b6c5803f18370af0170cc4dfb5cf5ec 100644 --- a/lite/kernels/cuda/nearest_interp_compute_test.cc +++ b/lite/kernels/cuda/nearest_interp_compute_test.cc @@ -16,91 +16,58 @@ #include #include #include -#include "lite/fluid/eigen.h" namespace paddle { namespace lite { namespace kernels { namespace cuda { -template -using EigenTensor = lite::fluid::EigenTensor; using Tensor = lite::Tensor; -static void NearestNeighborInterpolate(const Tensor& input, - Tensor* output, - const float ratio_h, - const float ratio_w, - const int n, - const int c, - const int out_h, - const int out_w, - const bool align_corners) { - auto input_t = EigenTensor::From(input); - auto output_t = EigenTensor::From(*output); - for (int k = 0; k < out_h; k++) { // loop for images - int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) - : static_cast(ratio_h * k); - for (int l = 0; l < out_w; l++) { - int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) - : static_cast(ratio_w * l); - for (int i = 0; i < n; i++) { // loop for batches - for (int j = 0; j < c; j++) { // loop for channels - output_t(i, j, k, l) = input_t(i, j, in_k, in_l); +void NearestInterpRef(Tensor* input, Tensor* output, bool with_align) { + int hin = input->dims()[2]; + int win = input->dims()[3]; + int channels = input->dims()[1]; + int num = input->dims()[0]; + int hout = output->dims()[2]; + int wout = output->dims()[3]; + float scale_w = (with_align) ? (static_cast(win - 1) / (wout - 1)) + : (static_cast(win) / (wout)); + float scale_h = (with_align) ? (static_cast(hin - 1) / (hout - 1)) + : (static_cast(hin) / (hout)); + const float* src = input->data(); + float* dst = output->mutable_data(); + int dst_stride_w = 1; + int dst_stride_h = wout; + int dst_stride_c = wout * hout; + int dst_stride_batch = wout * hout * channels; + int src_stride_w = 1; + int src_stride_h = win; + int src_stride_c = win * hin; + int src_stride_batch = win * hin * channels; + for (int n = 0; n < num; ++n) { + for (int c = 0; c < channels; ++c) { + int src_index = n * src_stride_batch + c * src_stride_c; + for (int h = 0; h < hout; ++h) { + for (int w = 0; w < wout; ++w) { + int fw = (with_align) ? static_cast(scale_w * w + 0.5) + : static_cast(scale_w * w); + fw = (fw < 0) ? 0 : fw; + int fh = (with_align) ? static_cast(scale_h * h + 0.5) + : static_cast(scale_h * h); + fh = (fh < 0) ? 0 : fh; + int w_start = static_cast(fw); + int h_start = static_cast(fh); + int dst_index = n * dst_stride_batch + c * dst_stride_c + + h * dst_stride_h + w * dst_stride_w; + dst[dst_index] = + src[src_index + w_start * src_stride_w + h_start * src_stride_h]; } } } } } -static void NearestInterpRef(operators::InterpolateParam param, - Tensor* input, - const size_t scale, - const size_t n, - const size_t c, - const size_t in_h, - const size_t in_w, - Tensor* output_size, - Tensor* output, - size_t out_h, - size_t out_w) { - if (scale > 0) { - out_h = static_cast(in_h * scale); - out_w = static_cast(in_w * scale); - } - bool align_corners = param.align_corners; - if (output_size != nullptr) { - auto out_size_data = output_size->mutable_data(); - out_h = static_cast(out_size_data[0]); - out_w = static_cast(out_size_data[1]); - } - - float* input_data = input->mutable_data(); - LOG(INFO) << *(input_data + 2); - float* output_data = output->mutable_data(); - LOG(INFO) << *(output_data + 2); - if (in_h == out_h && in_w == out_w) { - std::memcpy(output_data, input_data, sizeof(float) * n * c * in_h * in_w); - LOG(INFO) << *(output_data + 2); - return; - } - float ratio_h = 0.f; - float ratio_w = 0.f; - if (out_h > 1) { - ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) - : static_cast(in_h) / out_h; - } - if (out_w > 1) { - ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) - : static_cast(in_w) / out_w; - } - NearestNeighborInterpolate( - *input, output, ratio_h, ratio_w, n, c, out_h, out_w, align_corners); -} - TEST(nearest_interp, normal) { NearestInterpCompute nearest_interp_kernel; std::unique_ptr ctx(new KernelContext); @@ -112,9 +79,9 @@ TEST(nearest_interp, normal) { Tensor x_cpu, osz_cpu, out_cpu; Tensor x_ref, osz_ref, out_ref; - int n = 1, c = 3, in_h = 4, in_w = 4; + int n = 1, c = 3, in_h = 40, in_w = 40; int in_chw = c * in_h * in_w; - int out_h = 4, out_w = 4; + int out_h = 80, out_w = 80; float scale = 2.0; param.out_h = out_h; @@ -173,8 +140,7 @@ TEST(nearest_interp, normal) { CopySync( out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); - NearestInterpRef( - param, &x_ref, scale, n, c, in_h, in_w, &osz_ref, &out_ref, out_h, out_w); + NearestInterpRef(&x_ref, &out_ref, false); for (int i = 0; i < out.numel(); i++) { EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); }