提交 ef5dbd1d 编写于 作者: P Pei Yang 提交者: GitHub

add search_group_padding cuda kernel, test=develop (#2472)

上级 874a5af4
......@@ -5,6 +5,7 @@ endif()
message(STATUS "compile with lite CUDA kernels")
add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
......@@ -44,11 +45,12 @@ nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(search_group_padding_compute_cuda_test SRCS search_group_padding_compute_test.cc DEPS search_group_padding_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_group_padding_compute.h"
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
template <typename Dtype>
__global__ void ker_search_group_padding(Dtype* out_emb_padding_data,
Dtype* out_padding_data,
const Dtype* in_data,
const uint64_t* offset,
const int seq_num,
const int max_len,
const int emb_size,
const Dtype pad_id,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
int emb_id = tid % emb_size;
int word_id = tid / emb_size;
int seq_id = word_id / max_len;
int word_id_in_seq = word_id % max_len;
int cur_len = offset[seq_id + 1] - offset[seq_id];
if (word_id_in_seq < cur_len) {
out_emb_padding_data[tid] =
in_data[(offset[seq_id] + word_id_in_seq) * emb_size + emb_id];
} else {
out_emb_padding_data[tid] = 0.f;
out_padding_data[word_id] = pad_id;
}
}
}
void SearchGroupPaddingCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto cuda_stream = ctx.exec_stream();
const Tensor* x = param.x;
Tensor* out_emb_padding = param.out_emb_padding;
Tensor* out_new = param.out_new;
Tensor* out_padding = param.out_padding;
const float pad_id = static_cast<float>(param.pad_id);
const float* in_data = x->data<float>();
float* out_emb_padding_data =
out_emb_padding->mutable_data<float>(TARGET(kCUDA));
float* out_new_data = out_new->mutable_data<float>(TARGET(kCUDA));
float* out_padding_data = out_padding->mutable_data<float>(TARGET(kCUDA));
const auto& in_seq_offset = x->lod()[0];
int batch = in_seq_offset.size() - 1;
int max_seq = 0;
for (int i = 0; i < batch; ++i) {
if (in_seq_offset[i + 1] - in_seq_offset[i] > max_seq) {
max_seq = in_seq_offset[i + 1] - in_seq_offset[i];
}
}
std::vector<size_t> new_offset;
new_offset.resize(batch + 1);
for (int i = 0; i < batch + 1; ++i) {
new_offset[i] = i * max_seq;
}
std::vector<int64_t> x_dims = x->dims().Vectorize();
LoD out_emb_padding_lod;
out_emb_padding_lod.push_back(new_offset);
out_emb_padding->set_lod(out_emb_padding_lod);
out_emb_padding->Resize({batch * max_seq, x_dims[1]});
LoD out_new_lod;
out_new_lod.push_back(in_seq_offset);
out_new->set_lod(out_new_lod);
out_new->Resize({x_dims[0], 1});
LoD out_padding_lod;
out_padding_lod.push_back(new_offset);
out_padding->set_lod(out_padding_lod);
out_padding->Resize({batch * max_seq, 1});
const int count = out_emb_padding->numel();
const auto& out_emb_padding_seq_offset = out_emb_padding->lod()[0];
int max_len = out_emb_padding_seq_offset[1];
int seq_num = out_emb_padding_seq_offset.size() - 1;
int emb_size = x->dims()[1];
_in_seq_offset.Resize({seq_num + 1, 1, 1, 1});
uint64_t* offset_data = _in_seq_offset.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(offset_data,
in_seq_offset.data(),
sizeof(uint64_t) * in_seq_offset.size(),
IoDirection::HtoD,
cuda_stream);
TargetWrapperCuda::MemsetSync(
out_new_data, 0, out_new->dims()[0] * out_new->dims()[1] * sizeof(float));
ker_search_group_padding<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, cuda_stream>>>(
out_emb_padding_data,
out_padding_data,
in_data,
offset_data,
seq_num,
max_len,
emb_size,
pad_id,
count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_group_padding,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SearchGroupPaddingCompute,
def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out_emb_padding",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out_new",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out_padding",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SearchGroupPaddingCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SearchGroupPaddingParam;
void Run() override;
virtual ~SearchGroupPaddingCompute() = default;
private:
lite::Tensor _in_seq_offset;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_group_padding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(search_group_padding_cuda, run_test) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
lite::Tensor x, x_cpu, x_ref;
lite::Tensor out_emb_padding, out_emb_padding_cpu, out_emb_padding_ref;
lite::Tensor out_new, out_new_cpu, out_new_ref;
lite::Tensor out_padding, out_padding_cpu, out_padding_ref;
int x_dims0 = 2;
int x_dims1 = 3;
x.Resize({x_dims0, x_dims1});
x_cpu.Resize({x_dims0, x_dims1});
x_ref.Resize({x_dims0, x_dims1});
out_emb_padding.Resize({1, x_dims1});
out_emb_padding_cpu.Resize({1, x_dims1});
out_emb_padding_ref.Resize({1, x_dims1});
out_new.Resize({x_dims0, 1});
out_new_cpu.Resize({x_dims0, 1});
out_new_ref.Resize({x_dims0, 1});
out_padding.Resize({1, 1});
out_padding_cpu.Resize({1, 1});
out_padding_ref.Resize({1, 1});
LoD x_lod{};
x_lod.push_back({0, 1});
x.set_lod(x_lod);
auto* x_cpu_data = x_cpu.mutable_data<float>();
auto* x_ref_data = x_ref.mutable_data<float>();
auto* out_emb_padding_data =
out_emb_padding.mutable_data<float>(TARGET(kCUDA));
auto* out_emb_padding_cpu_data = out_emb_padding_cpu.mutable_data<float>();
auto* out_emb_padding_ref_data = out_emb_padding_ref.mutable_data<float>();
auto* out_new_data = out_new.mutable_data<float>(TARGET(kCUDA));
auto* out_new_cpu_data = out_new_cpu.mutable_data<float>();
auto* out_new_ref_data = out_new_ref.mutable_data<float>();
auto* out_padding_data = out_padding.mutable_data<float>(TARGET(kCUDA));
auto* out_padding_cpu_data = out_padding_cpu.mutable_data<float>();
auto* out_padding_ref_data = out_padding_ref.mutable_data<float>();
for (int64_t i = 0; i < x_cpu.dims().production(); i++) {
x_cpu_data[i] = static_cast<float>(i);
x_ref_data[i] = static_cast<float>(i);
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
out_emb_padding_ref_data[0] = 0.f;
out_emb_padding_ref_data[1] = 1.f;
out_emb_padding_ref_data[2] = 2.f;
out_new_ref_data[0] = 0.f;
out_new_ref_data[1] = 0.f;
out_padding_ref_data[0] = 0.f;
SearchGroupPaddingCompute sgp_kernel;
operators::SearchGroupPaddingParam param;
param.x = &x;
param.out_emb_padding = &out_emb_padding;
param.out_new = &out_new;
param.out_padding = &out_padding;
sgp_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
sgp_kernel.SetContext(std::move(ctx));
sgp_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(out_emb_padding_cpu_data,
out_emb_padding_data,
sizeof(float) * out_emb_padding.numel(),
IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(out_new_cpu_data,
out_new_data,
sizeof(float) * out_new.numel(),
IoDirection::DtoH);
CopySync<TARGET(kCUDA)>(out_padding_cpu_data,
out_padding_data,
sizeof(float) * out_padding.numel(),
IoDirection::DtoH);
for (int i = 0; i < out_emb_padding_cpu.dims().production(); i++) {
EXPECT_NEAR(out_emb_padding_cpu_data[i], out_emb_padding_ref_data[i], 1e-5);
}
for (int i = 0; i < out_new_cpu.dims().production(); i++) {
EXPECT_NEAR(out_new_cpu_data[i], out_new_ref_data[i], 1e-5);
}
for (int i = 0; i < out_padding_cpu.dims().production(); i++) {
EXPECT_NEAR(out_padding_cpu_data[i], out_padding_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(search_group_padding, kCUDA, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册