From 0775140a37fa9d592547d1d15bb09f19699d132c Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Fri, 28 Feb 2020 15:31:40 +0800 Subject: [PATCH] add sequence_conv op and arm kernel (#3016) * add sequence_conv op and arm kernel * add test, test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * test=develop * modify code style. test=develop * fix ut, test=develop * delete unused code, test=develop --- lite/backends/arm/math/conv_impl.cc | 16 ++ lite/backends/arm/math/conv_impl.h | 18 +++ lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/sequence_conv_compute.cc | 150 ++++++++++++++++++ lite/kernels/arm/sequence_conv_compute.h | 39 +++++ lite/operators/CMakeLists.txt | 1 + lite/operators/op_params.h | 9 ++ lite/operators/sequence_conv_op.cc | 94 +++++++++++ lite/operators/sequence_conv_op.h | 43 +++++ lite/tests/kernels/CMakeLists.txt | 1 + .../kernels/sequence_conv_compute_test.cc | 149 +++++++++++++++++ 11 files changed, 521 insertions(+) create mode 100644 lite/kernels/arm/sequence_conv_compute.cc create mode 100644 lite/kernels/arm/sequence_conv_compute.h create mode 100644 lite/operators/sequence_conv_op.cc create mode 100644 lite/operators/sequence_conv_op.h create mode 100644 lite/tests/kernels/sequence_conv_compute_test.cc diff --git a/lite/backends/arm/math/conv_impl.cc b/lite/backends/arm/math/conv_impl.cc index 96d0893bc0..9412fc43f1 100644 --- a/lite/backends/arm/math/conv_impl.cc +++ b/lite/backends/arm/math/conv_impl.cc @@ -573,6 +573,22 @@ template void conv_im2col_gemm_int8(const int8_t* i_data, ARMContext* ctx, const float* scale); +template void im2col(const float* data_im, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + float* data_col); + void conv_depthwise_3x3_fp32(const void* din, void* dout, int num, diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index 60f74b7fee..28a2fb7e2a 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -359,6 +359,24 @@ void conv_compute_2x2_3x3_small(const float* input, const float* bias, const operators::ConvParam& param, ARMContext* ctx); + +template +void im2col(const Dtype* data_im, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int pad_top, + int pad_bottom, + int pad_left, + int pad_right, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + Dtype* data_col); + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index f16cbe7f66..65eec4575a 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -68,6 +68,7 @@ add_kernel(reduce_max_compute_arm ARM extra SRCS reduce_max_compute.cc DEPS ${li add_kernel(sequence_expand_compute_arm ARM extra SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_pool_compute_arm ARM extra SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(sequence_conv_compute_arm ARM extra SRCS sequence_conv_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layer_norm_compute_arm ARM extra SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(gather_compute_arm ARM extra SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_prod_compute_arm ARM extra SRCS reduce_prod_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/sequence_conv_compute.cc b/lite/kernels/arm/sequence_conv_compute.cc new file mode 100644 index 0000000000..190f8b504c --- /dev/null +++ b/lite/kernels/arm/sequence_conv_compute.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/kernels/arm/sequence_conv_compute.h" +#include +#include +#include +#include +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/sgemm.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +template +void local_naive_transpose(const Dtype* din, Dtype* dout, int m, int n) { + int k = 0; + for (int i = 0; i < n; ++i) { + for (int j = 0; j < m; ++j) { + dout[k++] = din[j * n + i]; + } + } +} + +void SequenceConvCompute::PrepareForRun() {} + +void SequenceConvCompute::Run() { + // param.X is in shape: [sequence_len, hidden_dim]; + // param.Filter is in shape: [kernel_size * hidden_dim, kernel_num] + // param.contextLength : kernel_size + // param.contextStart: for padding idx + // param.Out is in shape [new_sequence_len, kernel_num] + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + const auto* in_data = param.X->data(); + const auto* filter_data = param.Filter->data(); + float* out_data = param.Out->mutable_data(); + int pad_start = param.contextStart; + int kernel_size = param.contextLength; + int kernel_num = param.Filter->dims()[1]; + int up_pad = std::max(0, -pad_start); + int down_pad = std::max(0, pad_start + kernel_size - 1); + auto hidden_dim = static_cast(param.X->dims()[1]); + auto sequence_len = static_cast(param.X->dims()[0]); + auto lod = param.X->lod(); + + // Im2Col + lite::Tensor col; + lite::Tensor tmp; + col.Resize({sequence_len, kernel_size * hidden_dim}); + auto* col_data = col.mutable_data(); + auto lod_level_0 = lod[0]; + int input_row_begin, input_row_end; + for (int i = 0; i < static_cast(lod_level_0.size()) - 1; i++) { + if (lod_level_0[i] == lod_level_0[i + 1]) continue; + input_row_begin = (pad_start > 0) + ? static_cast(lod_level_0[i]) + pad_start + : static_cast(lod_level_0[i]); + input_row_end = static_cast(lod_level_0[i + 1]); + + if (input_row_begin < input_row_end) { + // do im2col + auto* sub_in_data = in_data + input_row_begin * hidden_dim; + auto* sub_col_data = + col_data + input_row_begin * kernel_size * hidden_dim; + tmp.Resize({kernel_size * hidden_dim, input_row_end - input_row_begin}); + auto* tmp_data = tmp.mutable_data(); + // Image Col: [input_channels, filter_height, filter_width, output_height, + // output_width] + // sequence Col: [1, kernel_size, hidden_dim, sequence_len, 1] + paddle::lite::arm::math::im2col( + sub_in_data, + 1, + sequence_len, + hidden_dim, // C H W -> 1, seq_len, hidden_dim + kernel_size, + hidden_dim, // kernel_h, kernel_w + up_pad, + down_pad, + 0, + 0, // pad_top, pad_bottom, pad_left, pad_right + 1, + 1, + 1, + 1, // stride_h, stride_w, dilation_h, dilation_w + tmp_data); + local_naive_transpose(tmp_data, + sub_col_data, + kernel_size * hidden_dim, + input_row_end - input_row_begin); + } + } + + // SGDMM C := alpha * A * B + beta * C + // matmul: col * filter_data + // [sequence_len, kernel_size * hidden_dim] * [kernel_size * hidden_dim, + // kernel_num] + // = [sequence_len, kernel_num] + paddle::lite::operators::ActivationParam act_param; + paddle::lite::arm::math::sgemm(false, + false, // is_transB, + sequence_len, // M + kernel_num, // N + kernel_size * hidden_dim, // K + 1.0f, // alpha + col_data, // A + kernel_size * hidden_dim, // lda: k + filter_data, // B + kernel_num, // ldb: n + 0.f, // beta + out_data, // C + sequence_len, // ldc: m + NULL, // bias + false, // is_bias + act_param, // act_param + &ctx); // ctx +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(sequence_conv, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::SequenceConvCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/sequence_conv_compute.h b/lite/kernels/arm/sequence_conv_compute.h new file mode 100644 index 0000000000..d63b72b006 --- /dev/null +++ b/lite/kernels/arm/sequence_conv_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class SequenceConvCompute : public KernelLite { + public: + void PrepareForRun() override; + + void Run() override; + + virtual ~SequenceConvCompute() = default; + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index eb39924e71..8ff355d11f 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -90,6 +90,7 @@ add_operator(reduce_prod_op_lite extra SRCS reduce_prod_op.cc DEPS ${op_DEPS}) add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS}) add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) +add_operator(sequence_conv extra SRCS sequence_conv_op.cc DEPS ${op_DEPS}) add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 8ad99c99af..4ea1b27080 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -773,6 +773,15 @@ struct SequencePoolParam { #endif }; +struct SequenceConvParam { + const lite::Tensor* X{}; + const lite::Tensor* Filter{}; + lite::Tensor* Out{}; + int contextStart{0}; + int contextStride{1}; + int contextLength; +}; + struct SequencePoolConcatParam { std::vector X{}; lite::Tensor* Out{}; diff --git a/lite/operators/sequence_conv_op.cc b/lite/operators/sequence_conv_op.cc new file mode 100644 index 0000000000..89596a22c6 --- /dev/null +++ b/lite/operators/sequence_conv_op.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/sequence_conv_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool SequenceConvOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Filter); + CHECK_OR_FALSE(param_.Out); + + // currently we only support the case that + // the contextStride is equal to 1 + int context_length = param_.contextLength; + int context_start = param_.contextStart; + CHECK_EQ_OR_FALSE(param_.contextStride, 1UL); + CHECK_GT_OR_FALSE(context_start, -context_length); + CHECK_GE_OR_FALSE(0, context_start); + + const auto *filter = param_.Filter; + auto lod = param_.X->lod(); + auto filter_dims = filter->dims(); + auto in_dims = param_.X->dims(); + CHECK_EQ_OR_FALSE(in_dims.size(), 2UL); + CHECK_EQ_OR_FALSE(filter_dims.size(), 2UL); + CHECK_EQ_OR_FALSE(lod.size(), 1UL); + CHECK_EQ_OR_FALSE(filter_dims[0], context_length * in_dims[1]); + CHECK_GE_OR_FALSE(in_dims[0], (static_cast(lod[0].size()) - 1)); + return true; +} + +bool SequenceConvOp::InferShape() const { + const auto *input = param_.X; + const auto *filter = param_.Filter; + auto in_dims = input->dims(); + auto filter_dims = filter->dims(); + auto out_dims = in_dims; + out_dims[1] = filter_dims[1]; + param_.Out->Resize(out_dims); + param_.Out->set_lod(param_.X->lod()); + return true; +} + +bool SequenceConvOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { + // required params + param_.X = const_cast( + &scope->FindVar(opdesc.Input("X").front())->Get()); + param_.Filter = const_cast( + &scope->FindVar(opdesc.Input("Filter").front())->Get()); + param_.Out = + scope->FindVar(opdesc.Output("Out").front())->GetMutable(); + param_.contextStart = opdesc.GetAttr("contextStart"); + param_.contextStride = opdesc.GetAttr("contextStride"); + param_.contextLength = opdesc.GetAttr("contextLength"); + + // PaddingData is not supported for now + std::vector input_arg_names = opdesc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), + input_arg_names.end(), + "PaddingData") != input_arg_names.end()) { + auto padding_data_arguments = opdesc.Input("PaddingData"); + CHECK_EQ_OR_FALSE(padding_data_arguments.size(), 0); + } + + // paddingTrainable == True is not supported for now. + if (opdesc.HasAttr("paddingTrainable")) { + CHECK_OR_FALSE(!opdesc.GetAttr("paddingTrainable")); + } + CHECK(param_.X); + CHECK(param_.Filter); + CHECK(param_.Out); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(sequence_conv, paddle::lite::operators::SequenceConvOp); diff --git a/lite/operators/sequence_conv_op.h b/lite/operators/sequence_conv_op.h new file mode 100644 index 0000000000..34d65d3cc9 --- /dev/null +++ b/lite/operators/sequence_conv_op.h @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class SequenceConvOp : public OpLite { + public: + SequenceConvOp() {} + explicit SequenceConvOp(const std::string &op_type) : OpLite(op_type) {} + bool CheckShape() const override; + bool InferShape() const override; + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "sequence_conv"; } + + private: + mutable SequenceConvParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index c55f62c029..9683a56fc8 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -38,6 +38,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/sequence_conv_compute_test.cc b/lite/tests/kernels/sequence_conv_compute_test.cc new file mode 100644 index 0000000000..798eb909fb --- /dev/null +++ b/lite/tests/kernels/sequence_conv_compute_test.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +class SequenceConvComputeTester : public arena::TestCase { + public: + SequenceConvComputeTester(const Place& place, + const std::string& alias, + LoD lod, + DDim dims, + const int& contextStart, + const int& contextStride, + const int& contextLength, + const int& kernel_num) + : TestCase(place, alias), + lod_(lod), + dims_(dims), + contextStart_(contextStart), + contextStride_(contextStride), + contextLength_(contextLength), + kernel_num_(kernel_num) {} + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("sequence_conv"); + op_desc->SetInput("X", {input_name_}); + op_desc->SetInput("Filter", {filter_name_}); + op_desc->SetOutput("Out", {output_name_}); + op_desc->SetAttr("contextStart", contextStart_); + op_desc->SetAttr("contextStride", contextStride_); + op_desc->SetAttr("contextLength", contextLength_); + } + + void PrepareData() override { + DDim filter_dims( + std::vector{contextLength_ * dims_[1], kernel_num_}); + + std::vector din(dims_.production()); + for (int i = 0; i < dims_[0]; i++) { + for (int j = 0; j < dims_[1]; j++) { + din[i * dims_[1] + j] = + (2.0 * i + 3.0 * j) / (2.0 * dims_[0] + 3.0 * dims_[1]) - 0.5; + } + } + SetCommonTensor(input_name_, dims_, din.data(), lod_); + + std::vector dfilter(filter_dims.production()); + for (int i = 0; i < filter_dims[0]; i++) { + for (int j = 0; j < filter_dims[1]; j++) { + dfilter[i * filter_dims[1] + j] = + (1.5 * i + 2.0 * j) / + (1.5 * filter_dims[0] + 2.0 * filter_dims[1]) - + 0.5; + } + } + SetCommonTensor(filter_name_, filter_dims, dfilter.data(), lod_); + } + + void RunBaseline(Scope* scope) override { + // calculate res the output in this scope + // to compare with the Paddle-Lite calculated one + + auto* output = scope->NewTensor(output_name_); + CHECK(output); + std::vector output_shape({4, 3}); + output->Resize(DDim(output_shape)); + auto output_dims = output->dims(); + auto output_data = output->mutable_data(); + std::vector> res; + if (contextStart_ == -2) { + res = {{-0.08867277, -0.17257819, -0.2564836}, + {0.194508, 0.05720823, -0.08009153}, + {0.73512584, 0.5749428, 0.41475973}, + {0.5635012, 0.49485126, 0.42620137}}; + } else if (contextStart_ == -1) { + res = {{0.194508, 0.05720823, -0.08009153}, + {0.73512584, 0.5749428, 0.41475973}, + {0.5635012, 0.49485126, 0.42620137}, + {0.2517162, 0.23646072, 0.22120519}}; + } else if (contextStart_ == 0) { + res = {{0.73512584, 0.5749428, 0.41475973}, + {0.5635012, 0.49485126, 0.42620137}, + {0.2517162, 0.23646072, 0.22120519}, + {0.02574372, 0.03337148, 0.04099924}}; + } else { + fprintf(stderr, "not supported contextStart_\n"); + exit(-1); + } + for (int i = 0; i < output_shape[0]; i++) { + for (int j = 0; j < output_shape[1]; j++) { + output_data[i * output_shape[0] + j] = res[i][j]; + } + } + (output->mutable_lod())->push_back(lod_[0]); + } + + protected: + std::string input_name_ = "x"; + std::string filter_name_ = "filter"; + std::string output_name_ = "out"; + LoD lod_; + DDim dims_; + int contextStart_; + int contextStride_; + int contextLength_; + int kernel_num_; +}; + +void TestNormalCase(Place place, float abs_error = 2e-5) { + std::vector> lod{{0, 4}}; + std::vector dims{4, 5}; + std::vector candidate_pad_idx{-2, -1, 0}; + for (int pad_idx : candidate_pad_idx) { + std::unique_ptr tester(new SequenceConvComputeTester( + place, "def", lod, DDim(dims), pad_idx, 1, 3, 3)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } +} + +TEST(sequence_conv, precision) { +#ifdef LITE_WITH_ARM + float abs_error = 2e-5; + Place place(TARGET(kARM)); + + TestNormalCase(place, abs_error); +#endif +} + +} // namespace lite +} // namespace paddle -- GitLab