From 3132ad0301a8959501a2399ec23390d50e05f97f Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 13 Nov 2019 18:39:27 +0800 Subject: [PATCH] add sequence_reverse op and kerenl for arm and cuda test=develop (#2397) - add sequence_reverse op - add sequence_reverse kernel for x86 and cuda - add sequence_reverse_test for x86 and cuda --- lite/kernels/cuda/CMakeLists.txt | 2 + lite/kernels/cuda/sequence_reverse_compute.cu | 125 ++++++++++++++++++ lite/kernels/cuda/sequence_reverse_compute.h | 38 ++++++ .../cuda/sequence_reverse_compute_test.cc | 105 +++++++++++++++ lite/kernels/x86/CMakeLists.txt | 2 + lite/kernels/x86/sequence_reverse_compute.cc | 25 ++++ lite/kernels/x86/sequence_reverse_compute.h | 64 +++++++++ .../x86/sequence_reverse_compute_test.cc | 108 +++++++++++++++ lite/operators/CMakeLists.txt | 1 + lite/operators/op_params.h | 5 + lite/operators/sequence_reverse_op.cc | 55 ++++++++ lite/operators/sequence_reverse_op.h | 41 ++++++ 12 files changed, 571 insertions(+) create mode 100644 lite/kernels/cuda/sequence_reverse_compute.cu create mode 100644 lite/kernels/cuda/sequence_reverse_compute.h create mode 100644 lite/kernels/cuda/sequence_reverse_compute_test.cc create mode 100644 lite/kernels/x86/sequence_reverse_compute.cc create mode 100644 lite/kernels/x86/sequence_reverse_compute.h create mode 100644 lite/kernels/x86/sequence_reverse_compute_test.cc create mode 100644 lite/operators/sequence_reverse_op.cc create mode 100644 lite/operators/sequence_reverse_op.h diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 8e0400cab8..24e306b420 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) @@ -39,6 +40,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) +nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) if(LITE_BUILD_EXTRA) nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda) diff --git a/lite/kernels/cuda/sequence_reverse_compute.cu b/lite/kernels/cuda/sequence_reverse_compute.cu new file mode 100644 index 0000000000..ee2550cd96 --- /dev/null +++ b/lite/kernels/cuda/sequence_reverse_compute.cu @@ -0,0 +1,125 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/kernels/cuda/sequence_reverse_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +template +__host__ __device__ inline size_t UpperBound(const T* x, + size_t num, + const T& val) { + // The following code is from + // https://en.cppreference.com/w/cpp/algorithm/upper_bound + auto* first = x; + int64_t count = static_cast(num); + while (count > 0) { + auto step = (count >> 1); + auto* it = first + step; + if (val < *it) { + count = step; + } else { + first = ++it; + count -= (step + 1); + } + } + return static_cast(first - x); +} + +__global__ void SequenceReverseKernelGridIsOne(const float* x, + float* y, + const int64_t* lod, + size_t lod_count, + int64_t row_numel) { + int64_t idx = static_cast(threadIdx.x); + auto row_idx_x = idx / row_numel; + auto lod_idx = UpperBound(lod, lod_count, row_idx_x); + auto row_idx_y = lod[lod_idx - 1] + (lod[lod_idx] - 1 - row_idx_x); + auto idx_y = row_idx_y * row_numel + idx % row_numel; + y[idx_y] = x[idx]; +} + +__global__ void SequenceReverseKernel(const float* x, + float* y, + const int64_t* lod, + size_t lod_count, + int64_t row_numel, + size_t limit) { + int64_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx < limit) { + auto row_idx_x = idx / row_numel; + auto lod_idx = UpperBound(lod, lod_count, row_idx_x); + auto row_idx_y = lod[lod_idx - 1] + (lod[lod_idx] - 1 - row_idx_x); + auto idx_y = row_idx_y * row_numel + idx % row_numel; + y[idx_y] = x[idx]; + } +} + +void SequenceReverseCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + size_t limit = static_cast(param.X->numel()); + int64_t row_numel = static_cast(limit / param.X->dims()[0]); + const auto* x_data = param.X->data(); + auto y_data = param.Out->mutable_data(TARGET(kCUDA)); + CHECK_NE(x_data, y_data) + << "SequenceReverse Op does not support in-place operation"; + const auto lod = param.X->lod()[param.X->lod().size() - 1]; + const size_t lod_count = lod.size(); + + lod_cuda.Resize({static_cast(lod.size())}); + int64_t* lod_data = lod_cuda.mutable_data(TARGET(kCUDA)); + TargetWrapperCuda::MemcpyAsync(lod_data, + lod.data(), + sizeof(int64_t) * lod.size(), + IoDirection::HtoD, + stream); + + constexpr int num_threads = 1024; + int block_size = limit <= num_threads ? limit : num_threads; + int grid_size = (limit + num_threads - 1) / num_threads; + + if (grid_size == 1) { + SequenceReverseKernelGridIsOne<<<1, block_size, 0, stream>>>( + x_data, y_data, lod_data, lod_count, row_numel); + } else { + SequenceReverseKernel<<>>( + x_data, y_data, lod_data, lod_count, row_numel, limit); + } + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_reverse, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SequenceReverseCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_reverse_compute.h b/lite/kernels/cuda/sequence_reverse_compute.h new file mode 100644 index 0000000000..ba85f08563 --- /dev/null +++ b/lite/kernels/cuda/sequence_reverse_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SequenceReverseCompute + : public KernelLite { + public: + using param_t = operators::SequenceReverseParam; + + void Run() override; + virtual ~SequenceReverseCompute() = default; + + private: + lite::Tensor lod_cuda; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_reverse_compute_test.cc b/lite/kernels/cuda/sequence_reverse_compute_test.cc new file mode 100644 index 0000000000..3659f0d12c --- /dev/null +++ b/lite/kernels/cuda/sequence_reverse_compute_test.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sequence_reverse_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) { + const auto* x_data = x->data(); + auto seq_offset = x->lod()[x->lod().size() - 1]; + int width = x->numel() / x->dims()[0]; + auto* y_data = y->mutable_data(); + for (int i = 0; i < static_cast(seq_offset.size()) - 1; ++i) { + auto start_pos = seq_offset[i]; + auto end_pos = seq_offset[i + 1]; + for (auto pos = start_pos; pos < end_pos; ++pos) { + auto cur_pos = end_pos - pos - 1 + start_pos; + std::memcpy(y_data + pos * width, + x_data + cur_pos * width, + width * sizeof(float)); + } + } +} + +TEST(sequence_reverse_cuda, normal) { + SequenceReverseCompute seq_kernel; + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + + operators::SequenceReverseParam param; + lite::Tensor x, x_cpu, x_ref; + lite::Tensor y, y_cpu, y_ref; + + int32_t lod_len = 10, feature_len = 4; + LoD lod_info{{0, 2, 4}, {0, 3, 5, 6, 10}}; + + x.Resize({lod_len, feature_len}); + x_cpu.Resize({lod_len, feature_len}); + x_ref.Resize({lod_len, feature_len}); + y.Resize({lod_len, feature_len}); + y_cpu.Resize({lod_len, feature_len}); + y_ref.Resize({lod_len, feature_len}); + x.set_lod(lod_info); + x_cpu.set_lod(lod_info); + x_ref.set_lod(lod_info); + y.set_lod(lod_info); + y_cpu.set_lod(lod_info); + y_ref.set_lod(lod_info); + + auto* y_data = y.mutable_data(TARGET(kCUDA)); + + float* x_cpu_data = x_cpu.mutable_data(); + float* x_ref_data = x_ref.mutable_data(); + float* y_cpu_data = y_cpu.mutable_data(); + float* y_ref_data = y_ref.mutable_data(); + + for (int i = 0; i < x_cpu.numel(); ++i) { + x_cpu_data[i] = (i - 2.0) * 1.0; + x_ref_data[i] = (i - 2.0) * 1.0; + } + + x.Assign(x_cpu_data, x_cpu.dims()); + + param.X = &x; + param.Out = &y; + seq_kernel.SetParam(param); + + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + seq_kernel.SetContext(std::move(ctx)); + seq_kernel.Run(); + cudaDeviceSynchronize(); + + CopySync( + y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH); + + sequence_reverse_ref(&x_ref, &y_ref); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 56441bbc2c..5fa9e20a02 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -34,6 +34,7 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) +add_kernel(sequence_reverse_compute_x86 X86 basic SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps}) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) @@ -55,6 +56,7 @@ lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_b lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86) lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) +lite_cc_test(test_sequence_reverse_compute_x86 SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_x86) lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) diff --git a/lite/kernels/x86/sequence_reverse_compute.cc b/lite/kernels/x86/sequence_reverse_compute.cc new file mode 100644 index 0000000000..7d4cb8402f --- /dev/null +++ b/lite/kernels/x86/sequence_reverse_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_reverse_compute.h" + +REGISTER_LITE_KERNEL(sequence_reverse, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceReverseCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_reverse_compute.h b/lite/kernels/x86/sequence_reverse_compute.h new file mode 100644 index 0000000000..85072e8010 --- /dev/null +++ b/lite/kernels/x86/sequence_reverse_compute.h @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequenceReverseCompute + : public KernelLite { + public: + using param_t = operators::SequenceReverseParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto* output = param.Out; + const auto* din = param.X->data(); + + T* dout = output->mutable_data(); + CHECK_NE(din, dout) + << "SequenceReverse Op does not support in-place operation"; + const auto lod = param.X->lod()[param.X->lod().size() - 1]; + const size_t lod_count = lod.size(); + + size_t limit = static_cast(param.X->numel()); + size_t row_numel = static_cast(limit / param.X->dims()[0]); + + for (size_t idx = 0; idx < lod_count - 1; ++idx) { + auto start_pos = lod[idx]; + auto end_pos = lod[idx + 1]; + for (auto pos = start_pos; pos < end_pos; ++pos) { + auto cur_pos = end_pos - pos - 1 + start_pos; + std::memcpy(dout + pos * row_numel, + din + cur_pos * row_numel, + row_numel * sizeof(T)); + } + } + output->set_lod(param.X->lod()); + } + + virtual ~SequenceReverseCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_reverse_compute_test.cc b/lite/kernels/x86/sequence_reverse_compute_test.cc new file mode 100644 index 0000000000..46eab42952 --- /dev/null +++ b/lite/kernels/x86/sequence_reverse_compute_test.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_reverse_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +namespace { +static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) { + const auto* x_data = x->data(); + auto seq_offset = x->lod()[x->lod().size() - 1]; + int width = x->numel() / x->dims()[0]; + auto* y_data = y->mutable_data(); + for (int i = 0; i < seq_offset.size() - 1; ++i) { + auto start_pos = seq_offset[i]; + auto end_pos = seq_offset[i + 1]; + for (auto pos = start_pos; pos < end_pos; ++pos) { + auto cur_pos = end_pos - pos - 1 + start_pos; + std::memcpy(y_data + pos * width, + x_data + cur_pos * width, + width * sizeof(float)); + } + } +} +} // namespace + +TEST(sequence_reverse_x86, retrive_op) { + auto sequence_reverse = + KernelRegistry::Global().Create( + "sequence_reverse"); + ASSERT_FALSE(sequence_reverse.empty()); + ASSERT_TRUE(sequence_reverse.front()); +} + +TEST(sequence_reverse_x86, init) { + SequenceReverseCompute sequence_reverse; + ASSERT_EQ(sequence_reverse.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_reverse.target(), TARGET(kX86)); +} + +TEST(sequence_reverse_x86, run_test) { + SequenceReverseCompute seq_kernel; + std::unique_ptr ctx(new KernelContext); + + operators::SequenceReverseParam param; + lite::Tensor x, x_ref; + lite::Tensor y, y_ref; + + int32_t lod_len = 10, feature_len = 4; + LoD lod_info{{0, 2, 4}, {0, 3, 5, 6, 10}}; + + x.Resize({lod_len, feature_len}); + x_ref.Resize({lod_len, feature_len}); + y.Resize({lod_len, feature_len}); + y_ref.Resize({lod_len, feature_len}); + x.set_lod(lod_info); + x_ref.set_lod(lod_info); + y.set_lod(lod_info); + y_ref.set_lod(lod_info); + + auto* y_data = y.mutable_data(); + float* x_data = x.mutable_data(); + float* x_ref_data = x_ref.mutable_data(); + float* y_ref_data = y_ref.mutable_data(); + + for (int i = 0; i < x.numel(); ++i) { + x_ref_data[i] = (i - 2.0) * 1.0; + x_data[i] = (i - 2.0) * 1.0; + } + + param.X = &x; + param.Out = &y; + seq_kernel.SetParam(param); + + seq_kernel.SetContext(std::move(ctx)); + seq_kernel.Run(); + + sequence_reverse_ref(&x_ref, &y_ref); + for (int i = 0; i < y.numel(); i++) { + EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_reverse, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 5b868a3d7e..b56606ba3c 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -78,6 +78,7 @@ add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS}) add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) +add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 4eb4c4f688..cbe0054ba3 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -751,6 +751,11 @@ struct SequenceExpandAsParam { lite::Tensor* out{nullptr}; }; +struct SequenceReverseParam { + const lite::Tensor* X{}; + lite::Tensor* Out{}; +}; + struct SequenceConcatParam { std::vector X{}; lite::Tensor* Out{}; diff --git a/lite/operators/sequence_reverse_op.cc b/lite/operators/sequence_reverse_op.cc new file mode 100644 index 0000000000..dd8fa2e8fd --- /dev/null +++ b/lite/operators/sequence_reverse_op.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_reverse_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceReverseOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + CHECK_EQ(param_.X->lod().empty(), false) + << "Input(X) Tensor of SequenceReverseOp does not contain " + "LoD information."; + CHECK_GE(param_.X->dims().size(), 2) + << "Rank of Input(X) must be not less than 2."; + return true; +} + +bool SequenceReverseOp::InferShape() const { + const auto *input = param_.X; + auto out_dims = input->dims(); + param_.Out->Resize(out_dims); + return true; +} + +bool SequenceReverseOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + param_.X = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.Out = + scope->FindVar(opdesc.Output("Y").front())->GetMutable(); + CHECK(param_.X); + CHECK(param_.Out); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_reverse, paddle::lite::operators::SequenceReverseOp); diff --git a/lite/operators/sequence_reverse_op.h b/lite/operators/sequence_reverse_op.h new file mode 100644 index 0000000000..326d0f6892 --- /dev/null +++ b/lite/operators/sequence_reverse_op.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceReverseOp : public OpLite { + public: + SequenceReverseOp() {} + explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {} + bool CheckShape() const override; + bool InferShape() const override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_reverse"; } + + private: + mutable SequenceReverseParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle -- GitLab