未验证 提交 b094b2b6 编写于 作者: P Pei Yang 提交者: GitHub

fix sequence pool cuda (#2457)

* add sequence_pool cuda kernel, test=develop

* fix sequence_pool cuda,test=develop

* fix and complete unittest, test=develop
上级 2621af0e
......@@ -43,7 +43,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling)
nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
......
......@@ -163,20 +163,20 @@ void SequencePoolCompute::Run() {
auto stream = ctx.exec_stream();
std::vector<uint64_t> seq_offset = param.X->lod()[0];
int slice_size =
param.Out->dims()[1] * param.Out->dims()[2] * param.Out->dims()[3];
int batch_size = param.X->lod()[0].size() - 1;
int slice_size = param.Out->dims().production() / batch_size;
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
const float* in_data = param.X->data<float>();
int batch_size = param.X->lod().size() - 1;
lite::Tensor seq_offset_D;
seq_offset_D.Resize({static_cast<int64_t>(seq_offset.size())});
TargetWrapperCuda::MemcpyAsync(seq_offset_D.mutable_data<uint64_t>(),
seq_offset.data(),
sizeof(uint64_t) * seq_offset.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(
seq_offset_D.mutable_data<uint64_t>(TARGET(kCUDA)),
seq_offset.data(),
sizeof(uint64_t) * seq_offset.size(),
IoDirection::HtoD,
stream);
if (param.pool_type == "MAX") {
seq_pool_max_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
......@@ -187,7 +187,7 @@ void SequencePoolCompute::Run() {
batch_size,
seq_offset_D.data<uint64_t>(),
slice_size);
} else if (param.pool_type == "AVERAGE ") {
} else if (param.pool_type == "AVERAGE") {
seq_pool_average_kernel<float><<<CUDA_GET_BLOCKS(batch_size * slice_size),
CUDA_NUM_THREADS,
0,
......
......@@ -14,117 +14,87 @@
#include "lite/kernels/cuda/sequence_pool_compute.h"
#include <gtest/gtest.h>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/backends/x86/math/sequence_pooling.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
namespace {
static void sequence_pool_ref(const operators::SequencePoolParam& param, ) {
auto* x = param.X;
auto* out = param.Out;
auto dims = x->dims();
auto lod = x->lod();
CHECK_EQ(lod.size(), 1UL);
CHECK_GE(dims[0], static_cast<int64_t>(lod[0].size() - 1));
dims[0] = lod[0].size() - 1;
out->Resize({dims});
out->mutable_data<float>();
lite::Tensor* index = nullptr;
const bool is_test = true;
float pad_value = 0.0;
lite::x86::math::SequencePoolFunctor<lite::TargetType::kX86, float> pool;
pool(context, param.pool_type, pad_value, *x, out, is_test, index);
}
#define PREPARE_INPUT_DATA(name) \
name.Resize({name##_lod_len, feature_len}); \
name##_cpu.Resize({name##_lod_len, feature_len}); \
name##_ref.Resize({name##_lod_len, feature_len}); \
name.set_lod(lod_info_##name); \
name##_cpu.set_lod(lod_info_##name); \
name##_ref.set_lod(lod_info_##name); \
float* name##_cpu_data = name##_cpu.mutable_data<float>(); \
float* name##_ref_data = name##_ref.mutable_data<float>(); \
for (int i = 0; i < name##_cpu.numel(); ++i) { \
name##_cpu_data[i] = (i - 2.0) * 1.0; \
name##_ref_data[i] = (i - 2.0) * 1.0; \
} \
name.Assign<float, lite::DDim, TARGET(kCUDA)>(name##_cpu_data, \
name##_cpu.dims());
#define PREPARE_OUTPUT_INFO(name) \
name##_cpu.Resize({y_lod_len, feature_len}); \
name##_ref.Resize({y_lod_len, feature_len}); \
name.Resize({y_lod_len, feature_len}); \
float* name##_cpu_data = name##_cpu.mutable_data<float>();
} // namespace
TEST(sequence_pool_cuda, normal) {
SequencePoolCompute seq_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
std::unique_ptr<KernelContext> ctx_ref(new KernelContext);
auto& context_ref = ctx_ref->As<X86Context>();
operators::SequencePoolParam param;
lite::Tensor x1, x2, x3, x1_cpu, x2_cpu, x3_cpu, x1_ref, x2_ref, x3_ref;
lite::Tensor y, y_cpu, y_ref;
int32_t x1_lod_len = 10, feature_len = 4;
int32_t x2_lod_len = 4, x3_lod_len = 8;
int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len;
LoD lod_info_x1{{0, 3, 5, 6, 10}};
LoD lod_info_x2{{0, 1, 2, 3, 4}};
LoD lod_info_x3{{0, 2, 4, 6, 8}};
LoD lod_info_y{{0, 0, 0, 0, 0}};
for (size_t i = 0; i < lod_info_x1[0].size(); ++i) {
lod_info_y[0][i] =
lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i];
lite::Tensor x, x_cpu, out, out_cpu;
lite::LoD lod;
lod.push_back(std::vector<uint64_t>{0, 10});
x.set_lod(lod);
x_cpu.set_lod(lod);
const size_t second_dim = 8u;
std::vector<int64_t> input_shape{static_cast<int64_t>(lod[0].back()),
static_cast<int64_t>(second_dim)};
lite::DDim in_dims(input_shape);
x.Resize(in_dims);
x_cpu.Resize(in_dims);
const size_t out_first_dim = lod[0].size() - 1;
std::vector<int64_t> output_shape{static_cast<int64_t>(out_first_dim),
static_cast<int64_t>(second_dim)};
lite::DDim out_dims(output_shape);
out.Resize(out_dims);
out_cpu.Resize(out_dims);
auto x_cpu_data = x_cpu.mutable_data<float>();
auto out_data = out.mutable_data<float>(TARGET(kCUDA));
auto out_cpu_data = out_cpu.mutable_data<float>();
for (int64_t i = 0; i < x_cpu.dims().production(); i++) {
x_cpu_data[i] = 1.1f * i;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
PREPARE_INPUT_DATA(x1);
PREPARE_INPUT_DATA(x2);
PREPARE_INPUT_DATA(x3);
PREPARE_OUTPUT_INFO(y);
param.X = &x1;
param.Out = &y;
param.pool_type = "AVERAGE";
seq_kernel.SetParam(param);
operators::SequencePoolParam param;
param.X = &x;
param.Out = &out;
std::vector<std::string> pool_types(
{"MAX", "AVERAGE", "SUM", "SQRT", "FIRST", "LAST"});
std::map<std::string, std::vector<float>> type_map;
type_map["MAX"] = {79.2, 80.3, 81.4, 82.5, 83.6, 84.7, 85.8, 86.9};
type_map["AVERAGE"] = {39.6, 40.7, 41.8, 42.9, 44, 45.1, 46.2, 47.3};
type_map["SUM"] = {396, 407, 418, 429, 440, 451, 462, 473};
type_map["SQRT"] = {
125.226, 128.705, 132.183, 135.662, 139.14, 142.619, 146.097, 149.576};
type_map["FIRST"] = {0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7};
type_map["LAST"] = {79.2, 80.3, 81.4, 82.5, 83.6, 84.7, 85.8, 86.9};
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
seq_kernel.SetContext(std::move(ctx));
seq_kernel.Run();
cudaDeviceSynchronize();
for (std::string pool_type : pool_types) {
param.pool_type = pool_type;
seq_kernel.SetParam(param);
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
seq_kernel.Run();
cudaDeviceSynchronize();
param.X = &x1_ref;
param.Out = &y_ref;
sequence_pool_ref(param);
CopySync<TARGET(kCUDA)>(out_cpu_data,
out_data,
sizeof(float) * out_cpu.numel(),
IoDirection::DtoH);
lite::x86::math::SequencePoolFunctor<lite::TargetType::kX86, float> pool;
pool(context, param.pool_type, pad_value, *x, out, is_test, index);
std::vector<float> ref_results = type_map[pool_type];
float* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5);
for (int i = 0; i < out_cpu.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-3);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册