提交 05088be6 编写于 作者: M Megvii Engine Team

feat(lite): add get_tensor_at method in TensorBatchCollector

GitOrigin-RevId: 923118803075d5cca93bd81a7f7d635b0772170d
上级 eb3c6473
......@@ -203,6 +203,21 @@ class TensorBatchCollector:
self._free_list.remove(i)
self._free_list.extend(indexes)
def get_tensor_at(self, idx):
"""
get the tensor from the internal big tensor by the idx, make sure the
idx is not freed, return the tensor
Args:
idx: the tensor index in the internal big tensor
"""
assert idx < self.shape[0], "the idx specific the tensor is out of range."
if idx in self._free_list:
warnings.warn(
"tensor with batch id {} has not collected before get it.".format(idx)
)
return self._tensor.slice([idx], [idx + 1])
def get(self):
"""
After finish collection, get the result tensor
......
......@@ -53,6 +53,15 @@ def test_tensor_collect_batch_cpu():
for j in range(64):
assert data[i][j // 8][j % 8] == i + 1
for i in range(4):
t = batch_tensor.get_tensor_at(i)
data = t.to_numpy()
assert data.shape[0] == 1
assert data.shape[1] == 8
assert data.shape[2] == 8
for j in range(64):
assert data[0][j // 8][j % 8] == i + 1
@require_cuda
def test_tensor_collect_batch_by_index():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册