未验证 提交 77a674fc 编写于 作者: 周周周 提交者: GitHub

commit (#54189)

上级 934d25e7
......@@ -68,6 +68,12 @@ class SliceOpConverter : public OpConverter {
starts_tensor[axes[i]] = GetEleTensorOfShape(
engine_->GetITensor(op_desc.Input("StartsTensor")[0]), i);
}
} else if (slice_inputs.find("StartsTensorList") != slice_inputs.end() &&
op_desc.Input("StartsTensorList").size()) {
for (size_t i = 0; i < axes.size(); ++i) {
starts_tensor[axes[i]] =
engine_->GetITensor(op_desc.Input("StartsTensorList")[i]);
}
} else {
PADDLE_ENFORCE_EQ(starts.size(),
axes.size(),
......@@ -97,6 +103,12 @@ class SliceOpConverter : public OpConverter {
ends_tensor[axes[i]] = GetEleTensorOfShape(
engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i);
}
} else if (slice_inputs.find("EndsTensorList") != slice_inputs.end() &&
op_desc.Input("EndsTensorList").size()) {
for (size_t i = 0; i < axes.size(); ++i) {
ends_tensor[axes[i]] =
engine_->GetITensor(op_desc.Input("EndsTensorList")[i]);
}
} else {
PADDLE_ENFORCE_EQ(ends.size(),
axes.size(),
......
......@@ -1381,14 +1381,10 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (slice_inputs.find("StartsTensorList") != slice_inputs.end()) {
if (desc.Input("StartsTensorList").size()) {
return false;
}
VLOG(3) << "The Slice has StartsTensorList input.";
}
if (slice_inputs.find("EndsTensorList") != slice_inputs.end()) {
if (desc.Input("EndsTensorList").size()) {
return false;
}
VLOG(3) << "The Slice has EndsTensorList input.";
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册