未验证 提交 03803f20 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] support list of tensor input (#31801)

* support list of tensor as npu input

* add comment

* fix typo

* fix typo
上级 63505282
......@@ -185,6 +185,21 @@ NpuOpRunner &NpuOpRunner::AddInputs(const std::vector<Tensor> &tensors) {
return *this;
}
// NOTE(zhiqiu): For operators whose input is a list (such as concat, stack),
// It is needed to set the name of each input tensor.
NpuOpRunner &NpuOpRunner::AddInputNames(const std::vector<std::string> &names) {
PADDLE_ENFORCE_EQ(names.size(), input_descs_.size(),
platform::errors::InvalidArgument(
"The size of input names should be "
"equal to the size of input descs, but got the size "
"of input names is %d, the size of input descs is %d.",
names.size(), input_descs_.size()));
for (size_t i = 0; i < names.size(); ++i) {
aclSetTensorDescName(input_descs_[i], names[i].c_str());
}
return *this;
}
NpuOpRunner &NpuOpRunner::AddOutputs(const std::vector<Tensor> &tensors) {
for (auto tensor : tensors) {
// create aclTensorDesc
......
......@@ -53,6 +53,8 @@ class NpuOpRunner {
NpuOpRunner &AddInputs(const std::vector<Tensor> &tensors);
NpuOpRunner &AddInputNames(const std::vector<std::string> &names);
NpuOpRunner &AddOutputs(const std::vector<Tensor> &tensors);
aclTensorDesc *GetInputDesc(size_t index);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册