未验证 提交 c543a41b 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Use correct input size for splits (#1284)

* Use correct input size for splits

* Use smarter partitioning
上级 b1b41754
......@@ -14,6 +14,7 @@ def split_tensor_along_last_dim(tensor, partitions, contiguous_split_chunks=Fals
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
# Split.
tensor_list = torch.split(tensor, partitions, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
......@@ -123,8 +124,9 @@ class TiledLinear(torch.nn.Module):
def forward(self, input_):
if self.in_splits > 1 and not self.input_is_already_split:
input_parts = partition(input_.shape[-1], self.in_splits)
split_sizes = [
self.in_parts[p + 1] - self.in_parts[p] for p in range(self.in_splits)
input_parts[p + 1] - input_parts[p] for p in range(self.in_splits)
]
inputs = self._split_global_input(input_, split_sizes)
elif self.in_splits > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册