diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index b33fc8f6bb0a5616ab87c01d55f9d81a9fe7032b..8e0400cab86f40de4f9b0af496a4e4f1ee9a67ec 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda) @@ -38,6 +39,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) +nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) if(LITE_BUILD_EXTRA) nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda) endif() diff --git a/lite/kernels/cuda/sequence_concat_compute.cu b/lite/kernels/cuda/sequence_concat_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..3488c829ce2469e4193601145c0e1eb459bfb1a4 --- /dev/null +++ b/lite/kernels/cuda/sequence_concat_compute.cu @@ -0,0 +1,131 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/kernels/cuda/sequence_concat_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +const int CUDA_NUM_THREADS = 512; + +template +__global__ void ker_sequence_concat(Dtype* out_data, + const uint64_t* in_locate_data, + const int* o2i_map, + const int* o2i_w_map, + const int seq_num, + const int emb_size, + const int count) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + for (int tid = idx; tid < count; tid += blockDim.x * gridDim.x) { + int emb_id = tid % emb_size; + int word_id = tid / emb_size; + int input_id = o2i_map[word_id]; + int cur_work_id = o2i_w_map[word_id]; + const Dtype* in_data = reinterpret_cast( + reinterpret_cast(in_locate_data[input_id])); + out_data[tid] = in_data[cur_work_id * emb_size + emb_id]; + } +} + +void SequenceConcatCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + float* out_data = param.Out->mutable_data(TARGET(kCUDA)); + + int seq_num = param.X[0]->lod()[0].size() - 1; + const int emb_size = param.X[0]->numel() / param.X[0]->dims()[0]; + std::vector in_locate_vec; + for (size_t i = 0; i < param.X.size(); ++i) { + in_locate_vec.push_back( + reinterpret_cast(param.X[i]->data())); + } + in_locate_tensor.Resize({static_cast(in_locate_vec.size())}); + + std::vector out2in_map; + std::vector out2in_word_map; + for (int i = 0; i < seq_num; ++i) { + for (int j = 0; j < param.X.size(); ++j) { + auto offset = param.X[j]->lod()[0]; + int cur_len = offset[i + 1] - offset[i]; + for (int k = 0; k < cur_len; ++k) { + out2in_map.push_back(j); + out2in_word_map.push_back(offset[i] + k); + } + } + } + int word_num = out2in_map.size(); + out2in_map_tensor.Resize({word_num}); + out2in_word_map_tensor.Resize({word_num}); + int* gpu_o2i_map_data = out2in_map_tensor.mutable_data(TARGET(kCUDA)); + int* gpu_o2i_w_map_data = + out2in_word_map_tensor.mutable_data(TARGET(kCUDA)); + uint64_t* gpu_in_locate_data = + in_locate_tensor.mutable_data(TARGET(kCUDA)); + + TargetWrapperCuda::MemcpyAsync(gpu_o2i_map_data, + out2in_map.data(), + sizeof(int) * out2in_map.size(), + IoDirection::HtoD, + stream); + TargetWrapperCuda::MemcpyAsync(gpu_o2i_w_map_data, + out2in_word_map.data(), + sizeof(int) * out2in_word_map.size(), + IoDirection::HtoD, + stream); + TargetWrapperCuda::MemcpyAsync(gpu_in_locate_data, + in_locate_vec.data(), + sizeof(uint64_t) * in_locate_vec.size(), + IoDirection::HtoD, + stream); + + int count = param.X[0]->numel(); + for (int i = 1; i < param.X.size(); ++i) { + count += param.X[i]->numel(); + } + + int blocks = (count + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; + ker_sequence_concat<<>>( + out_data, + gpu_in_locate_data, + gpu_o2i_map_data, + gpu_o2i_w_map_data, + seq_num, + emb_size, + count); + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_concat, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SequenceConcatCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_concat_compute.h b/lite/kernels/cuda/sequence_concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..1737c18dd35976572efa1b62fadefed906b0ceb5 --- /dev/null +++ b/lite/kernels/cuda/sequence_concat_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SequenceConcatCompute + : public KernelLite { + public: + using param_t = operators::SequenceConcatParam; + + void Run() override; + virtual ~SequenceConcatCompute() = default; + + private: + lite::Tensor out2in_map_tensor; + lite::Tensor out2in_word_map_tensor; + lite::Tensor in_locate_tensor; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_concat_compute_test.cc b/lite/kernels/cuda/sequence_concat_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..477dc48dbbdfe7a1453bbb5c811d6897347fee53 --- /dev/null +++ b/lite/kernels/cuda/sequence_concat_compute_test.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sequence_concat_compute.h" +#include +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +namespace { +inline LoD ConcatLoD(const std::vector& xs, + std::vector* xs_in_order) { + std::vector result; + result.resize(xs[0]->lod()[0].size()); + + for (size_t i = 1; i < result.size(); ++i) { + size_t sum = 0; + for (size_t j = 0; j < xs.size(); ++j) { + auto& x_lod = xs[j]->lod()[0]; + if (x_lod[i - 1] < x_lod[i]) { + xs_in_order->emplace_back(xs[j]->Slice(x_lod[i - 1], x_lod[i])); + } + sum += x_lod[i]; + } + result[i] = sum; + } + LoD lod; + lod.emplace_back(result); + return lod; +} + +static void sequence_concat_ref(const std::vector& xs, + lite::Tensor* out) { + std::vector out_dims; + int64_t batch_size = 0; + int64_t feature_size = 0; + for (const auto& tensor : xs) { + const auto x_dims = tensor->dims(); + if (out_dims.empty()) { + out_dims = x_dims.Vectorize(); + } + batch_size += x_dims[0]; + if (feature_size == 0) { + feature_size = x_dims.production() / x_dims[0]; + } else { + CHECK_EQ(feature_size, x_dims.production() / x_dims[0]) + << "Inputs of sequence concat must have same feature size"; + } + } + out_dims[0] = batch_size; + out->Resize(out_dims); + std::vector x_in_order; + out->set_lod(ConcatLoD(xs, &x_in_order)); + + int num = x_in_order.size(); + std::vector input_cols(num); + for (int i = 0; i < num; ++i) { + input_cols[i] = x_in_order[i].numel(); + } + float* out_data = out->mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = x_in_order[j].data(); + memcpy(out_data + col_idx, input_data, sizeof(float) * col_len); + col_idx += col_len; + } +} + +#define PREPARE_INPUT_DATA(name) \ + name.Resize({name##_lod_len, feature_len}); \ + name##_cpu.Resize({name##_lod_len, feature_len}); \ + name##_ref.Resize({name##_lod_len, feature_len}); \ + name.set_lod(lod_info_##name); \ + name##_cpu.set_lod(lod_info_##name); \ + name##_ref.set_lod(lod_info_##name); \ + float* name##_cpu_data = name##_cpu.mutable_data(); \ + float* name##_ref_data = name##_ref.mutable_data(); \ + for (int i = 0; i < name##_cpu.numel(); ++i) { \ + name##_cpu_data[i] = (i - 2.0) * 1.0; \ + name##_ref_data[i] = (i - 2.0) * 1.0; \ + } \ + name.Assign(name##_cpu_data, \ + name##_cpu.dims()); + +#define PREPARE_OUTPUT_INFO(name) \ + name##_cpu.Resize({y_lod_len, feature_len}); \ + name##_ref.Resize({y_lod_len, feature_len}); \ + name.Resize({y_lod_len, feature_len}); \ + float* name##_cpu_data = name##_cpu.mutable_data(); + +} // namespace + +TEST(sequence_concat_cuda, normal) { + SequenceConcatCompute seq_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::SequenceConcatParam param; + lite::Tensor x1, x2, x3, x1_cpu, x2_cpu, x3_cpu, x1_ref, x2_ref, x3_ref; + lite::Tensor y, y_cpu, y_ref; + + int32_t x1_lod_len = 10, feature_len = 4; + int32_t x2_lod_len = 4, x3_lod_len = 8; + int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len; + LoD lod_info_x1{{0, 3, 5, 6, 10}}; + LoD lod_info_x2{{0, 1, 2, 3, 4}}; + LoD lod_info_x3{{0, 2, 4, 6, 8}}; + LoD lod_info_y{{0, 0, 0, 0, 0}}; + for (size_t i = 0; i < lod_info_x1[0].size(); ++i) { + lod_info_y[0][i] = + lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i]; + } + + PREPARE_INPUT_DATA(x1); + PREPARE_INPUT_DATA(x2); + PREPARE_INPUT_DATA(x3); + PREPARE_OUTPUT_INFO(y); + + param.X = std::vector({&x1, &x2, &x3}); + param.Out = &y; + seq_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + seq_kernel.SetContext(std::move(ctx)); + seq_kernel.Run(); + cudaDeviceSynchronize(); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + + std::vector input_ref({&x1_ref, &x2_ref, &x3_ref}); + sequence_concat_ref(input_ref, &y_ref); + float* y_ref_data = y_ref.mutable_data(); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index da955e4fd5902373cd881f85a8bc715eef7cec94..46d32da3872113f265f7ddef2e70f460c0379d13 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -39,6 +39,7 @@ add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${li add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) if(NOT LITE_WITH_X86) return() @@ -67,3 +68,4 @@ lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_com lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) +lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86) diff --git a/lite/kernels/x86/sequence_concat_compute.cc b/lite/kernels/x86/sequence_concat_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..facdad39d383c3a2134599e1490c89e9d5afa543 --- /dev/null +++ b/lite/kernels/x86/sequence_concat_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_concat_compute.h" + +REGISTER_LITE_KERNEL(sequence_concat, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceConcatCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_concat_compute.h b/lite/kernels/x86/sequence_concat_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..553e2e8b0667106f25685a9ef155d7e61a672f31 --- /dev/null +++ b/lite/kernels/x86/sequence_concat_compute.h @@ -0,0 +1,84 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +inline LoD ConcatLoD(const std::vector& xs, + std::vector* xs_in_order) { + std::vector result; + result.resize(xs[0]->lod()[0].size()); + + for (size_t i = 1; i < result.size(); ++i) { + size_t sum = 0; + for (size_t j = 0; j < xs.size(); ++j) { + auto& x_lod = xs[j]->lod()[0]; + if (x_lod[i - 1] < x_lod[i]) { + xs_in_order->emplace_back(xs[j]->Slice(x_lod[i - 1], x_lod[i])); + } + sum += x_lod[i]; + } + result[i] = sum; + } + LoD lod; + lod.emplace_back(result); + return lod; +} + +template +class SequenceConcatCompute + : public KernelLite { + public: + using param_t = operators::SequenceConcatParam; + + void Run() override { + auto& param = *param_.get_mutable(); + // auto& param = Param(); + T* dout = param.Out->mutable_data(); + + std::vector x_in_order; + param.Out->set_lod(ConcatLoD(param.X, &x_in_order)); + + int num = x_in_order.size(); + int out_rows = 1; + + std::vector input_cols(num); + for (int i = 0; i < num; ++i) { + input_cols[i] = x_in_order[i].numel() / out_rows; + } + + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = x_in_order[j].data(); + memcpy(dout + col_idx, input_data, sizeof(T) * col_len); + col_idx += col_len; + } + } + + virtual ~SequenceConcatCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_concat_compute_test.cc b/lite/kernels/x86/sequence_concat_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be1f86a5c848b5c03634ea2a1aed0d57f2283879 --- /dev/null +++ b/lite/kernels/x86/sequence_concat_compute_test.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_concat_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +namespace { +inline LoD ConcatLoD(const std::vector& xs, + std::vector* xs_in_order) { + std::vector result; + result.resize(xs[0]->lod()[0].size()); + + for (size_t i = 1; i < result.size(); ++i) { + size_t sum = 0; + for (size_t j = 0; j < xs.size(); ++j) { + auto& x_lod = xs[j]->lod()[0]; + if (x_lod[i - 1] < x_lod[i]) { + xs_in_order->emplace_back(xs[j]->Slice(x_lod[i - 1], x_lod[i])); + } + sum += x_lod[i]; + } + result[i] = sum; + } + LoD lod; + lod.emplace_back(result); + return lod; +} + +static void sequence_concat_ref(const std::vector& xs, + lite::Tensor* out) { + std::vector out_dims; + int64_t batch_size = 0; + int64_t feature_size = 0; + for (const auto& tensor : xs) { + const auto x_dims = tensor->dims(); + if (out_dims.empty()) { + out_dims = x_dims.Vectorize(); + } + batch_size += x_dims[0]; + if (feature_size == 0) { + feature_size = x_dims.production() / x_dims[0]; + } else { + CHECK_EQ(feature_size, x_dims.production() / x_dims[0]) + << "Inputs of sequence concat must have same feature size"; + } + } + out_dims[0] = batch_size; + out->Resize(out_dims); + std::vector x_in_order; + out->set_lod(ConcatLoD(xs, &x_in_order)); + + int num = x_in_order.size(); + std::vector input_cols(num); + for (int i = 0; i < num; ++i) { + input_cols[i] = x_in_order[i].numel(); + } + float* out_data = out->mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = x_in_order[j].data(); + memcpy(out_data + col_idx, input_data, sizeof(float) * col_len); + col_idx += col_len; + } +} + +#define PREPARE_INPUT(name) \ + name.Resize({name##_lod_len, feature_len}); \ + name.set_lod(lod_info_##name); \ + float* name##_data = name.mutable_data(); \ + for (int i = 0; i < name.numel(); ++i) { \ + name##_data[i] = (i - 2.0) * 1.0; \ + } + +} // namespace + +TEST(sequence_concat_x86, retrive_op) { + auto sequence_concat = + KernelRegistry::Global().Create( + "sequence_concat"); + ASSERT_FALSE(sequence_concat.empty()); + ASSERT_TRUE(sequence_concat.front()); +} + +TEST(sequence_concat_x86, init) { + SequenceConcatCompute sequence_concat; + ASSERT_EQ(sequence_concat.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_concat.target(), TARGET(kX86)); +} + +TEST(sequence_concat_x86, run_test) { + SequenceConcatCompute seq_kernel; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + + operators::SequenceConcatParam param; + lite::Tensor x1, x2, x3; + lite::Tensor y, y_ref; + + int32_t x1_lod_len = 10, feature_len = 4; + int32_t x2_lod_len = 4, x3_lod_len = 8; + int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len; + LoD lod_info_x1{{0, 3, 5, 6, 10}}; + LoD lod_info_x2{{0, 1, 2, 3, 4}}; + LoD lod_info_x3{{0, 2, 4, 6, 8}}; + LoD lod_info_y{{0, 0, 0, 0, 0}}; + for (size_t i = 0; i < lod_info_x1[0].size(); ++i) { + lod_info_y[0][i] = + lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i]; + } + + PREPARE_INPUT(x1); + PREPARE_INPUT(x2); + PREPARE_INPUT(x3); + + y_ref.Resize({y_lod_len, feature_len}); + y.Resize({y_lod_len, feature_len}); + y_ref.set_lod(lod_info_y); + y.set_lod(lod_info_y); + + std::vector xs{&x1, &x2, &x3}; + + param.X = xs; + param.Out = &y; + seq_kernel.SetParam(param); + + seq_kernel.SetContext(std::move(ctx)); + seq_kernel.Run(); + + auto* y_data = y.mutable_data(); + sequence_concat_ref(xs, &y_ref); + float* y_ref_data = y_ref.mutable_data(); + + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_concat, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 49badbb27b00979117f9e75d1c66763a7be99837..5b868a3d7e087e1d7547ccf8b8de7ad6f939d5c3 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -79,6 +79,7 @@ add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quan add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS}) add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) +add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 474c97559041d069ccdaa2e149c83cea4ea9ae2c..4eb4c4f68896b2cbe4cac2418a7b4cf255602264 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -751,6 +751,11 @@ struct SequenceExpandAsParam { lite::Tensor* out{nullptr}; }; +struct SequenceConcatParam { + std::vector X{}; + lite::Tensor* Out{}; +}; + struct ReduceMaxParam { const lite::Tensor* X{}; lite::Tensor* Out{}; diff --git a/lite/operators/sequence_concat_op.cc b/lite/operators/sequence_concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7c842d49e54a6a567abd4b733307942f90176dce --- /dev/null +++ b/lite/operators/sequence_concat_op.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_concat_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceConcatOp::CheckShape() const { + CHECK_GT(param_.X.size(), 1) + << "The number of input sequences is at least two."; + CHECK_OR_FALSE(param_.Out); + size_t lod_size = 0; + for (const auto &t : param_.X) { + CHECK_EQ(t->lod().empty(), false) + << "Input Tensor of X does not contain LoD information."; + CHECK_EQ(t->lod().size(), 1) << "Only support one level sequence now."; + if (lod_size == 0) { + lod_size = t->lod()[0].size(); + } else { + CHECK_EQ(t->lod()[0].size(), lod_size) + << "The number of sequence must be same between each input"; + } + } + CHECK_NE(lod_size, 0) << "Each input must have sequence information"; + return true; +} + +bool SequenceConcatOp::InferShape() const { + int64_t batch_size = 0; + int64_t feature_size = 0; + std::vector out_dims; + for (const auto &tensor : param_.X) { + const auto x_dims = tensor->dims(); + if (out_dims.empty()) { + out_dims = x_dims.Vectorize(); + } + batch_size += x_dims[0]; + if (feature_size == 0) { + feature_size = x_dims.production() / x_dims[0]; + } else { + CHECK_EQ(feature_size, x_dims.production() / x_dims[0]) + << "Inputs of sequence concat must have same feature size"; + } + } + if (batch_size < 0) { + batch_size = -1; // Normalize batch size for compile time. + } + out_dims[0] = batch_size; + param_.Out->Resize(out_dims); + // LoD info will be computed in Kernel. + return true; +} + +bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + auto input_list = opdesc.Input("X"); + param_.X.clear(); + for (auto var : input_list) { + param_.X.push_back(scope->FindVar(var)->GetMutable()); + } + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + CHECK(param_.Out) << "Output(Out) of Sequence Concat Op should not be null."; + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_concat, paddle::lite::operators::SequenceConcatOp); diff --git a/lite/operators/sequence_concat_op.h b/lite/operators/sequence_concat_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8cdc07ebca83b9c400b00a0f40556a788c5854e6 --- /dev/null +++ b/lite/operators/sequence_concat_op.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceConcatOp : public OpLite { + public: + SequenceConcatOp() {} + explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {} + bool CheckShape() const override; + bool InferShape() const override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_concat"; } + + private: + mutable SequenceConcatParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle