diff --git a/lite/kernels/arm/split_compute.cc b/lite/kernels/arm/split_compute.cc index 27606e2d76dfd13161fffc3f53d614155f62254e..2a0c52e7fc44cdd7c36ac3e8f93b33731f03bd77 100644 --- a/lite/kernels/arm/split_compute.cc +++ b/lite/kernels/arm/split_compute.cc @@ -42,5 +42,9 @@ void SplitCompute::Run() { REGISTER_LITE_KERNEL( split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("SectionsTensorList", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 0e9f85e06096b64b8f041cc7fbb58db240ecd78e..d47543961529b0147768ca11f2df70f8b3b66526 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -344,6 +344,9 @@ struct DropoutParam { struct SplitParam { lite::Tensor* x{}; std::vector output{}; + lite::Tensor* axis_tensor; + std::vector sections_tensor_list{}; + int axis{-1}; int num{0}; std::vector sections; diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 18280616aa00b734596b620727f6dcfd5beb67d7..ec98a0d6c3ba3b1e5cd1c7992b58e96917d21057 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -39,8 +39,16 @@ bool SplitOp::InferShape() const { const int outs_number = outs.size(); std::vector outs_dims; outs_dims.reserve(outs_number); - - if (num > 0) { + std::vector sections_tensor_list_ = + param_.sections_tensor_list; + if (sections.size() > 0 && sections_tensor_list_.size() > 0) { + std::vector vec_sections; + for (size_t i = 0; i < sections_tensor_list_.size(); ++i) { + auto dim = in_dims; + dim[axis] = sections_tensor_list_[i]->data()[0]; + outs_dims.push_back(dim); + } + } else if (num > 0) { int out_axis_dim = in_dims[axis] / num; for (int i = 0; i < outs_number; ++i) { auto dim = in_dims; @@ -55,6 +63,10 @@ bool SplitOp::InferShape() const { } } + if (param_.axis_tensor != nullptr) { + axis = param_.axis_tensor->data()[0]; + } + for (int j = 0; j < outs_dims.size(); ++j) { outs[j]->Resize(outs_dims[j]); } @@ -73,6 +85,21 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } + std::vector input_arg_names = opdesc.InputArgumentNames(); + if (std::find(input_arg_names.begin(), input_arg_names.end(), "AxisTensor") != + input_arg_names.end()) { + auto args = opdesc.Input("AxisTensor"); + auto *var = scope->FindVar(args.front()); + param_.axis_tensor = var->GetMutable(); + } + if (std::find(input_arg_names.begin(), + input_arg_names.end(), + "SectionsTensorList") != input_arg_names.end()) { + auto args = opdesc.Input("SectionsTensorList"); + auto *var = scope->FindVar(args.front()); + param_.sections_tensor_list = + *(var->GetMutable>()); + } return true; }