未验证 提交 3d73dea9 编写于 作者: P Pei Yang 提交者: GitHub

add sequence_pool cuda kernel, test=develop (#2430)

add sequence_pool cuda kernel
上级 1e88d1e8
......@@ -9,6 +9,7 @@ add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_k
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
......@@ -38,6 +39,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_pool_compute.h"
const int CUDA_NUM_THREADS = 512;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
/// CUDA: number of blocks for threads.
inline int CUDA_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
inline int CUDA_GET_BLOCKS(const int N, const int base) {
return (N + base - 1) / base;
}
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
__global__ void seq_pool_average_kernel(Dtype* dst,
const Dtype* src_in,
const int batch_size,
const uint64_t* seq_offset,
const int slice_size) {
int total = slice_size * batch_size;
CUDA_KERNEL_LOOP(tid, total) {
int out_batch_id = tid / slice_size;
int out_id = tid % slice_size;
int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
seq_offset[out_batch_id]);
int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
src_in += in_offset + out_id;
Dtype sum = (Dtype)0;
for (int i = 0; i < in_slice_num; ++i) {
sum += src_in[i * slice_size];
}
dst[out_batch_id * slice_size + out_id] = sum / in_slice_num;
}
}
template <typename Dtype>
__global__ void seq_pool_sum_kernel(Dtype* dst,
const Dtype* src_in,
const int batch_size,
const uint64_t* seq_offset,
const int slice_size) {
int total = slice_size * batch_size;
CUDA_KERNEL_LOOP(tid, total) {
int out_batch_id = tid / slice_size;
int out_id = tid % slice_size;
int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
seq_offset[out_batch_id]);
int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
src_in += in_offset + out_id;
Dtype sum = (Dtype)0;
for (int i = 0; i < in_slice_num; ++i) {
sum += src_in[i * slice_size];
}
dst[out_batch_id * slice_size + out_id] = sum;
}
}
template <typename Dtype>
__global__ void seq_pool_sqrt_kernel(Dtype* dst,
const Dtype* src_in,
const int batch_size,
const uint64_t* seq_offset,
const int slice_size) {
int total = slice_size * batch_size;
CUDA_KERNEL_LOOP(tid, total) {
int out_batch_id = tid / slice_size;
int out_id = tid % slice_size;
int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
seq_offset[out_batch_id]);
int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
src_in += in_offset + out_id;
Dtype sum = (Dtype)0;
for (int i = 0; i < in_slice_num; ++i) {
sum += src_in[i * slice_size];
}
dst[out_batch_id * slice_size + out_id] = sum * rsqrtf(in_slice_num);
}
}
template <typename Dtype>
__global__ void seq_pool_max_kernel(Dtype* dst,
const Dtype* src_in,
const int batch_size,
const uint64_t* seq_offset,
const int slice_size) {
int total = slice_size * batch_size;
CUDA_KERNEL_LOOP(tid, total) {
int out_batch_id = tid / slice_size;
int out_id = tid % slice_size;
int in_slice_num = static_cast<int>(seq_offset[out_batch_id + 1] -
seq_offset[out_batch_id]);
int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
src_in += in_offset + out_id;
Dtype max = src_in[0];
for (int i = 1; i < in_slice_num; ++i) {
Dtype val = src_in[i * slice_size];
if (val > max) {
max = val;
}
}
dst[out_batch_id * slice_size + out_id] = max;
}
}
template <typename Dtype>
__global__ void seq_pool_last_kernel(Dtype* dst,
const Dtype* src_in,
const int batch_size,
const uint64_t* seq_offset,
const int slice_size) {
int total = slice_size * batch_size;
CUDA_KERNEL_LOOP(tid, total) {
int out_batch_id = tid / slice_size;
int out_id = tid % slice_size;
int in_offset =
(static_cast<int>(seq_offset[out_batch_id + 1]) - 1) * slice_size;
dst[tid] = src_in[in_offset + out_id];
}
}
template <typename Dtype>
__global__ void seq_pool_first_kernel(Dtype* dst,
const Dtype* src_in,
const int batch_size,
const uint64_t* seq_offset,
const int slice_size) {
int total = slice_size * batch_size;
CUDA_KERNEL_LOOP(tid, total) {
int out_batch_id = tid / slice_size;
int out_id = tid % slice_size;
int in_offset = static_cast<int>(seq_offset[out_batch_id] * slice_size);
dst[tid] = src_in[in_offset + out_id];
}
}
void SequencePoolCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
std::vector<uint64_t> seq_offset = param.X->lod()[0];
int slice_size =
param.Out->dims()[1] * param.Out->dims()[2] * param.Out->dims()[3];
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
const float* in_data = param.X->data<float>();
int batch_size = param.X->lod().size() - 1;
lite::Tensor seq_offset_D;
seq_offset_D.Resize({static_cast<int64_t>(seq_offset.size())});
TargetWrapperCuda::MemcpyAsync(seq_offset_D.mutable_data<uint64_t>(),
seq_offset.data(),
sizeof(uint64_t) * seq_offset.size(),
IoDirection::HtoD,
stream);
if (param.pool_type == "MAX") {
seq_pool_max_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
stream>>>(out_data,
in_data,
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else if (param.pool_type == "AVERAGE ") {
seq_pool_average_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
stream>>>(out_data,
in_data,
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else if (param.pool_type == "SUM") {
seq_pool_sum_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
stream>>>(out_data,
in_data,
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else if (param.pool_type == "SQRT") {
seq_pool_sqrt_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
stream>>>(out_data,
in_data,
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else if (param.pool_type == "FIRST") {
seq_pool_first_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
stream>>>(out_data,
in_data,
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else if (param.pool_type == "LAST") {
seq_pool_last_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
stream>>>(out_data,
in_data,
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else {
LOG(ERROR) << "pool type " << param.pool_type << " is not supoorted.";
}
std::vector<uint64_t> offset_new(static_cast<uint64_t>(batch_size + 1));
for (int i = 0; i <= batch_size; ++i) {
offset_new[i] = i;
}
std::vector<std::vector<uint64_t>> voffset_new;
voffset_new.push_back(offset_new);
param.Out->set_lod(voffset_new);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_pool,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequencePoolCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequencePoolCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SequencePoolParam;
void Run() override;
virtual ~SequencePoolCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_pool_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/backends/x86/math/sequence_pooling.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
namespace {
static void sequence_pool_ref(const operators::SequencePoolParam& param, ) {
auto* x = param.X;
auto* out = param.Out;
auto dims = x->dims();
auto lod = x->lod();
CHECK_EQ(lod.size(), 1UL);
CHECK_GE(dims[0], static_cast<int64_t>(lod[0].size() - 1));
dims[0] = lod[0].size() - 1;
out->Resize({dims});
out->mutable_data<float>();
lite::Tensor* index = nullptr;
const bool is_test = true;
float pad_value = 0.0;
lite::x86::math::SequencePoolFunctor<lite::TargetType::kX86, float> pool;
pool(context, param.pool_type, pad_value, *x, out, is_test, index);
}
#define PREPARE_INPUT_DATA(name) \
name.Resize({name##_lod_len, feature_len}); \
name##_cpu.Resize({name##_lod_len, feature_len}); \
name##_ref.Resize({name##_lod_len, feature_len}); \
name.set_lod(lod_info_##name); \
name##_cpu.set_lod(lod_info_##name); \
name##_ref.set_lod(lod_info_##name); \
float* name##_cpu_data = name##_cpu.mutable_data<float>(); \
float* name##_ref_data = name##_ref.mutable_data<float>(); \
for (int i = 0; i < name##_cpu.numel(); ++i) { \
name##_cpu_data[i] = (i - 2.0) * 1.0; \
name##_ref_data[i] = (i - 2.0) * 1.0; \
} \
name.Assign<float, lite::DDim, TARGET(kCUDA)>(name##_cpu_data, \
name##_cpu.dims());
#define PREPARE_OUTPUT_INFO(name) \
name##_cpu.Resize({y_lod_len, feature_len}); \
name##_ref.Resize({y_lod_len, feature_len}); \
name.Resize({y_lod_len, feature_len}); \
float* name##_cpu_data = name##_cpu.mutable_data<float>();
} // namespace
TEST(sequence_pool_cuda, normal) {
SequencePoolCompute seq_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
std::unique_ptr<KernelContext> ctx_ref(new KernelContext);
auto& context_ref = ctx_ref->As<X86Context>();
operators::SequencePoolParam param;
lite::Tensor x1, x2, x3, x1_cpu, x2_cpu, x3_cpu, x1_ref, x2_ref, x3_ref;
lite::Tensor y, y_cpu, y_ref;
int32_t x1_lod_len = 10, feature_len = 4;
int32_t x2_lod_len = 4, x3_lod_len = 8;
int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len;
LoD lod_info_x1{{0, 3, 5, 6, 10}};
LoD lod_info_x2{{0, 1, 2, 3, 4}};
LoD lod_info_x3{{0, 2, 4, 6, 8}};
LoD lod_info_y{{0, 0, 0, 0, 0}};
for (size_t i = 0; i < lod_info_x1[0].size(); ++i) {
lod_info_y[0][i] =
lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i];
}
PREPARE_INPUT_DATA(x1);
PREPARE_INPUT_DATA(x2);
PREPARE_INPUT_DATA(x3);
PREPARE_OUTPUT_INFO(y);
param.X = &x1;
param.Out = &y;
param.pool_type = "AVERAGE";
seq_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
seq_kernel.SetContext(std::move(ctx));
seq_kernel.Run();
cudaDeviceSynchronize();
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
param.X = &x1_ref;
param.Out = &y_ref;
sequence_pool_ref(param);
lite::x86::math::SequencePoolFunctor<lite::TargetType::kX86, float> pool;
pool(context, param.pool_type, pad_value, *x, out, is_test, index);
float* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册