From 5868a6cac9ef448a4207b4c22714f4ef98f61214 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Tue, 4 Jun 2019 17:09:40 +0800 Subject: [PATCH] add concat op, kernel and unitest test=develop (#17823) * add concat op, kernel and unitest test=develop --- paddle/fluid/lite/api/cxx_api_test.cc | 2 + paddle/fluid/lite/kernels/x86/CMakeLists.txt | 2 + .../fluid/lite/kernels/x86/concat_compute.cc | 102 ++++++++++++++++++ paddle/fluid/lite/operators/CMakeLists.txt | 3 + paddle/fluid/lite/operators/concat_op.cc | 75 +++++++++++++ paddle/fluid/lite/operators/concat_op.h | 46 ++++++++ paddle/fluid/lite/operators/concat_op_test.cc | 59 ++++++++++ paddle/fluid/lite/operators/op_params.h | 7 ++ 8 files changed, 296 insertions(+) create mode 100644 paddle/fluid/lite/kernels/x86/concat_compute.cc create mode 100644 paddle/fluid/lite/operators/concat_op.cc create mode 100644 paddle/fluid/lite/operators/concat_op.h create mode 100644 paddle/fluid/lite/operators/concat_op_test.cc diff --git a/paddle/fluid/lite/api/cxx_api_test.cc b/paddle/fluid/lite/api/cxx_api_test.cc index dbc364e1dde..05630384044 100644 --- a/paddle/fluid/lite/api/cxx_api_test.cc +++ b/paddle/fluid/lite/api/cxx_api_test.cc @@ -129,6 +129,7 @@ USE_LITE_OP(elementwise_sub) USE_LITE_OP(square) USE_LITE_OP(softmax) USE_LITE_OP(dropout) +USE_LITE_OP(concat) USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); @@ -142,6 +143,7 @@ USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def); #endif #ifdef LITE_WITH_CUDA diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index aa9896b0c26..1c2937df5be 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -14,6 +14,7 @@ cc_library(scale_compute_x86 SRCS scale_compute.cc DEPS ${lite_kernel_deps}) cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op) cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} ) +cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} ) set(x86_kernels activation_compute_x86 @@ -26,6 +27,7 @@ set(x86_kernels scale_compute_x86 softmax_compute_x86 dropout_compute_x86 + concat_compute_x86 ) set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") diff --git a/paddle/fluid/lite/kernels/x86/concat_compute.cc b/paddle/fluid/lite/kernels/x86/concat_compute.cc new file mode 100644 index 00000000000..23ae8ca5055 --- /dev/null +++ b/paddle/fluid/lite/kernels/x86/concat_compute.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/types.h" +#include "paddle/fluid/operators/strided_memcpy.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +class ConcatCompute : public KernelLite { + public: + using param_t = operators::ConcatParam; + + void Run() override { + auto& param = *param_.get_mutable(); + int64_t axis = static_cast(param.axis); + auto out = param.output; + + if (axis == 0 && param.x.size() < 10) { + size_t output_offset = 0; + for (auto* in : param.x) { + if (!in || in->dims().production() == 0UL) { + continue; + } + auto in_stride = framework::stride_numel(in->dims().data()); + auto out_stride = framework::stride_numel(out->dims().data()); + paddle::operators::StridedNumelCopyWithAxis( + platform::CPUDeviceContext(), axis, + out->mutable_data() + output_offset, out_stride, in->data(), + in_stride, in_stride[axis]); + + output_offset += in_stride[axis]; + } + } else { + std::vector inputs; + for (size_t j = 0; j < param.x.size(); ++j) { + if (param.x[j] && param.x[j]->dims().production() > 0) { + inputs.push_back(*param.x[j]); + } else { + continue; + } + } + + int num = inputs.size(); + int rows = 1; + auto dim_0 = inputs[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int out_rows = rows, out_cols = 0; + + std::vector input_cols(inputs.size()); + for (int i = 0; i < num; ++i) { + int t_cols = inputs[i].dims().production() / rows; + out_cols += t_cols; + input_cols[i] = t_cols; + } + // computation + auto output_data = param.output->template mutable_data(); + int col_idx = 0; + for (int j = 0; j < num; ++j) { + int col_len = input_cols[j]; + auto input_data = inputs[j].data(); + for (int k = 0; k < out_rows; ++k) { + std::memcpy(output_data + k * out_cols + col_idx, + input_data + k * col_len, sizeof(T) * col_len); + } + col_idx += col_len; + } + } + } + + virtual ~ConcatCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW, + paddle::lite::kernels::x86::ConcatCompute, def) + .BindInput("X", {LiteType::GetTensorListTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 6ed8a410c37..8c855ed4610 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -16,6 +16,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) #cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS}) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) +cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) set(ops_lite fc_op_lite @@ -32,6 +33,7 @@ set(ops_lite fill_constant_op_lite activation_ops_lite dropout_op_lite + concat_op_lite PARENT_SCOPE) lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc @@ -41,3 +43,4 @@ lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite) lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) +lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) diff --git a/paddle/fluid/lite/operators/concat_op.cc b/paddle/fluid/lite/operators/concat_op.cc new file mode 100644 index 00000000000..e8fd910f9d0 --- /dev/null +++ b/paddle/fluid/lite/operators/concat_op.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/concat_op.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool ConcatOpLite::CheckShape() const { + CHECK_GT_OR_FALSE(param_.x.size(), 1UL); + CHECK_OR_FALSE(param_.output); + return true; +} + +bool ConcatOpLite::InferShape() const { + std::vector input_dims; + for (auto p : param_.x) { + input_dims.push_back(p->dims().data()); + } + size_t axis = static_cast(param_.axis); + const size_t n = input_dims.size(); + CHECK_GT_OR_FALSE(n, 0); + auto &out_dims = input_dims[0]; + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += input_dims[i][j]; + } else { + CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + } + } + } + if (out_dims[axis] < 0) { + out_dims[axis] = -1; + } + // Set output dims + param_.output->Resize(lite::DDim(out_dims)); + return true; +} + +// TODO(Superjomn) replace framework::OpDesc with a lite one. +bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) { + auto inputs = op_desc.Input("X"); + auto out = op_desc.Output("Out").front(); + + for (auto var : inputs) { + param_.x.push_back(scope->FindVar(var)->GetMutable()); + } + CHECK(scope->FindVar(out)); + param_.output = scope->FindVar(out)->GetMutable(); + param_.axis = GetAttr(op_desc.GetAttr("axis")); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(concat, paddle::lite::operators::ConcatOpLite); diff --git a/paddle/fluid/lite/operators/concat_op.h b/paddle/fluid/lite/operators/concat_op.h new file mode 100644 index 00000000000..86f58be45f3 --- /dev/null +++ b/paddle/fluid/lite/operators/concat_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class ConcatOpLite : public OpLite { + public: + ConcatOpLite() {} + explicit ConcatOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "concat"; } + + private: + mutable ConcatParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/concat_op_test.cc b/paddle/fluid/lite/operators/concat_op_test.cc new file mode 100644 index 00000000000..d5a412893ee --- /dev/null +++ b/paddle/fluid/lite/operators/concat_op_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/concat_op.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(concat_op_lite, test) { + // prepare variables + lite::Scope scope; + auto* x0 = scope.Var("x0")->GetMutable(); + auto* x1 = scope.Var("x1")->GetMutable(); + auto* output = scope.Var("output")->GetMutable(); + x0->Resize(lite::DDim(std::vector({10, 20}))); + x1->Resize(lite::DDim(std::vector({10, 20}))); + output->Resize(lite::DDim(std::vector{20, 20})); + + // set data + for (int i = 0; i < 10 * 20; i++) { + x0->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + x1->mutable_data()[i] = i; + } + for (int i = 0; i < 10 * 20; i++) { + output->mutable_data()[i] = 0.; + } + + // prepare op desc + lite::OpDesc desc; + desc.SetType("concat"); + desc.SetInput("X", {"x0", "x1"}); + desc.SetOutput("Out", {"output"}); + desc.SetAttr("axis", static_cast(0)); + + ConcatOpLite concat("concat"); + + concat.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}}); + concat.Attach(desc, &scope); +} + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/op_params.h b/paddle/fluid/lite/operators/op_params.h index 166a5ad868e..23b21cb2764 100644 --- a/paddle/fluid/lite/operators/op_params.h +++ b/paddle/fluid/lite/operators/op_params.h @@ -113,6 +113,13 @@ struct ReshapeParam { bool inplace{false}; }; +// For Concat op +struct ConcatParam { + std::vector x{}; + lite::Tensor* output{}; + int axis{0}; +}; + // For Convolution op struct ConvParam { lite::Tensor* x{}; -- GitLab