提交 37b2d22c 编写于 作者: myq406450149's avatar myq406450149 提交者: myq406450149

split op upgrade

上级 34e73be0
......@@ -42,5 +42,9 @@ void SplitCompute::Run() {
REGISTER_LITE_KERNEL(
split, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::SplitCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("SectionsTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -316,6 +316,9 @@ struct DropoutParam {
struct SplitParam {
lite::Tensor* x{};
std::vector<lite::Tensor*> output{};
lite::Tensor* axis_tensor;
std::vector<lite::Tensor>* sections_tensor_list{};
int axis{-1};
int num{0};
std::vector<int> sections;
......@@ -377,6 +380,7 @@ struct MeanGradParam {
struct FillConstantParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{};
float value{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu{false};
......
......@@ -39,8 +39,18 @@ bool SplitOp::InferShape() const {
const int outs_number = outs.size();
std::vector<lite::DDim> outs_dims;
outs_dims.reserve(outs_number);
if (num > 0) {
std::vector<lite::Tensor> *sections_tensor_list_ =
param_.sections_tensor_list;
if (sections.size() > 0 && sections_tensor_list_->size() > 0) {
std::vector<int> vec_sections;
for (size_t i = 0; i < sections_tensor_list_->size(); ++i) {
auto dim = in_dims;
// lite::TensorLite aa = sections_tensor_list_[i];
dim[axis] = (*sections_tensor_list_)[i].data<int>()[0];
// final_axes.push_back(axes_tensor_vct[i].data<int>()[0]);
outs_dims.push_back(dim);
}
} else if (num > 0) {
int out_axis_dim = in_dims[axis] / num;
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
......@@ -55,6 +65,10 @@ bool SplitOp::InferShape() const {
}
}
if (param_.axis_tensor != nullptr) {
axis = param_.axis_tensor->data<int>()[0];
}
for (int j = 0; j < outs_dims.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
......@@ -73,6 +87,16 @@ bool SplitOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
for (auto var : outs) {
param_.output.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
if (opdesc.HasAttr("AxisTensor")) {
auto args = opdesc.Input("AxisTensor");
auto *var = scope->FindVar(args.front());
param_.axis_tensor = var->GetMutable<lite::Tensor>();
}
if (opdesc.HasAttr("SectionsTensorList")) {
auto args = opdesc.Input("SectionsTensorList");
auto *var = scope->FindVar(args.front());
param_.sections_tensor_list = var->GetMutable<std::vector<lite::Tensor>>();
}
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册