# 具名张量操作覆盖范围 > 原文:[`pytorch.org/docs/stable/name_inference.html`](https://pytorch.org/docs/stable/name_inference.html)> > > 译者:[飞龙](https://github.com/wizardforcel) > > 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) 请先阅读具名张量以了解具名张量的介绍。 本文档是关于*名称推断*的参考,该过程定义了具名张量如何: 1. 使用名称提供额外的自动运行时正确性检查 1. 从输入张量传播名称到输出张量 以下是所有支持具名张量及其相关名称推断规则的操作列表。 如果您在此处找不到所需的操作,请查看是否已经提交了问题,如果没有,请提交一个。 警告 具名张量 API 是实验性的,可能会发生变化。 支持的操作 | API | 名称推断规则 | | --- | --- | | `Tensor.abs()`, `torch.abs()` | 保留输入名称 | | `Tensor.abs_()` | 保留输入名称 | | `Tensor.acos()`, `torch.acos()` | 保留输入名称 | | `Tensor.acos_()` | 保留输入名称 | | `Tensor.add()`, `torch.add()` | 统一来自输入的名称 | | `Tensor.add_()` | 统一来自输入的名称 | | `Tensor.addmm()`, `torch.addmm()` | 消除维度 | | `Tensor.addmm_()` | 消除维度 | | `Tensor.addmv()`, `torch.addmv()` | 消除维度 | | `Tensor.addmv_()` | 消除维度 | | `Tensor.align_as()` | 查看文档 | | `Tensor.align_to()` | 查看文档 | | `Tensor.all()`, `torch.all()` | 无 | | `Tensor.any()`, `torch.any()` | 无 | | `Tensor.asin()`, `torch.asin()` | 保留输入名称 | | `Tensor.asin_()` | 保留输入名称 | | `Tensor.atan()`, `torch.atan()` | 保留输入名称 | | `Tensor.atan2()`, `torch.atan2()` | 统一来自输入的名称 | | `Tensor.atan2_()` | 统一来自输入的名称 | | `Tensor.atan_()` | 保留输入名称 | | `Tensor.bernoulli()`, `torch.bernoulli()` | 保留输入名称 | | `Tensor.bernoulli_()` | 无 | | `Tensor.bfloat16()` | 保留输入名称 | | `Tensor.bitwise_not()`, `torch.bitwise_not()` | 保留输入名称 | | `Tensor.bitwise_not_()` | 无 | | `Tensor.bmm()`, `torch.bmm()` | 消除维度 | | `Tensor.bool()` | 保留输入名称 | | `Tensor.byte()` | 保留输入名称 | | `torch.cat()` | 统一来自输入的名称 | | `Tensor.cauchy_()` | 无 | | `Tensor.ceil()`, `torch.ceil()` | 保留输入名称 | | `Tensor.ceil_()` | 无 | | `Tensor.char()` | 保留输入名称 | | `Tensor.chunk()`, `torch.chunk()` | 保留输入名称 | | `Tensor.clamp()`, `torch.clamp()` | 保留输入名称 | | `Tensor.clamp_()` | 无 | | `Tensor.copy_()` | 输出函数和原地变体 | | `Tensor.cos()`, `torch.cos()` | 保留输入名称 | | `Tensor.cos_()` | 无 | | `Tensor.cosh()`, `torch.cosh()` | 保留输入名称 | | `Tensor.cosh_()` | 无 | | `Tensor.acosh()`, `torch.acosh()` | 保留输入名称 | | `Tensor.acosh_()` | 无 | | `Tensor.cpu()` | 保留输入名称 | | `Tensor.cuda()` | 保留输入名称 | | `Tensor.cumprod()`, `torch.cumprod()` | 保留输入名称 | | `Tensor.cumsum()`, `torch.cumsum()` | 保留输入名称 | | `Tensor.data_ptr()` | 无 | | `Tensor.deg2rad()`, `torch.deg2rad()` | 保留输入名称 | | `Tensor.deg2rad_()` | 无 | | `Tensor.detach()`, `torch.detach()` | 保留输入名称 | | `Tensor.detach_()` | 无 | | `Tensor.device`, `torch.device()` | 无 | | `Tensor.digamma()`, `torch.digamma()` | 保留输入名称 | | `Tensor.digamma_()` | 无 | | `Tensor.dim()` | 无 | | `Tensor.div()`, `torch.div()` | 统一输入名称 | | `Tensor.div_()` | 统一输入名称 | | `Tensor.dot()`, `torch.dot()` | 无 | | `Tensor.double()` | 保留输入名称 | | `Tensor.element_size()` | 无 | | `torch.empty()` | 工厂函数 | | `torch.empty_like()` | 工厂函数 | | `Tensor.eq()`, `torch.eq()` | 统一输入名称 | | `Tensor.erf()`, `torch.erf()` | 保留输入名称 | | `Tensor.erf_()` | 无 | | `Tensor.erfc()`, `torch.erfc()` | 保留输入名称 | | `Tensor.erfc_()` | 无 | | `Tensor.erfinv()`, `torch.erfinv()` | 保留输入名称 | | `Tensor.erfinv_()` | 无 | | `Tensor.exp()`, `torch.exp()` | 保留输入名称 | | `Tensor.exp_()` | 无 | | `Tensor.expand()` | 保留输入名称 | | `Tensor.expm1()`, `torch.expm1()` | 保留输入名称 | | `Tensor.expm1_()` | 无 | | `Tensor.exponential_()` | 无 | | `Tensor.fill_()` | 无 | | `Tensor.flatten()`, `torch.flatten()` | 查看文档 | | `Tensor.float()` | 保留输入名称 | | `Tensor.floor()`, `torch.floor()` | 保留输入名称 | | `Tensor.floor_()` | 无 | | `Tensor.frac()`, `torch.frac()` | 保留输入名称 | | `Tensor.frac_()` | 无 | | `Tensor.ge()`, `torch.ge()` | 统一输入的名称 | | `Tensor.get_device()`, `torch.get_device()` | 无 | | `Tensor.grad` | 无 | | `Tensor.gt()`, `torch.gt()` | 统一输入的名称 | | `Tensor.half()` | 保留输入名称 | | `Tensor.has_names()` | 查看文档 | | `Tensor.index_fill()`, `torch.index_fill()` | 保留输入名称 | | `Tensor.index_fill_()` | 无 | | `Tensor.int()` | 保留输入名称 | | `Tensor.is_contiguous()` | 无 | | `Tensor.is_cuda` | 无 | | `Tensor.is_floating_point()`, `torch.is_floating_point()` | 无 | | `Tensor.is_leaf` | 无 | | `Tensor.is_pinned()` | 无 | | `Tensor.is_shared()` | 无 | | `Tensor.is_signed()`, `torch.is_signed()` | 无 | | `Tensor.is_sparse` | 无 | | `Tensor.is_sparse_csr` | 无 | | `torch.is_tensor()` | 无 | | `Tensor.item()` | 无 | | `Tensor.itemsize` | 无 | | `Tensor.kthvalue()`, `torch.kthvalue()` | 移除维度 | | `Tensor.le()`, `torch.le()` | 从输入中统一名称 | | `Tensor.log()`, `torch.log()` | 保留输入名称 | | `Tensor.log10()`, `torch.log10()` | 保留输入名称 | | `Tensor.log10_()` | 无 | | `Tensor.log1p()`, `torch.log1p()` | 保留输入名称 | | `Tensor.log1p_()` | 无 | | `Tensor.log2()`, `torch.log2()` | 保留输入名称 | | `Tensor.log2_()` | 无 | | `Tensor.log_()` | 无 | | `Tensor.log_normal_()` | 无 | | `Tensor.logical_not()`, `torch.logical_not()` | 保留输入名称 | | `Tensor.logical_not_()` | 无 | | `Tensor.logsumexp()`, `torch.logsumexp()` | 移除维度 | | `Tensor.long()` | 保留输入名称 | | `Tensor.lt()`, `torch.lt()` | 从输入统一名称 | | `torch.manual_seed()` | 无 | | `Tensor.masked_fill()`, `torch.masked_fill()` | 保留输入名称 | | `Tensor.masked_fill_()` | 无 | | `Tensor.masked_select()`, `torch.masked_select()` | 将掩码与输入对齐,然后统一输入张量的名称 | | `Tensor.matmul()`, `torch.matmul()` | 消除维度 | | `Tensor.mean()`, `torch.mean()` | 移除维度 | | `Tensor.median()`, `torch.median()` | 移除维度 | | `Tensor.nanmedian()`, `torch.nanmedian()` | 移除维度 | | `Tensor.mm()`, `torch.mm()` | 消除维度 | | `Tensor.mode()`, `torch.mode()` | 移除维度 | | `Tensor.mul()`, `torch.mul()` | 从输入统一名称 | | `Tensor.mul_()` | 从输入统一名称 | | `Tensor.mv()`, `torch.mv()` | 消除维度 | | `Tensor.names` | 查看文档 | | `Tensor.narrow()`, `torch.narrow()` | 保留输入名称 | | `Tensor.nbytes` | 无 | | `Tensor.ndim` | 无 | | `Tensor.ndimension()` | 无 | | `Tensor.ne()`, `torch.ne()` | 统一输入名称 | | `Tensor.neg()`, `torch.neg()` | 保留输入名称 | | `Tensor.neg_()` | 无 | | `torch.normal()` | 保留输入名称 | | `Tensor.normal_()` | 无 | | `Tensor.numel()`, `torch.numel()` | 无 | | `torch.ones()` | 工厂函数 | | `Tensor.pow()`, `torch.pow()` | 统一输入名称 | | `Tensor.pow_()` | 无 | | `Tensor.prod()`, `torch.prod()` | 移除维度 | | `Tensor.rad2deg()`, `torch.rad2deg()` | 保留输入名称 | | `Tensor.rad2deg_()` | 无 | | `torch.rand()` | 工厂函数 | | `torch.rand()` | 工厂函数 | | `torch.randn()` | 工厂函数 | | `torch.randn()` | 工厂函数 | | `Tensor.random_()` | 无 | | `Tensor.reciprocal()`, `torch.reciprocal()` | 保留输入名称 | | `Tensor.reciprocal_()` | 无 | | `Tensor.refine_names()` | 查看文档 | | `Tensor.register_hook()` | 无 | | `Tensor.register_post_accumulate_grad_hook()` | 无 | | `Tensor.rename()` | 查看文档 | | `Tensor.rename_()` | 查看文档 | | `Tensor.requires_grad` | 无 | | `Tensor.requires_grad_()` | 无 | | `Tensor.resize_()` | 仅允许不改变形状的调整 | | `Tensor.resize_as_()` | 仅允许不改变形状的调整大小 | | `Tensor.round()`, `torch.round()` | 保留输入名称 | | `Tensor.round_()` | 无 | | `Tensor.rsqrt()`, `torch.rsqrt()` | 保留输入名称 | | `Tensor.rsqrt_()` | 无 | | `Tensor.select()`, `torch.select()` | 移除维度 | | `Tensor.short()` | 保留输入名称 | | `Tensor.sigmoid()`, `torch.sigmoid()` | 保留输入名称 | | `Tensor.sigmoid_()` | 无 | | `Tensor.sign()`, `torch.sign()` | 保留输入名称 | | `Tensor.sign_()` | 无 | | `Tensor.sgn()`, `torch.sgn()` | 保留输入名称 | | `Tensor.sgn_()` | 无 | | `Tensor.sin()`, `torch.sin()` | 保留输入名称 | | `Tensor.sin_()` | 无 | | `Tensor.sinh()`, `torch.sinh()` | 保留输入名称 | | `Tensor.sinh_()` | 无 | | `Tensor.asinh()`, `torch.asinh()` | 保留输入名称 | | `Tensor.asinh_()` | 无 | | `Tensor.size()` | 无 | | `Tensor.softmax()`, `torch.softmax()` | 保留输入名称 | | `Tensor.split()`, `torch.split()` | 保留输入名称 | | `Tensor.sqrt()`, `torch.sqrt()` | 保留输入名称 | | `Tensor.sqrt_()` | 无 | | `Tensor.squeeze()`, `torch.squeeze()` | 移除维度 | | `Tensor.std()`, `torch.std()` | 移除维度 | | `torch.std_mean()` | 移除维度 | | `Tensor.stride()` | 无 | | `Tensor.sub()`, `torch.sub()` | 统一输入的名称 | | `Tensor.sub_()` | 统一输入的名称 | | `Tensor.sum()`, `torch.sum()` | 移除维度 | | `Tensor.tan()`, `torch.tan()` | 保留输入名称 | | `Tensor.tan_()` | 无 | | `Tensor.tanh()`, `torch.tanh()` | 保留输入名称 | | `Tensor.tanh_()` | 无 | | `Tensor.atanh()`, `torch.atanh()` | 保留输入名称 | | `Tensor.atanh_()` | 无 | | `torch.tensor()` | 工厂函数 | | `Tensor.to()` | 保留输入名称 | | `Tensor.topk()`, `torch.topk()` | 移除维度 | | `Tensor.transpose()`, `torch.transpose()` | 重新排列维度 | | `Tensor.trunc()`, `torch.trunc()` | 保留输入名称 | | `Tensor.trunc_()` | 无 | | `Tensor.type()` | 无 | | `Tensor.type_as()` | 保留输入名称 | | `Tensor.unbind()`, `torch.unbind()` | 移除维度 | | `Tensor.unflatten()` | 查看文档 | | `Tensor.uniform_()` | 无 | | `Tensor.var()`, `torch.var()` | 删除维度 | | `torch.var_mean()` | 删除维度 | | `Tensor.zero_()` | None | | `torch.zeros()` | 工厂函数 | ## 保留输入名称 所有逐点一元函数都遵循此规则,以及一些其他一元函数。 + 检查名称:无 + 传播名称:输入张量的名称传播到输出。 ```py >>> x = torch.randn(3, 3, names=('N', 'C')) >>> x.abs().names ('N', 'C') ``` ## 删除维度 所有类似`sum()`的减少操作通过减少所需维度来删除维度。其他操作,如`select()`和`squeeze()`会删除维度。 无论何时可以将整数维度索引传递给运算符,也可以传递维度名称。接受维度索引列表的函数也可以接受维度名称列表。 + 检查名称:如果`dim`或`dims`作为名称列表传入,请检查这些名称是否存在于`self`中。 + 传播名称:如果由`dim`或`dims`指定的输入张量的维度不存在于输出张量中,则这些维度的相应名称不会出现在`output.names`中。 ```py >>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W')) >>> x.squeeze('N').names ('C', 'H', 'W') >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) >>> x.sum(['N', 'C']).names ('H', 'W') # Reduction ops with keepdim=True don't actually remove dimensions. >>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W')) >>> x.sum(['N', 'C'], keepdim=True).names ('N', 'C', 'H', 'W') ``` ## 统一输入的名称 所有二元算术运算都遵循此规则。广播操作仍然从右侧按位置广播,以保持与未命名张量的兼容性。要通过名称执行显式广播,请使用`Tensor.align_as()`。 + 检查名称:所有名称必须从右侧按位置匹配。即,在`tensor + other`中,对于所有`i`,`match(tensor.names[i], other.names[i])`必须为真,其中`i`在`(-min(tensor.dim(), other.dim()) + 1, -1]`中。 + 检查名称:此外,所有命名维度必须从右侧对齐。在匹配过程中,如果我们将命名维度`A`与未命名维度`None`匹配,那么`A`不能出现在具有未命名维度的张量中。 + 传播名称:从两个张量的右侧统一名称对以生成输出名称。 例如, ```py # tensor: Tensor[ N, None] # other: Tensor[None, C] >>> tensor = torch.randn(3, 3, names=('N', None)) >>> other = torch.randn(3, 3, names=(None, 'C')) >>> (tensor + other).names ('N', 'C') ``` 检查名称: + `match(tensor.names[-1], other.names[-1])`为`True` + `match(tensor.names[-2], tensor.names[-2])`为`True` + 因为我们在`tensor`中将`None`与`'C'`匹配,因此请确保`'C'`不存在于`tensor`中(它不存在)。 + 检查确保`'N'`不存在于`other`中(不存在)。 最后,输出名称通过`[unify('N', None), unify(None, 'C')] = ['N', 'C']`计算 更多示例: ```py # Dimensions don't match from the right: # tensor: Tensor[N, C] # other: Tensor[ N] >>> tensor = torch.randn(3, 3, names=('N', 'C')) >>> other = torch.randn(3, names=('N',)) >>> (tensor + other).names RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims ['N']: dim 'C' and dim 'N' are at the same position from the right but do not match. # Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]: # tensor: Tensor[N, None] # other: Tensor[ N] >>> tensor = torch.randn(3, 3, names=('N', None)) >>> other = torch.randn(3, names=('N',)) >>> (tensor + other).names RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and dims ['N', None]: dim 'N' appears in a different position from the right across both lists. ``` 注意 在最后两个示例中,可以通过名称对齐张量,然后执行加法。使用`Tensor.align_as()`按名称对齐张量,或使用`Tensor.align_to()`按自定义维度顺序对齐张量。 ## 重新排列维度 一些操作,比如`Tensor.t()`,会改变维度的顺序。维度名称附加到各个维度上,因此它们也会被重新排列。 如果运算符接受位置索引`dim`,也可以将维度名称作为`dim`。 + 检查名称:如果`dim`作为名称传入,请检查它是否存在于张量中。 + 传播名称:以与被置换的维度相同的方式置换维度名称。 ```py >>> x = torch.randn(3, 3, names=('N', 'C')) >>> x.transpose('N', 'C').names ('C', 'N') ```##收缩维度 矩阵乘法函数遵循某种变体。让我们先看看`torch.mm()`,然后推广批量矩阵乘法的规则。 对于`torch.mm(tensor, other)`: + 检查名称:无 + 传播名称:结果名称为`(tensor.names[-2], other.names[-1])`。 ```py >>> x = torch.randn(3, 3, names=('N', 'D')) >>> y = torch.randn(3, 3, names=('in', 'out')) >>> x.mm(y).names ('N', 'out') ``` 从本质上讲,矩阵乘法在两个维度上执行点积,将它们折叠。当两个张量进行矩阵乘法时,收缩的维度消失,不会出现在输出张量中。 `torch.mv()`,`torch.dot()`的工作方式类似:名称推断不检查输入名称,并移除参与点积的维度: ```py >>> x = torch.randn(3, 3, names=('N', 'D')) >>> y = torch.randn(3, names=('something',)) >>> x.mv(y).names ('N',) ``` 现在,让我们看看`torch.matmul(tensor, other)`。假设`tensor.dim() >= 2`且`other.dim() >= 2`。 + 检查名称:检查输入的批量维度是否对齐且可广播。查看从输入统一名称以了解输入对齐的含义。 + 传播名称:结果名称通过统一批量维度并移除收缩维度来获得:`unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])`。 示例: ```py # Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F']. # 'A', 'B' are batch dimensions. >>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D')) >>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F')) >>> torch.matmul(x, y).names ('A', 'B', 'C', 'F') ``` 最后,许多 matmul 函数有融合的`add`版本。即,`addmm()`和`addmv()`。这些被视为对`mm()`和`add()`进行名称推断的组合。##工厂函数 工厂函数现在接受一个新的`names`参数,将每个维度与一个名称关联起来。 ```py >>> torch.zeros(2, 3, names=('N', 'C')) tensor([[0., 0., 0.], [0., 0., 0.]], names=('N', 'C')) ```##输出函数和就地变体 指定为`out=`张量的张量具有以下行为: + 如果没有命名维度,则从操作中计算的名称会传播到其中。 + 如果有任何命名维度,则从操作中计算的名称必须与现有名称完全相等。否则,操作会出错。 所有就地方法都会修改输入,使其具有与名称推断的计算名称相等。例如: ```py >>> x = torch.randn(3, 3) >>> y = torch.randn(3, 3, names=('N', 'C')) >>> x.names (None, None) >>> x += y >>> x.names ('N', 'C') ```