提交 3c83e6f3 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

split op support param can be tensor or tensorlist,test=develop (#2474)

* split op upgrade
上级 d1a3c141
......@@ -42,5 +42,9 @@ void SplitCompute::Run() {
REGISTER_LITE_KERNEL(
split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -344,6 +344,9 @@ struct DropoutParam {
struct SplitParam {
lite::Tensor* x{};
std::vector<lite::Tensor*> output{};
lite::Tensor* axis_tensor;
std::vector<lite::Tensor*> sections_tensor_list{};
int axis{-1};
int num{0};
std::vector<int> sections;
......
......@@ -39,8 +39,16 @@ bool SplitOp::InferShape() const {
const int outs_number = outs.size();
std::vector<lite::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
std::vector<lite::Tensor *> sections_tensor_list_ =
param_.sections_tensor_list;
if (sections.size() > 0 && sections_tensor_list_.size() > 0) {
std::vector<int> vec_sections;
for (size_t i = 0; i < sections_tensor_list_.size(); ++i) {
auto dim = in_dims;
dim[axis] = sections_tensor_list_[i]->data<int>()[0];
outs_dims.push_back(dim);
}
} else if (num > 0) {
int out_axis_dim = in_dims[axis] / num;
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
......@@ -55,6 +63,10 @@ bool SplitOp::InferShape() const {
}
}
if (param_.axis_tensor != nullptr) {
axis = param_.axis_tensor->data<int>()[0];
}
for (int j = 0; j < outs_dims.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
......@@ -73,6 +85,21 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
std::vector<std::string> input_arg_names = opdesc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "AxisTensor") !=
input_arg_names.end()) {
auto args = opdesc.Input("AxisTensor");
auto *var = scope->FindVar(args.front());
param_.axis_tensor = var->GetMutable<lite::Tensor>();
}
if (std::find(input_arg_names.begin(),
input_arg_names.end(),
"SectionsTensorList") != input_arg_names.end()) {
auto args = opdesc.Input("SectionsTensorList");
auto *var = scope->FindVar(args.front());
param_.sections_tensor_list =
*(var->GetMutable<std::vector<lite::Tensor *>>());
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册