diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 83ae221b86e5402c22bc7a1f18e92103ee74b915..979e7f2730cc9040126781c5573adfff81835151 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -25,6 +25,7 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute. add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm) @@ -45,7 +46,9 @@ nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc D nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda) nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda) nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) +nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda) + if(LITE_BUILD_EXTRA) nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda) endif() diff --git a/lite/kernels/cuda/sequence_arithmetic_compute.cu b/lite/kernels/cuda/sequence_arithmetic_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..5ca12267f983c85f7a71866bd6d761994a8efe81 --- /dev/null +++ b/lite/kernels/cuda/sequence_arithmetic_compute.cu @@ -0,0 +1,250 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" +#include "lite/kernels/cuda/sequence_arithmetic_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +const int CUDA_NUM_THREADS = 512; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline int CUDA_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +template +__global__ void ker_arithmetic_sum(Dtype* out_data, + const Dtype* in_data_0, + const Dtype* in_data_1, + const int* offset_0, + const int* offset_1, + const int* word_id_to_seq_id, + const int seq_num, + const int inner_size, + const int count) { + CUDA_KERNEL_LOOP(tid, count) { + int emb_id = tid % inner_size; + int word_id = tid / inner_size; + int seq_id = word_id_to_seq_id[word_id]; + int word_id_in_cur_seq = word_id - offset_0[seq_id]; + int seq_len_1 = offset_1[seq_id + 1] - offset_1[seq_id]; + if (word_id_in_cur_seq < seq_len_1) { + out_data[tid] = + in_data_0[tid] + + in_data_1[(offset_1[seq_id] + word_id_in_cur_seq) * inner_size + + emb_id]; + } else { + out_data[tid] = in_data_0[tid]; + } + } +} + +template +__global__ void ker_arithmetic_sub(Dtype* out_data, + const Dtype* in_data_0, + const Dtype* in_data_1, + const int* offset_0, + const int* offset_1, + const int* word_id_to_seq_id, + const int seq_num, + const int inner_size, + const int count) { + CUDA_KERNEL_LOOP(tid, count) { + int emb_id = tid % inner_size; + int word_id = tid / inner_size; + int seq_id = word_id_to_seq_id[word_id]; + int word_id_in_cur_seq = word_id - offset_0[seq_id]; + int seq_len_1 = offset_1[seq_id + 1] - offset_1[seq_id]; + if (word_id_in_cur_seq < seq_len_1) { + out_data[tid] = + in_data_0[tid] - + in_data_1[(offset_1[seq_id] + word_id_in_cur_seq) * inner_size + + emb_id]; + } else { + out_data[tid] = in_data_0[tid]; + } + } +} + +template +__global__ void ker_arithmetic_mul(Dtype* out_data, + const Dtype* in_data_0, + const Dtype* in_data_1, + const int* offset_0, + const int* offset_1, + const int* word_id_to_seq_id, + const int seq_num, + const int inner_size, + const int count) { + CUDA_KERNEL_LOOP(tid, count) { + int emb_id = tid % inner_size; + int word_id = tid / inner_size; + int seq_id = word_id_to_seq_id[word_id]; + int word_id_in_cur_seq = word_id - offset_0[seq_id]; + int seq_len_1 = offset_1[seq_id + 1] - offset_1[seq_id]; + if (word_id_in_cur_seq < seq_len_1) { + out_data[tid] = + in_data_0[tid] * + in_data_1[(offset_1[seq_id] + word_id_in_cur_seq) * inner_size + + emb_id]; + } else { + out_data[tid] = in_data_0[tid]; + } + } +} + +void SequenceArithmeticCompute::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + auto stream = ctx.exec_stream(); + + auto x_data = param.X->data(); + auto x_lod = param.X->lod()[0]; + auto y_data = param.X->data(); + auto y_lod = param.Y->lod()[0]; + auto out_data = param.Out->mutable_data(TARGET(kCUDA)); + + offset_x.Resize({static_cast(x_lod.size())}); + auto offset_x_data = offset_x.mutable_data(TARGET(kCUDA)); + + offset_y.Resize({static_cast(y_lod.size())}); + auto offset_y_data = offset_y.mutable_data(TARGET(kCUDA)); + + word_id_to_seq_id.Resize({param.X->numel()}); + auto word_id_to_seq_id_data = + word_id_to_seq_id.mutable_data(TARGET(kCUDA)); + + std::vector word_seq_map; + for (int i = 0; i < x_lod.size() - 1; i++) { + for (int j = x_lod[i]; j < x_lod[i + 1]; j++) { + word_seq_map.push_back(i); + } + } + + std::vector offset_x_data_cpu(x_lod.size(), 0); + auto x_lod_data = x_lod.data(); + for (int i = 0; i < offset_x_data_cpu.size(); i++) { + offset_x_data_cpu[i] = x_lod_data[i]; + } + + std::vector offset_y_data_cpu(y_lod.size(), 0); + auto y_lod_data = y_lod.data(); + for (int i = 0; i < offset_y_data_cpu.size(); i++) { + offset_y_data_cpu[i] = y_lod_data[i]; + } + + TargetWrapperCuda::MemcpyAsync(offset_x_data, + offset_x_data_cpu.data(), + sizeof(int) * x_lod.size(), + IoDirection::HtoD, + stream); + + TargetWrapperCuda::MemcpyAsync(offset_y_data, + offset_y_data_cpu.data(), + sizeof(int) * y_lod.size(), + IoDirection::HtoD, + stream); + + TargetWrapperCuda::MemcpyAsync(word_id_to_seq_id_data, + word_seq_map.data(), + sizeof(int) * word_seq_map.size(), + IoDirection::HtoD, + stream); + + int seq_num = x_lod.size() - 1; + int count = param.X->numel(); + int inner_size = param.X->dims()[1]; + + switch (param.op_type) { + case 1: // sum + ker_arithmetic_sum< + float><<>>( + out_data, + x_data, + y_data, + offset_x_data, + offset_y_data, + word_id_to_seq_id_data, + seq_num, + inner_size, + count); + break; + case 2: // sub + ker_arithmetic_sub< + float><<>>( + out_data, + x_data, + y_data, + offset_x_data, + offset_y_data, + word_id_to_seq_id_data, + seq_num, + inner_size, + count); + break; + case 3: // mul + ker_arithmetic_mul< + float><<>>( + out_data, + x_data, + y_data, + offset_x_data, + offset_y_data, + word_id_to_seq_id_data, + seq_num, + inner_size, + count); + break; + default: + break; + } + + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_arithmetic, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SequenceArithmeticCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); +REGISTER_LITE_KERNEL(search_seq_arithmetic, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::SequenceArithmeticCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .Finalize(); diff --git a/lite/kernels/cuda/sequence_arithmetic_compute.h b/lite/kernels/cuda/sequence_arithmetic_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..a180c50eaa810511f8d72902e81bcd9abdaca31e --- /dev/null +++ b/lite/kernels/cuda/sequence_arithmetic_compute.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +class SequenceArithmeticCompute + : public KernelLite { + public: + using param_t = operators::SequenceArithmeticParam; + + void Run() override; + virtual ~SequenceArithmeticCompute() = default; + + private: + lite::Tensor offset_x; + lite::Tensor offset_y; + lite::Tensor word_id_to_seq_id; +}; + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/sequence_arithmetic_compute_test.cc b/lite/kernels/cuda/sequence_arithmetic_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0746d375d5c43d68cfad1896e7a3ab6178e2c35 --- /dev/null +++ b/lite/kernels/cuda/sequence_arithmetic_compute_test.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/cuda/sequence_arithmetic_compute.h" +#include +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace cuda { + +void sequence_arithmetic_compute_ref(const Tensor& x, + const Tensor& y, + Tensor* out, + int op_type) { + auto x_data = x.data(); + auto y_data = y.data(); + out->Resize(x.dims()); + out->set_lod(x.lod()); + auto out_data = out->mutable_data(); + auto x_seq_offset = x.lod()[0]; + auto y_seq_offset = y.lod()[0]; + int seq_num = x_seq_offset.size() - 1; + int inner_size = x.numel() / x.dims()[0]; + + for (int i = 0; i < seq_num; i++) { + int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size; + int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size; + auto input_x = x_data + x_seq_offset[i] * inner_size; + auto input_y = y_data + y_seq_offset[i] * inner_size; + auto t_out = out_data + x_seq_offset[i] * inner_size; + int len = std::min(len_x, len_y); + for (int j = 0; j < len; j++) { + switch (op_type) { + case 1: + t_out[j] = input_x[j] + input_y[j]; + break; + case 2: + t_out[j] = input_x[j] - input_y[j]; + break; + case 3: + t_out[j] = input_x[j] * input_y[j]; + break; + default: + break; + } + } + if (len_x > len) { + memcpy(t_out + len, input_x + len, sizeof(float) * (len_x - len)); + } + } +} + +void prepare_input(Tensor* x, const LoD& x_lod) { + x->Resize({static_cast(x_lod[0].back()), 3}); + x->set_lod(x_lod); + auto x_data = x->mutable_data(); + for (int i = 0; i < x->numel(); i++) { + x_data[i] = (i - x->numel() / 2) * 1.1; + } +} + +TEST(sequence_arithmetic_cuda, run_test) { + lite::Tensor x, y, x_cpu, y_cpu; + lite::Tensor out, out_cpu, out_ref; + lite::LoD x_lod{{0, 2, 5, 9}}, y_lod{{0, 2, 5, 9}}; + + prepare_input(&x_cpu, x_lod); + prepare_input(&y_cpu, y_lod); + + x.Resize(x_cpu.dims()); + x.set_lod(x_cpu.lod()); + auto x_cpu_data = x_cpu.mutable_data(); + x.Assign(x_cpu_data, x_cpu.dims()); + + y.Resize(y_cpu.dims()); + y.set_lod(y_cpu.lod()); + auto y_cpu_data = y_cpu.mutable_data(); + y.Assign(y_cpu_data, y_cpu.dims()); + + operators::SequenceArithmeticParam param; + param.X = &x; + param.Y = &y; + param.Out = &out; + param.op_type = 1; + + std::unique_ptr ctx(new KernelContext); + auto& context = ctx->As(); + cudaStream_t stream; + cudaStreamCreate(&stream); + context.SetExecStream(stream); + + SequenceArithmeticCompute sequence_arithmetic; + sequence_arithmetic.SetContext(std::move(ctx)); + sequence_arithmetic.SetParam(param); + sequence_arithmetic.Run(); + cudaDeviceSynchronize(); + + auto out_data = out.mutable_data(TARGET(kCUDA)); + out_cpu.Resize(out.dims()); + auto out_cpu_data = out_cpu.mutable_data(); + CopySync( + out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH); + + sequence_arithmetic_compute_ref(x_cpu, y_cpu, &out_ref, param.op_type); + auto out_ref_data = out_ref.data(); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-3); + } +} + +} // namespace cuda +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 552e7ff109f7c2cc119882269489791233cfa68a..423fa5355288fcfa0f2cc50f3930cb6e5f013b8d 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -46,6 +46,7 @@ add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_ add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type) +add_kernel(sequence_arithmetic_compute_x86 X86 basic SRCS sequence_arithmetic_compute.cc DEPS ${lite_kernel_deps}) # for content-dnn specific add_kernel(search_aligned_mat_mul_compute_x86 X86 extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} blas) @@ -87,3 +88,4 @@ endif() lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86) lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86) lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86) +lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86) diff --git a/lite/kernels/x86/sequence_arithmetic_compute.cc b/lite/kernels/x86/sequence_arithmetic_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..95fa27e3d4e7fdbf639a5b275568311907f8344d --- /dev/null +++ b/lite/kernels/x86/sequence_arithmetic_compute.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_arithmetic_compute.h" + +REGISTER_LITE_KERNEL( + sequence_arithmetic, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceArithmeticCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); +REGISTER_LITE_KERNEL( + search_seq_arithmetic, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::SequenceArithmeticCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/sequence_arithmetic_compute.h b/lite/kernels/x86/sequence_arithmetic_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..88510b8b1c7a04ab01da9af331f9d1f72765b215 --- /dev/null +++ b/lite/kernels/x86/sequence_arithmetic_compute.h @@ -0,0 +1,111 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class SequenceArithmeticCompute + : public KernelLite { + public: + using param_t = operators::SequenceArithmeticParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto x = param.X; + auto y = param.Y; + auto out = param.Out; + int op_type = param.op_type; + + out->Resize(x->dims()); + out->set_lod(x->lod()); + + auto x_data = x->data(); + auto y_data = y->data(); + auto out_data = out->mutable_data(); + auto x_seq_offset = x->lod()[0]; + auto y_seq_offset = y->lod()[0]; + int seq_num = x_seq_offset.size() - 1; + int inner_size = (x->numel()) / (x->dims()[0]); + + // sum + if (op_type == 1) { + for (int i = 0; i < seq_num; i++) { + int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size; + int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size; + auto input_x = x_data + x_seq_offset[i] * inner_size; + auto input_y = y_data + y_seq_offset[i] * inner_size; + auto t_out = out_data + x_seq_offset[i] * inner_size; + int len = std::min(len_x, len_y); + for (int j = 0; j < len; j++) { + t_out[j] = input_x[j] + input_y[j]; + } + if (len_x > len) { + memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len)); + } + } + } + + // sub + if (op_type == 2) { + for (int i = 0; i < seq_num; i++) { + int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size; + int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size; + auto input_x = x_data + x_seq_offset[i] * inner_size; + auto input_y = y_data + y_seq_offset[i] * inner_size; + auto t_out = out_data + x_seq_offset[i] * inner_size; + int len = std::min(len_x, len_y); + for (int j = 0; j < len; j++) { + t_out[j] = input_x[j] - input_y[j]; + } + if (len_x > len) { + memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len)); + } + } + } + + // mul + if (op_type == 3) { + for (int i = 0; i < seq_num; i++) { + int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size; + int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size; + auto input_x = x_data + x_seq_offset[i] * inner_size; + auto input_y = y_data + y_seq_offset[i] * inner_size; + auto t_out = out_data + x_seq_offset[i] * inner_size; + int len = std::min(len_x, len_y); + for (int j = 0; j < len; j++) { + t_out[j] = input_x[j] * input_y[j]; + } + if (len_x > len) { + memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len)); + } + } + } + } + + virtual ~SequenceArithmeticCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/sequence_arithmetic_compute_test.cc b/lite/kernels/x86/sequence_arithmetic_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b41e7d7ce37ebaf6a3f8518bc248ff4ec5c1aec --- /dev/null +++ b/lite/kernels/x86/sequence_arithmetic_compute_test.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/sequence_arithmetic_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +void sequence_arithmetic_compute_ref(const Tensor& x, + const Tensor& y, + Tensor* out, + int op_type) { + auto x_data = x.data(); + auto y_data = y.data(); + out->Resize(x.dims()); + out->set_lod(x.lod()); + auto out_data = out->mutable_data(); + auto x_seq_offset = x.lod()[0]; + auto y_seq_offset = y.lod()[0]; + int seq_num = x_seq_offset.size() - 1; + int inner_size = x.numel() / x.dims()[0]; + + for (int i = 0; i < seq_num; i++) { + int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size; + int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size; + auto input_x = x_data + x_seq_offset[i] * inner_size; + auto input_y = y_data + y_seq_offset[i] * inner_size; + auto t_out = out_data + x_seq_offset[i] * inner_size; + int len = std::min(len_x, len_y); + for (int j = 0; j < len; j++) { + switch (op_type) { + case 1: + t_out[j] = input_x[j] + input_y[j]; + break; + case 2: + t_out[j] = input_x[j] - input_y[j]; + break; + case 3: + t_out[j] = input_x[j] * input_y[j]; + break; + default: + break; + } + } + if (len_x > len) { + memcpy(t_out + len, input_x + len, sizeof(float) * (len_x - len)); + } + } +} + +void prepare_input(Tensor* x, const LoD& x_lod) { + x->Resize({static_cast(x_lod[0].back()), 3}); + x->set_lod(x_lod); + auto x_data = x->mutable_data(); + for (int i = 0; i < x->numel(); i++) { + x_data[i] = (i - x->numel() / 2) * 1.1; + } +} + +TEST(sequence_arithmetic_x86, retrive_op) { + auto sequence_arithmetic = + KernelRegistry::Global().Create( + "sequence_arithmetic"); + ASSERT_FALSE(sequence_arithmetic.empty()); + ASSERT_TRUE(sequence_arithmetic.front()); +} + +TEST(sequence_arithmetic_x86, init) { + SequenceArithmeticCompute sequence_arithmetic; + ASSERT_EQ(sequence_arithmetic.precision(), PRECISION(kFloat)); + ASSERT_EQ(sequence_arithmetic.target(), TARGET(kX86)); +} + +TEST(sequence_arithmetic_x86, run_test) { + SequenceArithmeticCompute sequence_arithmetic; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + + lite::Tensor x, y, out, out_ref; + lite::LoD x_lod{{0, 2, 5, 9}}, y_lod{{0, 2, 5, 9}}; + prepare_input(&x, x_lod); + prepare_input(&y, y_lod); + + operators::SequenceArithmeticParam param; + param.X = &x; + param.Y = &y; + param.Out = &out; + param.op_type = 1; + + sequence_arithmetic.SetContext(std::move(ctx)); + sequence_arithmetic.SetParam(param); + sequence_arithmetic.Run(); + + sequence_arithmetic_compute_ref(x, y, &out_ref, param.op_type); + auto out_data = out.data(); + auto out_ref_data = out_ref.data(); + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-3); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(sequence_arithmetic, kX86, kFloat, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index fff000eb7aca4dd0f1d19b844db3aa7800ce9283..bc48f853c6f7bc523faf431089ceebe0b1044301 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -85,6 +85,7 @@ add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS}) add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS}) +add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index f809700350eeaf9398a975df453007df19240817..a3f9e45048ad0a01c9429c163bf38da876dd5f42 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -769,6 +769,13 @@ struct SequenceConcatParam { lite::Tensor* Out{}; }; +struct SequenceArithmeticParam { + const lite::Tensor* X{}; + const lite::Tensor* Y{}; + int op_type{1}; + lite::Tensor* Out{}; +}; + struct ReduceMaxParam { const lite::Tensor* X{}; lite::Tensor* Out{}; diff --git a/lite/operators/sequence_arithmetic_op.cc b/lite/operators/sequence_arithmetic_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c4a28f8a8d6d014eda115a5d475ff295c846c3b --- /dev/null +++ b/lite/operators/sequence_arithmetic_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_arithmetic_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceArithmeticOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Y); + CHECK_EQ(param_.X->dims().size(), 2) << "Input X should a 2-D Tensor"; + CHECK_EQ(param_.Y->dims().size(), 2) << "Input Y should a 2-D Tensor"; + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool SequenceArithmeticOp::InferShape() const { + param_.Out->Resize(param_.X->dims()); + param_.Out->set_lod(param_.X->lod()); + return true; +} + +bool SequenceArithmeticOp::AttachImpl(const cpp::OpDesc &opdesc, + lite::Scope *scope) { + param_.X = scope->FindTensor(opdesc.Input("X").front()); + param_.Y = scope->FindTensor(opdesc.Input("Y").front()); + param_.Out = scope->FindMutableTensor(opdesc.Input("Out").front()); + + param_.op_type = opdesc.GetAttr("op_type"); + + CHECK(param_.X); + CHECK(param_.Y); + CHECK(param_.Out); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_arithmetic, + paddle::lite::operators::SequenceArithmeticOp); +REGISTER_LITE_OP(search_seq_arithmetic, + paddle::lite::operators::SequenceArithmeticOp); diff --git a/lite/operators/sequence_arithmetic_op.h b/lite/operators/sequence_arithmetic_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9f844dfbf429599d829bc786c66ba6d05e40d79d --- /dev/null +++ b/lite/operators/sequence_arithmetic_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceArithmeticOp : public OpLite { + public: + SequenceArithmeticOp() {} + explicit SequenceArithmeticOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "sequence_arithmetic"; } + + private: + mutable SequenceArithmeticParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle