提交 b7cf0984 编写于 作者: L lijianshe02 提交者: GitHub

add concat op, kernel and unitest test=develop (#17823)

* add concat op, kernel and unitest test=develop
上级 14764fa4
......@@ -129,6 +129,7 @@ USE_LITE_OP(elementwise_sub)
USE_LITE_OP(square)
USE_LITE_OP(softmax)
USE_LITE_OP(dropout)
USE_LITE_OP(concat)
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
......@@ -142,6 +143,7 @@ USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(dropout, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kX86, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_CUDA
......
......@@ -14,6 +14,7 @@ cc_library(scale_compute_x86 SRCS scale_compute.cc DEPS ${lite_kernel_deps})
cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op)
cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
cc_library(dropout_compute_x86 SRCS dropout_compute.cc DEPS ${lite_kernel_deps} )
cc_library(concat_compute_x86 SRCS concat_compute.cc DEPS ${lite_kernel_deps} )
set(x86_kernels
activation_compute_x86
......@@ -26,6 +27,7 @@ set(x86_kernels
scale_compute_x86
softmax_compute_x86
dropout_compute_x86
concat_compute_x86
)
set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <Eigen/Core>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/types.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::ConcatParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
int64_t axis = static_cast<int64_t>(param.axis);
auto out = param.output;
if (axis == 0 && param.x.size() < 10) {
size_t output_offset = 0;
for (auto* in : param.x) {
if (!in || in->dims().production() == 0UL) {
continue;
}
auto in_stride = framework::stride_numel(in->dims().data());
auto out_stride = framework::stride_numel(out->dims().data());
paddle::operators::StridedNumelCopyWithAxis<T>(
platform::CPUDeviceContext(), axis,
out->mutable_data<T>() + output_offset, out_stride, in->data<T>(),
in_stride, in_stride[axis]);
output_offset += in_stride[axis];
}
} else {
std::vector<lite::Tensor> inputs;
for (size_t j = 0; j < param.x.size(); ++j) {
if (param.x[j] && param.x[j]->dims().production() > 0) {
inputs.push_back(*param.x[j]);
} else {
continue;
}
}
int num = inputs.size();
int rows = 1;
auto dim_0 = inputs[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(inputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = inputs[i].dims().production() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
auto output_data = param.output->template mutable_data<T>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = inputs[j].data<float>();
for (int k = 0; k < out_rows; ++k) {
std::memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len);
}
col_idx += col_len;
}
}
}
virtual ~ConcatCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(concat, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ConcatCompute<float>, def)
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -16,6 +16,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
#cc_library(sgd_op_lite SRCS sgd_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
set(ops_lite
fc_op_lite
......@@ -32,6 +33,7 @@ set(ops_lite
fill_constant_op_lite
activation_ops_lite
dropout_op_lite
concat_op_lite
PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
......@@ -41,3 +43,4 @@ lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
lite_cc_test(test_scale_op_lite SRCS scale_op_test.cc DEPS scale_op_lite memory_lite)
lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite memory_lite)
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/concat_op.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ConcatOpLite::CheckShape() const {
CHECK_GT_OR_FALSE(param_.x.size(), 1UL);
CHECK_OR_FALSE(param_.output);
return true;
}
bool ConcatOpLite::InferShape() const {
std::vector<framework::DDim> input_dims;
for (auto p : param_.x) {
input_dims.push_back(p->dims().data());
}
size_t axis = static_cast<size_t>(param_.axis);
const size_t n = input_dims.size();
CHECK_GT_OR_FALSE(n, 0);
auto &out_dims = input_dims[0];
size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) {
out_dims[axis] += input_dims[i][j];
} else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]);
}
}
}
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
// Set output dims
param_.output->Resize(lite::DDim(out_dims));
return true;
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConcatOpLite::AttachImpl(const OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front();
for (auto var : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.axis = GetAttr<int>(op_desc.GetAttr("axis"));
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(concat, paddle::lite::operators::ConcatOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class ConcatOpLite : public OpLite {
public:
ConcatOpLite() {}
explicit ConcatOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "concat"; }
private:
mutable ConcatParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/concat_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(concat_op_lite, test) {
// prepare variables
lite::Scope scope;
auto* x0 = scope.Var("x0")->GetMutable<lite::Tensor>();
auto* x1 = scope.Var("x1")->GetMutable<lite::Tensor>();
auto* output = scope.Var("output")->GetMutable<lite::Tensor>();
x0->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
x1->Resize(lite::DDim(std::vector<int64_t>({10, 20})));
output->Resize(lite::DDim(std::vector<int64_t>{20, 20}));
// set data
for (int i = 0; i < 10 * 20; i++) {
x0->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
x1->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
output->mutable_data<float>()[i] = 0.;
}
// prepare op desc
lite::OpDesc desc;
desc.SetType("concat");
desc.SetInput("X", {"x0", "x1"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("axis", static_cast<int>(0));
ConcatOpLite concat("concat");
concat.SetValidPlaces({Place{TARGET(kX86), PRECISION(kFloat)}});
concat.Attach(desc, &scope);
}
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -113,6 +113,13 @@ struct ReshapeParam {
bool inplace{false};
};
// For Concat op
struct ConcatParam {
std::vector<lite::Tensor*> x{};
lite::Tensor* output{};
int axis{0};
};
// For Convolution op
struct ConvParam {
lite::Tensor* x{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册