未验证 提交 6d1da405 编写于 作者: P Pei Yang 提交者: GitHub

Add concat and elementwise_add cuda kernel (#1979)

* add nearest_interp_cuda kernel, test=develop

* add concat op and elementwise_add op

* remove eigen dependency from nearest_interp cuda kernel, test=develop

* free cuda pointers, test=develop
上级 da328594
......@@ -159,6 +159,7 @@ USE_LITE_KERNEL(conv2d, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(nearest_interp, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(yolo_box, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kCUDA, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_OPENCL
......
......@@ -5,17 +5,22 @@ endif()
message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
nv_library(yolo_box_compute_cuda SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
nv_library(nearest_interp_compute_cuda SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps}
${math_cuda})
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_library(conv2d_cuda SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
nv_library(concat_compute_cuda SRCS concat_compute.cu DEPS ${lite_kernel_deps})
nv_library(elementwise_add_compute_cuda SRCS elementwise_add_compute.cu DEPS ${lite_kernel_deps})
nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_library(calib_compute_cuda SRCS calib_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
......@@ -26,6 +31,8 @@ mul_compute_cuda
io_copy_compute_cuda
leaky_relu_compute_cuda
nearest_interp_compute_cuda
concat_compute_cuda
elementwise_add_compute_cuda
yolo_box_compute_cuda
)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/concat_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
template <typename T>
__global__ void ConcatKernel(const T** inputs,
const int* input_cols,
int col_size,
const int output_rows,
const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0;
int curr_offset = input_cols[0];
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int curr_col_offset = input_cols[curr_segment + 1];
while (curr_col_offset <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
curr_col_offset = input_cols[curr_segment + 1];
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
const T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
template <typename T>
__device__ void ConcatKernelDetail(const T** inputs_data,
const int fixed_in_col,
const int out_rows,
const int out_cols,
T* output_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * 1.0 / fixed_in_col;
int in_offset = tid_x - split * fixed_in_col;
const T* input_ptr = inputs_data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
output_data[tid_y * out_cols + tid_x] =
input_ptr[tid_y * fixed_in_col + in_offset];
}
}
// for (int i = 0; i < 4; i++){
// printf("input[0][%d] = %.1f\n", i, inputs_data[0][i]);
// printf("output[%d] = %.1f\n", i, output_data[i]);
// }
}
template <typename T>
__global__ void ConcatKernel(const T* input_addr0,
const T* input_addr1,
const int fixed_in_col,
const int out_rows,
const int out_cols,
T* output_data) {
const T* inputs_data[2];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void ConcatKernel(const T* input_addr0,
const T* input_addr1,
const T* input_addr2,
const int fixed_in_col,
const int out_rows,
const int out_cols,
T* output_data) {
const T* inputs_data[3];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
inputs_data[2] = input_addr2;
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void ConcatKernel(const T* input_addr0,
const T* input_addr1,
const T* input_addr2,
const T* input_addr3,
const int fixed_in_col,
const int out_rows,
const int out_cols,
T* output_data) {
const T* inputs_data[4];
inputs_data[0] = input_addr0;
inputs_data[1] = input_addr1;
inputs_data[2] = input_addr2;
inputs_data[3] = input_addr3;
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
template <typename T>
__global__ void ConcatKernel(const T** inputs_data,
const int in_num,
const int fixed_in_col,
const int out_rows,
const int out_cols,
T* output_data) {
ConcatKernelDetail<T>(
inputs_data, fixed_in_col, out_rows, out_cols, output_data);
}
static inline void GetBlockDims(const CUDAContext& context,
int num_rows,
int num_cols,
dim3* block_dims,
dim3* grid_dims) {
// Set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((num_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
*block_dims = dim3(block_cols, block_rows, 1);
int grid_cols = (num_cols + block_cols - 1) / block_cols;
int grid_rows = std::max(num_rows / block_rows, 1);
*grid_dims = dim3(grid_cols, grid_rows, 1);
}
void ConcatCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
std::vector<Tensor*> input = param.x;
Tensor* output = param.output;
int axis = param.axis;
int in_num = input.size();
int in_row = 1;
auto dim_0 = input[0]->dims();
for (int i = 0; i < axis; ++i) {
in_row *= dim_0[i];
}
int in_col = input[0]->numel() / in_row;
int out_row = in_row, out_col = 0;
std::vector<const float*> inputs_data(in_num);
std::vector<int> inputs_col(in_num + 1);
inputs_col[0] = 0;
bool has_same_shape = true;
for (int i = 0; i < in_num; ++i) {
int t_cols = input[i]->numel() / in_row;
if (has_same_shape) {
if (t_cols != in_col) has_same_shape = false;
}
out_col += t_cols;
inputs_col[i + 1] = out_col;
inputs_data[i] = input[i]->data<float>();
}
dim3 block_dims;
dim3 grid_dims;
GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims);
const float** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) {
float* tmp_dev_ins_data = nullptr;
CHECK(cudaSuccess ==
cudaMalloc(&tmp_dev_ins_data, inputs_data.size() * sizeof(float*)));
CHECK(cudaSuccess == cudaMemcpy(tmp_dev_ins_data,
static_cast<void*>(inputs_data.data()),
inputs_data.size() * sizeof(float*),
cudaMemcpyHostToDevice));
dev_ins_data = reinterpret_cast<const float**>(tmp_dev_ins_data);
}
if (has_same_shape) {
if (in_num == 2) {
ConcatKernel<float><<<grid_dims, block_dims, 0, stream>>>(
inputs_data[0],
inputs_data[1],
in_col,
out_row,
out_col,
output->mutable_data<float>());
} else if (in_num == 3) {
ConcatKernel<float><<<grid_dims, block_dims, 0, stream>>>(
inputs_data[0],
inputs_data[1],
inputs_data[2],
in_col,
out_row,
out_col,
output->mutable_data<float>());
} else if (in_num == 4) {
ConcatKernel<float><<<grid_dims, block_dims, 0, stream>>>(
inputs_data[0],
inputs_data[1],
inputs_data[2],
inputs_data[3],
in_col,
out_row,
out_col,
output->mutable_data<float>());
} else {
ConcatKernel<float><<<grid_dims, block_dims, 0, stream>>>(
dev_ins_data,
in_num,
in_col,
out_row,
out_col,
output->mutable_data<float>());
cudaFree(dev_ins_data);
}
} else {
int* tmp_dev_ins_col_data = nullptr;
CHECK(cudaSuccess ==
cudaMalloc(&tmp_dev_ins_col_data, inputs_col.size() * sizeof(int)));
CHECK(cudaSuccess == cudaMemcpy(tmp_dev_ins_col_data,
static_cast<void*>(inputs_col.data()),
inputs_col.size() * sizeof(int),
cudaMemcpyHostToDevice));
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data);
ConcatKernel<float><<<grid_dims, block_dims, 0, stream>>>(
dev_ins_data,
dev_ins_col_data,
static_cast<int>(inputs_col.size()),
out_row,
out_col,
output->mutable_data<float>());
cudaFree(dev_ins_data);
cudaFree(dev_ins_col_data);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(concat,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ConcatCompute,
def)
.BindInput("x", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("axis", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("output", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ConcatCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;
void Run() override;
virtual ~ConcatCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/concat_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
bool infer_shape(const operators::ConcatParam& param) {
std::vector<lite::DDim> input_dims;
for (auto p : param.x) {
input_dims.push_back(p->dims());
}
size_t axis = static_cast<size_t>(param.axis);
const size_t n = input_dims.size();
CHECK_GT_OR_FALSE(n, 0);
auto& out_dims = input_dims[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += input_dims[i][j];
} else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]);
}
}
}
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
// Set output dims
param.output->Resize(lite::DDim(out_dims));
return true;
}
void concat_compute_ref(const operators::ConcatParam& param) {
std::vector<lite::Tensor*> input = param.x;
int axis = param.axis;
infer_shape(param);
lite::Tensor* output = param.output;
int num = input.size();
int rows = 1;
auto dim_0 = input[0]->dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = input[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < input[i]->dims().size(); ++didx) {
input_i_numel *= input[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto output_data = output->mutable_data<float>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = input[j]->data<float>();
for (int k = 0; k < out_rows; ++k) {
memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len,
sizeof(float) * col_len);
}
col_idx += col_len;
}
}
TEST(concat, init) {
ConcatCompute concat;
ASSERT_EQ(concat.precision(), PRECISION(kFloat));
ASSERT_EQ(concat.target(), TARGET(kCUDA));
}
TEST(concat, compute_input_multi) {
ConcatCompute concat_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ConcatParam param;
operators::ConcatParam param_ref;
LOG(INFO) << "test concat start";
// init param
std::vector<lite::Tensor*> x;
std::vector<lite::Tensor*> x_cpu;
std::vector<lite::Tensor*> x_ref;
lite::Tensor out;
lite::Tensor out_cpu;
lite::Tensor out_ref;
lite::Tensor tensorA;
lite::Tensor tensorB;
lite::Tensor tensorC;
lite::Tensor tensorD;
lite::Tensor tensorA_cpu;
lite::Tensor tensorB_cpu;
lite::Tensor tensorC_cpu;
lite::Tensor tensorD_cpu;
lite::Tensor tensorA_ref;
lite::Tensor tensorB_ref;
lite::Tensor tensorC_ref;
lite::Tensor tensorD_ref;
DDimLite ddimA({1, 3, 1, 2});
DDimLite ddimB({1, 4, 1, 2});
DDimLite ddimC({1, 5, 1, 2});
DDimLite ddimD({1, 6, 1, 2});
tensorA.Resize(ddimA);
tensorB.Resize(ddimB);
tensorC.Resize(ddimC);
tensorD.Resize(ddimD);
tensorA_cpu.Resize(ddimA);
tensorB_cpu.Resize(ddimB);
tensorC_cpu.Resize(ddimC);
tensorD_cpu.Resize(ddimD);
tensorA_ref.Resize(ddimA);
tensorB_ref.Resize(ddimB);
tensorC_ref.Resize(ddimC);
tensorD_ref.Resize(ddimD);
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < tensorA_cpu.numel(); i++) {
tensorA_cpu.mutable_data<float>()[i] = i;
tensorA_ref.mutable_data<float>()[i] = i;
}
for (int i = 0; i < tensorB_cpu.numel(); i++) {
tensorB_cpu.mutable_data<float>()[i] = i + 3;
tensorB_ref.mutable_data<float>()[i] = i + 3;
}
for (int i = 0; i < tensorC_cpu.numel(); i++) {
tensorC_cpu.mutable_data<float>()[i] = i + 6;
tensorC_ref.mutable_data<float>()[i] = i + 6;
}
for (int i = 0; i < tensorD_cpu.numel(); i++) {
tensorD_cpu.mutable_data<float>()[i] = i + 9;
tensorD_ref.mutable_data<float>()[i] = i + 9;
}
tensorA.Assign<float, lite::DDim, TARGET(kCUDA)>(
tensorA_cpu.mutable_data<float>(), tensorA_cpu.dims());
tensorB.Assign<float, lite::DDim, TARGET(kCUDA)>(
tensorB_cpu.mutable_data<float>(), tensorB_cpu.dims());
tensorC.Assign<float, lite::DDim, TARGET(kCUDA)>(
tensorC_cpu.mutable_data<float>(), tensorC_cpu.dims());
tensorD.Assign<float, lite::DDim, TARGET(kCUDA)>(
tensorD_cpu.mutable_data<float>(), tensorD_cpu.dims());
x.push_back(&tensorA);
x.push_back(&tensorB);
x.push_back(&tensorC);
x.push_back(&tensorD);
x_cpu.push_back(&tensorA_cpu);
x_cpu.push_back(&tensorB_cpu);
x_cpu.push_back(&tensorC_cpu);
x_cpu.push_back(&tensorD_cpu);
x_ref.push_back(&tensorA_ref);
x_ref.push_back(&tensorB_ref);
x_ref.push_back(&tensorC_ref);
x_ref.push_back(&tensorD_ref);
for (int cur_axis : {1}) {
param.x = x;
param.axis = cur_axis;
param.output = &out;
concat_kernel.SetParam(param);
LOG(INFO) << "test concat start cur_axis:" << cur_axis;
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
concat_kernel.SetContext(std::move(ctx));
concat_kernel.Launch();
cudaDeviceSynchronize();
LOG(INFO) << "sync end";
CHECK(cudaSuccess == cudaMemcpy(out_cpu_data,
out_data,
sizeof(float) * out.numel(),
cudaMemcpyDeviceToHost));
LOG(INFO) << "concat.Run end";
param_ref.x = x_ref;
param_ref.axis = cur_axis;
param_ref.output = &out_ref;
LOG(INFO) << "concat_compute_ref start";
concat_compute_ref(param_ref);
LOG(INFO) << "concat_compute_ref end";
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/elementwise_add_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
__global__ void KeElementwiseAdd(const float* x_data,
const float* y_data,
float* out_data,
const size_t total) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < total; tid += stride) {
out_data[tid] = x_data[tid] + y_data[tid];
}
}
void ElementwiseAddCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const lite::Tensor* x = param.X;
const lite::Tensor* y = param.Y;
lite::Tensor* out = param.Out;
CHECK(x->dims() == y->dims());
const int n = x->dims()[0];
const int c = x->dims()[1];
const int h = x->dims()[2];
const int w = x->dims()[3];
auto* x_data = x->data<float>();
auto* y_data = y->data<float>();
auto out_data = out->mutable_data<float>(TARGET(kCUDA));
int pixel_num = x->numel();
int threads = 512;
int blocks = (pixel_num + threads - 1) / threads;
blocks = blocks > 8 ? 8 : blocks;
KeElementwiseAdd<<<blocks, threads, 0, stream>>>(
x_data, y_data, out_data, pixel_num);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(elementwise_add,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseAddCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class ElementwiseAddCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseAddCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/elementwise_add_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
static void ElementwiseAddRef(float* x, float* y, float* out, int num) {
for (int i = 0; i < num; ++i) {
out[i] = x[i] + y[i];
// LOG(INFO) << x[i] << " + " << y[i] << " = " << out[i];
}
}
TEST(elementwise_add, normal) {
ElementwiseAddCompute elementwise_add_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ElementwiseParam param;
Tensor x, y, out;
Tensor x_cpu, y_cpu, out_cpu;
Tensor x_ref, y_ref, out_ref;
const int n = 1;
const int c = 3;
const int h = 2000;
const int w = 2000;
x.Resize({n, c, h, w});
y.Resize({n, c, h, w});
out.Resize({n, c, h, w});
x_cpu.Resize({n, c, h, w});
y_cpu.Resize({n, c, h, w});
out_cpu.Resize({n, c, h, w});
x_ref.Resize({n, c, h, w});
y_ref.Resize({n, c, h, w});
out_ref.Resize({n, c, h, w});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* y_cpu_data = y_cpu.mutable_data<float>();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* y_ref_data = y_ref.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = i + 5.0;
x_ref_data[i] = i + 5.0;
}
for (int i = 0; i < y_cpu.numel(); ++i) {
y_cpu_data[i] = i - 5.0;
y_ref_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
param.X = &x;
param.Y = &y;
param.Out = &out;
elementwise_add_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
elementwise_add_kernel.SetContext(std::move(ctx));
elementwise_add_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
ElementwiseAddRef(x_ref_data, y_ref_data, out_ref_data, out.numel());
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -16,91 +16,58 @@
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include "lite/fluid/eigen.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T,
size_t D,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = lite::fluid::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = lite::Tensor;
static void NearestNeighborInterpolate(const Tensor& input,
Tensor* output,
const float ratio_h,
const float ratio_w,
const int n,
const int c,
const int out_h,
const int out_w,
const bool align_corners) {
auto input_t = EigenTensor<float, 4>::From(input);
auto output_t = EigenTensor<float, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = (align_corners) ? static_cast<int>(ratio_h * k + 0.5)
: static_cast<int>(ratio_h * k);
for (int l = 0; l < out_w; l++) {
int in_l = (align_corners) ? static_cast<int>(ratio_w * l + 0.5)
: static_cast<int>(ratio_w * l);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
void NearestInterpRef(Tensor* input, Tensor* output, bool with_align) {
int hin = input->dims()[2];
int win = input->dims()[3];
int channels = input->dims()[1];
int num = input->dims()[0];
int hout = output->dims()[2];
int wout = output->dims()[3];
float scale_w = (with_align) ? (static_cast<float>(win - 1) / (wout - 1))
: (static_cast<float>(win) / (wout));
float scale_h = (with_align) ? (static_cast<float>(hin - 1) / (hout - 1))
: (static_cast<float>(hin) / (hout));
const float* src = input->data<float>();
float* dst = output->mutable_data<float>();
int dst_stride_w = 1;
int dst_stride_h = wout;
int dst_stride_c = wout * hout;
int dst_stride_batch = wout * hout * channels;
int src_stride_w = 1;
int src_stride_h = win;
int src_stride_c = win * hin;
int src_stride_batch = win * hin * channels;
for (int n = 0; n < num; ++n) {
for (int c = 0; c < channels; ++c) {
int src_index = n * src_stride_batch + c * src_stride_c;
for (int h = 0; h < hout; ++h) {
for (int w = 0; w < wout; ++w) {
int fw = (with_align) ? static_cast<int>(scale_w * w + 0.5)
: static_cast<int>(scale_w * w);
fw = (fw < 0) ? 0 : fw;
int fh = (with_align) ? static_cast<int>(scale_h * h + 0.5)
: static_cast<int>(scale_h * h);
fh = (fh < 0) ? 0 : fh;
int w_start = static_cast<int>(fw);
int h_start = static_cast<int>(fh);
int dst_index = n * dst_stride_batch + c * dst_stride_c +
h * dst_stride_h + w * dst_stride_w;
dst[dst_index] =
src[src_index + w_start * src_stride_w + h_start * src_stride_h];
}
}
}
}
}
static void NearestInterpRef(operators::InterpolateParam param,
Tensor* input,
const size_t scale,
const size_t n,
const size_t c,
const size_t in_h,
const size_t in_w,
Tensor* output_size,
Tensor* output,
size_t out_h,
size_t out_w) {
if (scale > 0) {
out_h = static_cast<int>(in_h * scale);
out_w = static_cast<int>(in_w * scale);
}
bool align_corners = param.align_corners;
if (output_size != nullptr) {
auto out_size_data = output_size->mutable_data<float>();
out_h = static_cast<int>(out_size_data[0]);
out_w = static_cast<int>(out_size_data[1]);
}
float* input_data = input->mutable_data<float>();
LOG(INFO) << *(input_data + 2);
float* output_data = output->mutable_data<float>();
LOG(INFO) << *(output_data + 2);
if (in_h == out_h && in_w == out_w) {
std::memcpy(output_data, input_data, sizeof(float) * n * c * in_h * in_w);
LOG(INFO) << *(output_data + 2);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(in_h - 1) / (out_h - 1)
: static_cast<float>(in_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(in_w - 1) / (out_w - 1)
: static_cast<float>(in_w) / out_w;
}
NearestNeighborInterpolate(
*input, output, ratio_h, ratio_w, n, c, out_h, out_w, align_corners);
}
TEST(nearest_interp, normal) {
NearestInterpCompute nearest_interp_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
......@@ -112,9 +79,9 @@ TEST(nearest_interp, normal) {
Tensor x_cpu, osz_cpu, out_cpu;
Tensor x_ref, osz_ref, out_ref;
int n = 1, c = 3, in_h = 4, in_w = 4;
int n = 1, c = 3, in_h = 40, in_w = 40;
int in_chw = c * in_h * in_w;
int out_h = 4, out_w = 4;
int out_h = 80, out_w = 80;
float scale = 2.0;
param.out_h = out_h;
......@@ -173,8 +140,7 @@ TEST(nearest_interp, normal) {
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
NearestInterpRef(
param, &x_ref, scale, n, c, in_h, in_w, &osz_ref, &out_ref, out_h, out_w);
NearestInterpRef(&x_ref, &out_ref, false);
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册