diff --git a/lite/kernels/arm/split_compute.cc b/lite/kernels/arm/split_compute.cc index 27606e2d76dfd13161fffc3f53d614155f62254e..2a0c52e7fc44cdd7c36ac3e8f93b33731f03bd77 100644 --- a/lite/kernels/arm/split_compute.cc +++ b/lite/kernels/arm/split_compute.cc @@ -42,5 +42,9 @@ void SplitCompute::Run() { REGISTER_LITE_KERNEL( split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("AxisTensor", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + .BindInput("SectionsTensorList", + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 5ae22e6039bf55bb57f4e90a49b4eca835b879ea..2381f3f9265291bf7a436877c60997bd0d1f3498 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -316,6 +316,9 @@ struct DropoutParam { struct SplitParam { lite::Tensor* x{}; std::vector output{}; + lite::Tensor* axis_tensor; + std::vector* sections_tensor_list{}; + int axis{-1}; int num{0}; std::vector sections; @@ -377,6 +380,7 @@ struct MeanGradParam { struct FillConstantParam { int dtype{static_cast(VarDescAPI::VarDataType::FP32)}; std::vector shape{}; + float value{0.0f}; // useless for x86, keep it for compatibility bool force_cpu{false}; diff --git a/lite/operators/split_op.cc b/lite/operators/split_op.cc index 18280616aa00b734596b620727f6dcfd5beb67d7..36ab7dff9a6026ad5b6423b1395f3d502c9bb4d6 100644 --- a/lite/operators/split_op.cc +++ b/lite/operators/split_op.cc @@ -39,8 +39,18 @@ bool SplitOp::InferShape() const { const int outs_number = outs.size(); std::vector outs_dims; outs_dims.reserve(outs_number); - - if (num > 0) { + std::vector *sections_tensor_list_ = + param_.sections_tensor_list; + if (sections.size() > 0 && sections_tensor_list_->size() > 0) { + std::vector vec_sections; + for (size_t i = 0; i < sections_tensor_list_->size(); ++i) { + auto dim = in_dims; + // lite::TensorLite aa = sections_tensor_list_[i]; + dim[axis] = (*sections_tensor_list_)[i].data()[0]; + // final_axes.push_back(axes_tensor_vct[i].data()[0]); + outs_dims.push_back(dim); + } + } else if (num > 0) { int out_axis_dim = in_dims[axis] / num; for (int i = 0; i < outs_number; ++i) { auto dim = in_dims; @@ -55,6 +65,10 @@ bool SplitOp::InferShape() const { } } + if (param_.axis_tensor != nullptr) { + axis = param_.axis_tensor->data()[0]; + } + for (int j = 0; j < outs_dims.size(); ++j) { outs[j]->Resize(outs_dims[j]); } @@ -73,6 +87,16 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { for (auto var : outs) { param_.output.push_back(scope->FindVar(var)->GetMutable()); } + if (opdesc.HasAttr("AxisTensor")) { + auto args = opdesc.Input("AxisTensor"); + auto *var = scope->FindVar(args.front()); + param_.axis_tensor = var->GetMutable(); + } + if (opdesc.HasAttr("SectionsTensorList")) { + auto args = opdesc.Input("SectionsTensorList"); + auto *var = scope->FindVar(args.front()); + param_.sections_tensor_list = var->GetMutable>(); + } return true; }