未验证 提交 77a674fc 编写于 作者: 周周周 提交者: GitHub

commit (#54189)

上级 934d25e7
...@@ -68,6 +68,12 @@ class SliceOpConverter : public OpConverter { ...@@ -68,6 +68,12 @@ class SliceOpConverter : public OpConverter {
starts_tensor[axes[i]] = GetEleTensorOfShape( starts_tensor[axes[i]] = GetEleTensorOfShape(
engine_->GetITensor(op_desc.Input("StartsTensor")[0]), i); engine_->GetITensor(op_desc.Input("StartsTensor")[0]), i);
} }
} else if (slice_inputs.find("StartsTensorList") != slice_inputs.end() &&
op_desc.Input("StartsTensorList").size()) {
for (size_t i = 0; i < axes.size(); ++i) {
starts_tensor[axes[i]] =
engine_->GetITensor(op_desc.Input("StartsTensorList")[i]);
}
} else { } else {
PADDLE_ENFORCE_EQ(starts.size(), PADDLE_ENFORCE_EQ(starts.size(),
axes.size(), axes.size(),
...@@ -97,6 +103,12 @@ class SliceOpConverter : public OpConverter { ...@@ -97,6 +103,12 @@ class SliceOpConverter : public OpConverter {
ends_tensor[axes[i]] = GetEleTensorOfShape( ends_tensor[axes[i]] = GetEleTensorOfShape(
engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i); engine_->GetITensor(op_desc.Input("EndsTensor")[0]), i);
} }
} else if (slice_inputs.find("EndsTensorList") != slice_inputs.end() &&
op_desc.Input("EndsTensorList").size()) {
for (size_t i = 0; i < axes.size(); ++i) {
ends_tensor[axes[i]] =
engine_->GetITensor(op_desc.Input("EndsTensorList")[i]);
}
} else { } else {
PADDLE_ENFORCE_EQ(ends.size(), PADDLE_ENFORCE_EQ(ends.size(),
axes.size(), axes.size(),
......
...@@ -1381,14 +1381,10 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1381,14 +1381,10 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
} }
if (slice_inputs.find("StartsTensorList") != slice_inputs.end()) { if (slice_inputs.find("StartsTensorList") != slice_inputs.end()) {
if (desc.Input("StartsTensorList").size()) { VLOG(3) << "The Slice has StartsTensorList input.";
return false;
}
} }
if (slice_inputs.find("EndsTensorList") != slice_inputs.end()) { if (slice_inputs.find("EndsTensorList") != slice_inputs.end()) {
if (desc.Input("EndsTensorList").size()) { VLOG(3) << "The Slice has EndsTensorList input.";
return false;
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册