From 264706002d91c121413080bbb6a6a39bdbc048f4 Mon Sep 17 00:00:00 2001 From: juncaipeng <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 12 Nov 2019 11:31:14 +0800 Subject: [PATCH] Upgrade concat and unsqueeze, test=develop (#2378) * update concat and unsqueeze, test=develop --- lite/kernels/arm/concat_compute.cc | 7 + lite/kernels/arm/unsqueeze_compute.cc | 8 + lite/kernels/cuda/concat_compute.cu | 7 + lite/kernels/x86/concat_compute.cc | 2 + lite/kernels/x86/concat_compute.h | 5 + lite/operators/concat_op.cc | 26 ++- lite/operators/op_params.h | 3 + lite/operators/unsqueeze_op.cc | 48 ++++- lite/tests/kernels/CMakeLists.txt | 1 + lite/tests/kernels/concat_compute_test.cc | 177 +++++++++++++++++++ lite/tests/kernels/unsqueeze_compute_test.cc | 59 +++++-- 11 files changed, 329 insertions(+), 14 deletions(-) create mode 100644 lite/tests/kernels/concat_compute_test.cc diff --git a/lite/kernels/arm/concat_compute.cc b/lite/kernels/arm/concat_compute.cc index ccae9e0df6..eb59affea3 100644 --- a/lite/kernels/arm/concat_compute.cc +++ b/lite/kernels/arm/concat_compute.cc @@ -39,6 +39,11 @@ void ConcatCompute::Run() { std::vector inputs = param.x; auto* out = param.output; int axis = param.axis; + auto* axis_tensor = param.axis_tensor; + if (axis_tensor != nullptr) { + auto* axis_tensor_data = axis_tensor->data(); + axis = axis_tensor_data[0]; + } out->mutable_data(); /// Sometimes direct copies will be faster, this maybe need deeply analysis. @@ -83,5 +88,7 @@ void ConcatCompute::Run() { REGISTER_LITE_KERNEL( concat, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConcatCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/unsqueeze_compute.cc b/lite/kernels/arm/unsqueeze_compute.cc index 3dc7a274df..e623407c2e 100644 --- a/lite/kernels/arm/unsqueeze_compute.cc +++ b/lite/kernels/arm/unsqueeze_compute.cc @@ -55,6 +55,10 @@ REGISTER_LITE_KERNEL(unsqueeze, paddle::lite::kernels::host::UnsqueezeCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("AxesTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("AxesTensorList", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); @@ -65,6 +69,10 @@ REGISTER_LITE_KERNEL(unsqueeze2, paddle::lite::kernels::host::Unsqueeze2Compute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("AxesTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("AxesTensorList", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/cuda/concat_compute.cu b/lite/kernels/cuda/concat_compute.cu index 89a5be142a..9ec6936672 100644 --- a/lite/kernels/cuda/concat_compute.cu +++ b/lite/kernels/cuda/concat_compute.cu @@ -51,6 +51,11 @@ void ConcatCompute::Run() { Tensor* output = param.output; auto* output_data = output->mutable_data(TARGET(kCUDA)); int axis = param.axis; + auto* axis_tensor = param.axis_tensor; + if (axis_tensor != nullptr) { + auto* axis_tensor_data = axis_tensor->data(); + axis = axis_tensor_data[0]; + } int inner_size = 1; int outer_size = 1; auto input_dims = input[0]->dims(); @@ -97,5 +102,7 @@ REGISTER_LITE_KERNEL(concat, paddle::lite::kernels::cuda::ConcatCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); diff --git a/lite/kernels/x86/concat_compute.cc b/lite/kernels/x86/concat_compute.cc index a95ec54b4e..36f15a84fe 100644 --- a/lite/kernels/x86/concat_compute.cc +++ b/lite/kernels/x86/concat_compute.cc @@ -21,5 +21,7 @@ REGISTER_LITE_KERNEL(concat, paddle::lite::kernels::x86::ConcatCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/lite/kernels/x86/concat_compute.h b/lite/kernels/x86/concat_compute.h index 3fd1e9f233..2c6419a3c3 100644 --- a/lite/kernels/x86/concat_compute.h +++ b/lite/kernels/x86/concat_compute.h @@ -40,6 +40,11 @@ class ConcatCompute : public KernelLite { void Run() override { auto& param = *param_.get_mutable(); int64_t axis = static_cast(param.axis); + auto* axis_tensor = param.axis_tensor; + if (axis_tensor != nullptr) { + auto* axis_tensor_data = axis_tensor->data(); + axis = static_cast(axis_tensor_data[0]); + } auto x_dims = param.x[0]->dims(); auto out = param.output; if (param.x.size() == 1) { diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index dfd95e4658..1941a88bbf 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -31,14 +31,25 @@ bool ConcatOpLite::InferShape() const { for (auto p : param_.x) { input_dims.push_back(p->dims()); } - size_t axis = static_cast(param_.axis); const size_t n = input_dims.size(); CHECK_GT_OR_FALSE(n, 0); + + int axis = 0; + if (param_.axis_tensor == nullptr) { + axis = param_.axis; + } else { + auto *axis_tensor_val = param_.axis_tensor->data(); + axis = axis_tensor_val[0]; + } + if (axis < 0) { + axis += input_dims[0].size(); + } + auto &out_dims = input_dims[0]; size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { for (size_t j = 0; j < in_zero_dims_size; j++) { - if (j == axis) { + if (j == static_cast(axis)) { out_dims[axis] += input_dims[i][j]; } else { CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); @@ -68,6 +79,17 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { param_.output = scope->FindVar(out)->GetMutable(); param_.axis = op_desc.GetAttr("axis"); + std::vector input_arg_names = op_desc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "AxisTensor") != + input_arg_names.end()) { + auto arguments = op_desc.Input("AxisTensor"); + if (arguments.size() > 0) { + auto var = scope->FindVar(arguments.front()); + if (var != nullptr) { + param_.axis_tensor = var->GetMutable(); + } + } + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 7ed7715c30..8609f17888 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -207,6 +207,7 @@ struct ConcatParam { std::vector x{}; lite::Tensor* output{}; int axis{0}; + lite::Tensor* axis_tensor{}; }; /// ----------------------- activation operators ---------------------- @@ -854,6 +855,8 @@ struct UnsqueezeParam { lite::Tensor* Out{}; lite::Tensor* XShape{}; std::vector axes{}; + const lite::Tensor* axes_tensor{}; + std::vector* axes_tensor_vct{}; }; /// ----------------------- expand operators ---------------------- diff --git a/lite/operators/unsqueeze_op.cc b/lite/operators/unsqueeze_op.cc index aca9a9c0e8..8db14d0660 100644 --- a/lite/operators/unsqueeze_op.cc +++ b/lite/operators/unsqueeze_op.cc @@ -63,9 +63,30 @@ bool UnsqueezeOp::CheckShape() const { } bool UnsqueezeOp::InferShape() const { - std::vector unsqueeze_dims = param_.axes; + std::vector final_axes; + auto axes = param_.axes; + auto *axes_tensor = param_.axes_tensor; + std::vector axes_tensor_vct; + if (param_.axes_tensor_vct) { + axes_tensor_vct = *(param_.axes_tensor_vct); + } + + if (!axes.empty()) { + final_axes = axes; + } else if (axes_tensor != nullptr) { + auto *axes_tensor_data = axes_tensor->data(); + final_axes = std::vector(axes_tensor_data, + axes_tensor_data + axes_tensor->numel()); + } else if (!axes_tensor_vct.empty()) { + for (int i = 0; i < axes_tensor_vct.size(); i++) { + final_axes.push_back(axes_tensor_vct[i].data()[0]); + } + } else { + LOG(FATAL) << "Input axis error"; + } + DDim in_dims = param_.X->dims(); - DDim out_dim = GetOutputShape(unsqueeze_dims, in_dims); + DDim out_dim = GetOutputShape(final_axes, in_dims); param_.Out->Resize(out_dim); return true; } @@ -81,6 +102,29 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { if (opdesc.HasAttr("axes")) { param_.axes = opdesc.GetAttr>("axes"); } + + if (opdesc.HasInput("AxesTensor") && opdesc.Input("AxesTensor").size() > 0) { + auto var = scope->FindVar(opdesc.Input("AxesTensor").front()); + if (var != nullptr) { + param_.axes_tensor = var->GetMutable(); + VLOG(5) << "load AxesTensor"; + } + } + + if (opdesc.HasInput("AxesTensorList") && + opdesc.Input("AxesTensorList").size() > 0) { + auto args = opdesc.Input("AxesTensorList"); + /* + for (auto arg : args) { + auto *var = scope->FindVar(arg); + if (var != nullptr) { + param_.axes_tensor_vct.push_back(var->GetMutable()); + } + } + */ + auto *var = scope->FindVar(args.front()); + param_.axes_tensor_vct = var->GetMutable>(); + } CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null."; CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null."; return true; diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 696c278c62..02d40ce6cc 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -22,6 +22,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE #lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) if(LITE_BUILD_EXTRA) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/concat_compute_test.cc b/lite/tests/kernels/concat_compute_test.cc new file mode 100644 index 0000000000..e0ae4c2828 --- /dev/null +++ b/lite/tests/kernels/concat_compute_test.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" + +namespace paddle { +namespace lite { + +DDim infer_shape(const std::vector& inputs, int in_axis) { + std::vector input_dims; + for (auto* tensor : inputs) { + input_dims.push_back(tensor->dims()); + } + size_t axis = static_cast(in_axis); + + DDim out_dims = input_dims[0]; + for (size_t i = 1; i < input_dims.size(); i++) { + for (size_t j = 0; j < input_dims[0].size(); j++) { + if (j == axis) { + out_dims[axis] += input_dims[i][j]; + } else { + if (out_dims[j] != input_dims[i][j]) { + LOG(FATAL) << "infer shape error."; + } + } + } + } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } + + return out_dims; +} + +class ConcateComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::vector x_vct_{}; + std::string out_ = "out"; + std::string axis_tensor_ = "axis_tensor"; + int axis_ = 0; + bool is_use_axis_tensor_ = false; + + int x_num_ = 3; + DDim x_dims_{{2, 3, 4, 5}}; + + public: + ConcateComputeTester(const Place& place, + const std::string& alias, + int axis, + bool is_use_axis_tensor) + : TestCase(place, alias) { + axis_ = axis; + is_use_axis_tensor_ = is_use_axis_tensor; + } + + void RunBaseline(Scope* scope) override { + std::vector x_vct; + for (std::string& name : x_vct_) { + x_vct.push_back(scope->FindTensor(name)); + } + + auto* out = scope->NewTensor(out_); + DDim output_dims = infer_shape(x_vct, axis_); + out->Resize(output_dims); + auto* output_data = out->mutable_data(); + + int num = x_vct.size(); + int rows = 1; + auto dim_0 = x_vct[0]->dims(); + for (int i = 0; i < axis_; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(x_vct.size()); + for (int i = 0; i < num; ++i) { + int input_i_numel = x_vct[i]->dims().size() == 0 ? 0 : 1; + for (int didx = 0; didx < x_vct[i]->dims().size(); ++didx) { + input_i_numel *= x_vct[i]->dims()[didx]; + } + int t_cols = input_i_numel / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + + // computation + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = x_vct[j]->data(); + for (int k = 0; k < out_rows; ++k) { + memcpy(output_data + k * out_cols + col_idx, + input_data + k * col_len, + sizeof(float) * col_len); + } + col_idx += col_len; + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("concat"); + op_desc->SetInput("X", x_vct_); + op_desc->SetAttr("axis", axis_); + if (is_use_axis_tensor_) { + op_desc->SetInput("AxisTensor", {axis_tensor_}); + } + op_desc->SetOutput("Out", {out_}); + } + + void PrepareData() override { + for (int n = 0; n < x_num_; n++) { + std::vector x_data(x_dims_.production()); + for (int i = 0; i < x_dims_.production(); i++) { + x_data[i] = static_cast(i + n); + } + const std::string x_name = "x_tensor_" + std::to_string(n); + x_vct_.push_back(x_name); + SetCommonTensor(x_name, x_dims_, x_data.data()); + } + + if (is_use_axis_tensor_) { + SetCommonTensor(axis_tensor_, DDim({1}), &axis_); + LOG(INFO) << "set axis tensor"; + } + } +}; + +TEST(Concat, precision) { + LOG(INFO) << "test concat op, kARM"; +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + for (int axis : {1, 2}) { + for (bool is_use_axis_tensor : {false, true}) { + LOG(INFO) << "axis:" << axis + << ", is_use_axis_tensor:" << is_use_axis_tensor; + std::unique_ptr tester( + new ConcateComputeTester(place, "def", axis, is_use_axis_tensor)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } +#endif + +#ifdef LITE_WITH_X86 + Place place(TARGET(kX86)); + LOG(INFO) << "test concate op, x86"; + for (int axis : {1, 2}) { + for (bool is_use_axis_tensor : {false, true}) { + LOG(INFO) << "axis:" << axis + << ", is_use_axis_tensor:" << is_use_axis_tensor; + std::unique_ptr tester( + new ConcateComputeTester(place, "def", axis, is_use_axis_tensor)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } + } + +#endif +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index f6f35c615e..9bbf39b70d 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -13,10 +13,10 @@ // limitations under the License. #include +#include #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/core/arena/framework.h" - namespace paddle { namespace lite { @@ -25,15 +25,24 @@ class UnsqueezeComputeTester : public arena::TestCase { // common attributes for this op. std::string x_ = "X"; std::string out_ = "Out"; + std::string axes_tensor_ = "AxesTensor"; + std::vector axes_tensor_list_; std::vector axes_; DDim dims_; + // input_axes_flag_: 1 for axes, 2 for axes_tensor, 3 for axes_tensor_list + int input_axes_flag_ = 1; public: UnsqueezeComputeTester(const Place& place, const std::string& alias, const std::vector& axes, - DDim dims) - : TestCase(place, alias), axes_(axes), dims_(dims) {} + DDim dims, + int input_axes_flag) + : TestCase(place, alias), dims_(dims), input_axes_flag_(input_axes_flag) { + for (int v : axes) { + axes_.push_back(v); + } + } void RunBaseline(Scope* scope) override { const auto* input = scope->FindTensor(x_); @@ -86,7 +95,15 @@ class UnsqueezeComputeTester : public arena::TestCase { op_desc->SetType("unsqueeze"); op_desc->SetInput("X", {x_}); op_desc->SetOutput("Out", {out_}); - op_desc->SetAttr("axes", axes_); + if (input_axes_flag_ == 1) { + op_desc->SetAttr("axes", axes_); + } else if (input_axes_flag_ == 2) { + op_desc->SetInput("AxesTensor", {axes_tensor_}); + } else if (input_axes_flag_ == 3) { + op_desc->SetInput("AxesTensorList", axes_tensor_list_); + } else { + LOG(FATAL) << "input input_axes_flag_ error. " << input_axes_flag_; + } } void PrepareData() override { @@ -95,6 +112,23 @@ class UnsqueezeComputeTester : public arena::TestCase { in_data[i] = i; } SetCommonTensor(x_, dims_, in_data.data()); + + if (input_axes_flag_ == 2) { + DDim axes_tensor_dim{{static_cast(axes_.size())}}; + std::vector axes_tensor_data(axes_.size()); + for (int i = 0; i < axes_tensor_dim.production(); i++) { + axes_tensor_data[i] = axes_[i]; + } + SetCommonTensor(axes_tensor_, axes_tensor_dim, axes_tensor_data.data()); + } else if (input_axes_flag_ == 3) { + std::string name = "axes_tensor_"; + for (size_t i = 0; i < axes_.size(); i++) { + name = name + std::to_string(i); + axes_tensor_list_.push_back(name); + std::vector in_data = {axes_[i]}; + SetCommonTensor(name, DDim({1}), in_data.data()); + } + } } }; @@ -189,17 +223,22 @@ class Unsqueeze2ComputeTester : public arena::TestCase { }; void test_unsqueeze(Place place) { - for (std::vector axes : {std::vector({}), + for (std::vector axes : {std::vector({1}), std::vector({0, 2}), std::vector({0, -2})}) { for (int N : {1}) { for (int C : {3}) { for (int H : {1}) { for (int W : {5}) { - std::unique_ptr tester(new UnsqueezeComputeTester( - place, "def", axes, DDim({N, C, H, W}))); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); + for (int input_axes_flag : {1, 2}) { + LOG(INFO) << N << " " << C << " " << H << " " << W << " " + << input_axes_flag; + std::unique_ptr tester( + new UnsqueezeComputeTester( + place, "def", axes, DDim({N, C, H, W}), input_axes_flag)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } } } } @@ -208,7 +247,7 @@ void test_unsqueeze(Place place) { } void test_unsqueeze2(Place place) { - for (std::vector axes : {std::vector({}), + for (std::vector axes : {std::vector({0}), std::vector({0, 2}), std::vector({0, -2})}) { for (int N : {1}) { -- GitLab