doc22_081.md 17.3 KB
Newer Older
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
1 2
# torch.testing

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
3
> [`pytorch.org/docs/stable/testing.html`](https://pytorch.org/docs/stable/testing.html)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
4 5

```py
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
6
torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
7 8 9 10 11 12 13 14 15 16 17 18
```

断言`actual``expected`是接近的。

如果`actual``expected`是分步、非量化、实值且有限的,则它们被视为接近,如果

$\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert$∣actual−expected∣≤atol+rtol⋅∣expected∣

非有限值(`-inf``inf`)仅在它们相等时才被视为接近。只有当`equal_nan``True`时,`NaN`才被视为相等。

此外,只有当它们相同时才被视为接近

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
19
+   `device`(如果`check_device``True`),
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
20 21 22 23 24 25 26 27 28

+   `dtype`(如果`check_dtype``True`),

+   `layout`(如果`check_layout``True`),和

+   步幅(如果`check_stride``True`)。

如果`actual``expected`是元张量,则仅执行属性检查。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
29
如果`actual``expected`是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 布局),它们的分步成员将被单独检查。索引,即 COO 的`indices`,CSR 和 BSR 的`crow_indices``col_indices`,或 CSC 和 BSC 布局的`ccol_indices``row_indices`,始终被检查是否相等,而值根据上述定义被视为接近。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
30

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
31
如果`actual``expected`是量化的,则它们被视为接近,如果它们具有相同的`qscheme()`并且`dequantize()`的结果根据上述定义接近。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
32

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
33
`actual``expected`可以是`Tensor`或任何张量或标量,可以使用`torch.as_tensor()`构造`torch.Tensor`。除了 Python 标量外,输入类型必须直接相关。此外,`actual``expected`可以是[`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence "(在 Python v3.12 中)")[`Mapping`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping "(在 Python v3.12 中)"),在这种情况下,如果它们的结构匹配并且所有元素根据上述定义被视为接近,则它们被视为接近。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
34 35 36

注意

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
37
Python 标量是类型关系要求的一个例外,因为它们的`type()`,即[`int`](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")[`float`](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)")[`complex`](https://docs.python.org/3/library/functions.html#complex "(在 Python v3.12 中)"),等同于张量类的`dtype`。因此,不同类型的 Python 标量可以被检查,但需要`check_dtype=False`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
38 39 40 41 42 43 44

参数

+   **actual***任意*)- 实际输入。

+   **expected***任意*)- 预期输入。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
45
+   **allow_subclasses**[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)"))- 如果为`True`(默认)并且除了 Python 标量之外,直接相关类型的输入是允许的。否则需要类型相等。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
46

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
47
+   **rtol***可选**[*[*float*](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)")*]*) - 相对容差。如果指定了`atol`,必须同时指定。如果省略,默认值基于`dtype`从下表中选择。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
48

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
49
+   **atol***可选**[*[*float*](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)")*]*) - 绝对容差。如果指定了`rtol`,必须同时指定。如果省略,默认值基于`dtype`从下表中选择。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
50 51 52

+   **equal_nan** (*Union**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")*,* [*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]*) – 如果为`True`,则认为两个`NaN`值相等。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
53
+   **check_device** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`(默认),则断言相应的张量位于相同的`device`上。如果禁用此检查,则位于不同`device`上的张量在比较之前将移动到 CPU。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
54

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
55
+   **check_dtype** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`(默认),则断言相应的张量具有相同的`dtype`。如果禁用此检查,则在比较之前将具有不同`dtype`的张量提升为公共`dtype`(根据`torch.promote_types()`)。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
56 57 58 59 60 61 62 63 64

+   **check_layout** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`(默认),则断言相应的张量具有相同的`layout`。如果禁用此检查,则在比较之前将具有不同`layout`的张量转换为分步张量。

+   **check_stride** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`且相应的张量是分步的,则断言它们具有相同的步幅。

+   **msg** (*Optional**[**Union**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Callable**[**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]**,* [*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]**]**]*) – 在比较过程中发生失败时使用的可选错误消息。也可以作为可调用对象传递,此时将使用生成的消息并应返回新消息。

引发

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
65
+   [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(in Python v3.12)") – 如果无法从输入构造`torch.Tensor`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
66 67 68

+   [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(in Python v3.12)") – 如果只指定了`rtol``atol`

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
69
+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果相应的输入不是 Python 标量且不直接相关。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
70

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
71
+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果`allow_subclasses``False`,但相应的输入不是 Python 标量并且类型不同。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
72 73 74 75 76

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果输入是[`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence "(in Python v3.12)"),但它们的长度不匹配。

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果输入是[`Mapping`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping "(in Python v3.12)"),但它们的键集不匹配。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
77
+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果相应的张量的`shape`不相同。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
78 79 80 81 82

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果`check_layout``True`,但相应的张量的`layout`不相同。

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果仅有一组相应的张量被量化。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
83
+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果相应的张量被量化,但具有不同的`qscheme()`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
84

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
85
+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果`check_device``True`,但相应的张量不在相同的`device`上。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果`check_dtype``True`,但相应的张量的`dtype`不相同。

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果`check_stride``True`,但相应的步进张量的步幅不相同。

+   [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果相应张量的值根据上述定义不接近。

以下表显示了不同`dtype`的默认`rtol``atol`。在`dtype`不匹配的情况下,使用两个容差中的最大值。

| `dtype` | `rtol` | `atol` |
| --- | --- | --- |
| `float16` | `1e-3` | `1e-5` |
| `bfloat16` | `1.6e-2` | `1e-5` |
| `float32` | `1.3e-6` | `1e-5` |
| `float64` | `1e-7` | `1e-7` |
| `complex32` | `1e-3` | `1e-5` |
| `complex64` | `1.3e-6` | `1e-5` |
| `complex128` | `1e-7` | `1e-7` |
| `quint8` | `1.3e-6` | `1e-5` |
| `quint2x4` | `1.3e-6` | `1e-5` |
| `quint4x2` | `1.3e-6` | `1e-5` |
| `qint8` | `1.3e-6` | `1e-5` |
| `qint32` | `1.3e-6` | `1e-5` |
| 其他 | `0.0` | `0.0` |

注意

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
113
`assert_close()` 具有严格的默认设置,可以高度配置。鼓励用户使用[`partial()`](https://docs.python.org/3/library/functools.html#functools.partial "(在 Python v3.12 中)") 来适应其用例。例如,如果需要进行相等性检查,可以定义一个`assert_equal`,默认情况下对每个`dtype`使用零容差:
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242

```py
>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Expected 1e-10 but got 1e-09.
Absolute difference: 9.000000000000001e-10
Relative difference: 9.0 
```

示例

```py
>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected) 
```

```py
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected) 
```

```py
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected) 
```

```py
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected) 
```

```py
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected) 
```

```py
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type
<class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
and <class 'torch.Tensor'>.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False) 
```

```py
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Expected nan but got nan.
Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True) 
```

```py
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default error message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # If msg is a callable, it can be used to augment the generated message with
>>> # extra information
>>> torch.testing.assert_close(
...     actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
... )
Traceback (most recent call last):
...
AssertionError: Header

Tensor-likes are not close!

Mismatched elements: 2 / 3 (66.7%)
Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)

Footer 
```

```py
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
243
torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None)
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
```

创建具有给定`shape``device``dtype`的张量,并用从`[low, high)`均匀抽取的值填充。

如果指定了`low``high`,并且超出了`dtype`的可表示有限值范围,则它们将被夹紧到最低或最高可表示有限值,分别。如果为`None`,则以下表格描述了`low``high`的默认值,这取决于`dtype`

| `dtype` | `low` | `high` |
| --- | --- | --- |
| 布尔类型 | `0` | `2` |
| 无符号整数类型 | `0` | `10` |
| 有符号整数类型 | `-9` | `10` |
| 浮点类型 | `-9` | `9` |
| 复数类型 | `-9` | `9` |

参数

+   **shape** (*元组**[*[*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")*,* *...**]*) – 定义输出张量形状的单个整数或整数序列。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
262
+   **dtype** (`torch.dtype`) – 返回张量的数据类型。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
263

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
264
+   **device** (*Union**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *torch.device**]*) – 返回张量的设备。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
265

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
266
+   **low** (*可选**[**数字**]*) – 设置给定范围的下限(包括)。如果提供了一个数字,它将被夹紧到给定 dtype 的最小可表示有限值。当为`None`(默认)时,此值根据`dtype`(见上表)确定。默认值:`None`
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
267 268 269 270 271

+   **high** (*可选**[**数字**]*) –

    设置给定范围的上限(不包括)。如果提供了一个数字,则它将被夹紧到给定 dtype 的最大可表示有限值。当为 `None`(默认)时,此值基于 `dtype` 决定(参见上表)。默认值:`None`。

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
272
    自版本 2.1 起已弃用:对于浮点或复数类型,将 `low==high` 传递给 `make_tensor()` 自 2.1 版本起已弃用,并将在 2.3 版本中移除。请改用 `torch.full()`。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
273 274 275 276 277 278 279

+   **requires_grad***可选**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*]*) – 如果 autograd 应记录返回的张量上的操作。默认值:`False`

+   **noncontiguous***可选**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*]*) – 如果为 True,则返回的张量将是非连续的。如果构造的张量少于两个元素,则忽略此参数。与 `memory_format` 互斥。

+   **exclude_zero***可选**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*]*) – 如果为 `True`,则零将被替换为依赖于 `dtype` 的小正值。对于布尔和整数类型,零将被替换为一。对于浮点类型,它将被替换为 `dtype` 的最小正常数(`dtype``finfo()` 对象的“微小”值),对于复数类型,它将被替换为一个实部和虚部都是复数类型可表示的最小正常数的复数。默认为 `False`

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
280
+   **memory_format***可选***[*torch.memory_format**]*) – 返回张量的内存格式。与 `noncontiguous` 互斥。
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295

引发

+   [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果为整数 dtype 传递了 `requires_grad=True`

+   [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果 `low >= high`

+   [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果 `low``high` 中有一个为 `nan`

+   [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果同时传递了 `noncontiguous``memory_format`

+   [**TypeError**](https://docs.python.org/3/library/exceptions.html#TypeError "(在 Python v3.12 中)") – 如果 `dtype` 不受此函数支持。

返回类型

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
296
*Tensor*
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311

示例

```py
>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
 [False, True]], device='cuda:0') 
```

```py
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
312
torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='')
绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
313 314 315 316
```

警告

绝不原创的飞龙's avatar
绝不原创的飞龙 已提交
317
`torch.testing.assert_allclose()``1.12` 版本起已被弃用,并将在将来的版本中移除。请改用 `torch.testing.assert_close()`。您可以在[此处](https://github.com/pytorch/pytorch/issues/61844)找到详细的升级说明。