diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index a61aebb8f8e2f2ed5ceda3640c442dcc09c9a8a6..ac2e62f8cfff295c418c315d4fb521d69fe26a18 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -43,7 +43,7 @@ nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_c nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda) nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda) -nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda sequence_pooling) +nv_test(sequence_pool_compute_cuda_test SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_cuda) nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda) #nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda) nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) diff --git a/lite/kernels/cuda/sequence_pool_compute.cu b/lite/kernels/cuda/sequence_pool_compute.cu index 34853adf927ac7de641b377c21be6d881711f70e..0b0376859f3a8b2eea2c5ff6af4dd77e94dc362d 100644 --- a/lite/kernels/cuda/sequence_pool_compute.cu +++ b/lite/kernels/cuda/sequence_pool_compute.cu @@ -163,20 +163,20 @@ void SequencePoolCompute::Run() { auto stream = ctx.exec_stream(); std::vector seq_offset = param.X->lod()[0]; - int slice_size = - param.Out->dims()[1] * param.Out->dims()[2] * param.Out->dims()[3]; + int batch_size = param.X->lod()[0].size() - 1; + int slice_size = param.Out->dims().production() / batch_size; float* out_data = param.Out->mutable_data(TARGET(kCUDA)); const float* in_data = param.X->data(); - int batch_size = param.X->lod().size() - 1; lite::Tensor seq_offset_D; seq_offset_D.Resize({static_cast(seq_offset.size())}); - TargetWrapperCuda::MemcpyAsync(seq_offset_D.mutable_data(), - seq_offset.data(), - sizeof(uint64_t) * seq_offset.size(), - IoDirection::HtoD, - stream); + TargetWrapperCuda::MemcpyAsync( + seq_offset_D.mutable_data(TARGET(kCUDA)), + seq_offset.data(), + sizeof(uint64_t) * seq_offset.size(), + IoDirection::HtoD, + stream); if (param.pool_type == "MAX") { seq_pool_max_kernel<<(), slice_size); - } else if (param.pool_type == "AVERAGE ") { + } else if (param.pool_type == "AVERAGE") { seq_pool_average_kernel<< +#include #include +#include #include #include -#include "lite/backends/x86/math/sequence_pooling.h" namespace paddle { namespace lite { namespace kernels { namespace cuda { -namespace { - -static void sequence_pool_ref(const operators::SequencePoolParam& param, ) { - auto* x = param.X; - auto* out = param.Out; - auto dims = x->dims(); - auto lod = x->lod(); - CHECK_EQ(lod.size(), 1UL); - CHECK_GE(dims[0], static_cast(lod[0].size() - 1)); - - dims[0] = lod[0].size() - 1; - out->Resize({dims}); - out->mutable_data(); - lite::Tensor* index = nullptr; - - const bool is_test = true; - float pad_value = 0.0; - - lite::x86::math::SequencePoolFunctor pool; - pool(context, param.pool_type, pad_value, *x, out, is_test, index); -} - -#define PREPARE_INPUT_DATA(name) \ - name.Resize({name##_lod_len, feature_len}); \ - name##_cpu.Resize({name##_lod_len, feature_len}); \ - name##_ref.Resize({name##_lod_len, feature_len}); \ - name.set_lod(lod_info_##name); \ - name##_cpu.set_lod(lod_info_##name); \ - name##_ref.set_lod(lod_info_##name); \ - float* name##_cpu_data = name##_cpu.mutable_data(); \ - float* name##_ref_data = name##_ref.mutable_data(); \ - for (int i = 0; i < name##_cpu.numel(); ++i) { \ - name##_cpu_data[i] = (i - 2.0) * 1.0; \ - name##_ref_data[i] = (i - 2.0) * 1.0; \ - } \ - name.Assign(name##_cpu_data, \ - name##_cpu.dims()); - -#define PREPARE_OUTPUT_INFO(name) \ - name##_cpu.Resize({y_lod_len, feature_len}); \ - name##_ref.Resize({y_lod_len, feature_len}); \ - name.Resize({y_lod_len, feature_len}); \ - float* name##_cpu_data = name##_cpu.mutable_data(); - -} // namespace - TEST(sequence_pool_cuda, normal) { SequencePoolCompute seq_kernel; std::unique_ptr ctx(new KernelContext); auto& context = ctx->As(); - std::unique_ptr ctx_ref(new KernelContext); - auto& context_ref = ctx_ref->As(); - operators::SequencePoolParam param; - lite::Tensor x1, x2, x3, x1_cpu, x2_cpu, x3_cpu, x1_ref, x2_ref, x3_ref; - lite::Tensor y, y_cpu, y_ref; - - int32_t x1_lod_len = 10, feature_len = 4; - int32_t x2_lod_len = 4, x3_lod_len = 8; - int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len; - LoD lod_info_x1{{0, 3, 5, 6, 10}}; - LoD lod_info_x2{{0, 1, 2, 3, 4}}; - LoD lod_info_x3{{0, 2, 4, 6, 8}}; - LoD lod_info_y{{0, 0, 0, 0, 0}}; - for (size_t i = 0; i < lod_info_x1[0].size(); ++i) { - lod_info_y[0][i] = - lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i]; + lite::Tensor x, x_cpu, out, out_cpu; + lite::LoD lod; + lod.push_back(std::vector{0, 10}); + + x.set_lod(lod); + x_cpu.set_lod(lod); + const size_t second_dim = 8u; + std::vector input_shape{static_cast(lod[0].back()), + static_cast(second_dim)}; + lite::DDim in_dims(input_shape); + x.Resize(in_dims); + x_cpu.Resize(in_dims); + + const size_t out_first_dim = lod[0].size() - 1; + std::vector output_shape{static_cast(out_first_dim), + static_cast(second_dim)}; + lite::DDim out_dims(output_shape); + out.Resize(out_dims); + out_cpu.Resize(out_dims); + + auto x_cpu_data = x_cpu.mutable_data(); + auto out_data = out.mutable_data(TARGET(kCUDA)); + auto out_cpu_data = out_cpu.mutable_data(); + + for (int64_t i = 0; i < x_cpu.dims().production(); i++) { + x_cpu_data[i] = 1.1f * i; } + x.Assign(x_cpu_data, x_cpu.dims()); - PREPARE_INPUT_DATA(x1); - PREPARE_INPUT_DATA(x2); - PREPARE_INPUT_DATA(x3); - PREPARE_OUTPUT_INFO(y); - - param.X = &x1; - param.Out = &y; - param.pool_type = "AVERAGE"; - seq_kernel.SetParam(param); + operators::SequencePoolParam param; + param.X = &x; + param.Out = &out; + std::vector pool_types( + {"MAX", "AVERAGE", "SUM", "SQRT", "FIRST", "LAST"}); + std::map> type_map; + type_map["MAX"] = {79.2, 80.3, 81.4, 82.5, 83.6, 84.7, 85.8, 86.9}; + type_map["AVERAGE"] = {39.6, 40.7, 41.8, 42.9, 44, 45.1, 46.2, 47.3}; + type_map["SUM"] = {396, 407, 418, 429, 440, 451, 462, 473}; + type_map["SQRT"] = { + 125.226, 128.705, 132.183, 135.662, 139.14, 142.619, 146.097, 149.576}; + type_map["FIRST"] = {0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7}; + type_map["LAST"] = {79.2, 80.3, 81.4, 82.5, 83.6, 84.7, 85.8, 86.9}; cudaStream_t stream; cudaStreamCreate(&stream); context.SetExecStream(stream); seq_kernel.SetContext(std::move(ctx)); - seq_kernel.Run(); - cudaDeviceSynchronize(); + for (std::string pool_type : pool_types) { + param.pool_type = pool_type; + seq_kernel.SetParam(param); - auto* y_data = y.mutable_data(TARGET(kCUDA)); - CopySync( - y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + seq_kernel.Run(); + cudaDeviceSynchronize(); - param.X = &x1_ref; - param.Out = &y_ref; - sequence_pool_ref(param); + CopySync(out_cpu_data, + out_data, + sizeof(float) * out_cpu.numel(), + IoDirection::DtoH); - lite::x86::math::SequencePoolFunctor pool; - pool(context, param.pool_type, pad_value, *x, out, is_test, index); + std::vector ref_results = type_map[pool_type]; - float* y_ref_data = y_ref.mutable_data(); - for (int i = 0; i < y.numel(); i++) { - EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5); + for (int i = 0; i < out_cpu.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-3); + } } }