# torch.testing > [`pytorch.org/docs/stable/testing.html`](https://pytorch.org/docs/stable/testing.html) ```py torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_layout=True, check_stride=False, msg=None) ``` 断言`actual`和`expected`是接近的。 如果`actual`和`expected`是分步、非量化、实值且有限的,则它们被视为接近,如果 $\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert$∣actual−expected∣≤atol+rtol⋅∣expected∣ 非有限值(`-inf`和`inf`)仅在它们相等时才被视为接近。只有当`equal_nan`为`True`时,`NaN`才被视为相等。 此外,只有当它们相同时才被视为接近 + `device`(如果`check_device`为`True`), + `dtype`(如果`check_dtype`为`True`), + `layout`(如果`check_layout`为`True`),和 + 步幅(如果`check_stride`为`True`)。 如果`actual`或`expected`是元张量,则仅执行属性检查。 如果`actual`和`expected`是稀疏的(具有 COO、CSR、CSC、BSR 或 BSC 布局),它们的分步成员将被单独检查。索引,即 COO 的`indices`,CSR 和 BSR 的`crow_indices`和`col_indices`,或 CSC 和 BSC 布局的`ccol_indices`和`row_indices`,始终被检查是否相等,而值根据上述定义被视为接近。 如果`actual`和`expected`是量化的,则它们被视为接近,如果它们具有相同的`qscheme()`并且`dequantize()`的结果根据上述定义接近。 `actual`和`expected`可以是`Tensor`或任何张量或标量,可以使用`torch.as_tensor()`构造`torch.Tensor`。除了 Python 标量外,输入类型必须直接相关。此外,`actual`和`expected`可以是[`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence "(在 Python v3.12 中)")或[`Mapping`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping "(在 Python v3.12 中)"),在这种情况下,如果它们的结构匹配并且所有元素根据上述定义被视为接近,则它们被视为接近。 注意 Python 标量是类型关系要求的一个例外,因为它们的`type()`,即[`int`](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)"),[`float`](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)")和[`complex`](https://docs.python.org/3/library/functions.html#complex "(在 Python v3.12 中)"),等同于张量类的`dtype`。因此,不同类型的 Python 标量可以被检查,但需要`check_dtype=False`。 参数 + **actual**(*任意*)- 实际输入。 + **expected**(*任意*)- 预期输入。 + **allow_subclasses**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)"))- 如果为`True`(默认)并且除了 Python 标量之外,直接相关类型的输入是允许的。否则需要类型相等。 + **rtol**(*可选**[*[*float*](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)")*]*) - 相对容差。如果指定了`atol`,必须同时指定。如果省略,默认值基于`dtype`从下表中选择。 + **atol**(*可选**[*[*float*](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)")*]*) - 绝对容差。如果指定了`rtol`,必须同时指定。如果省略,默认值基于`dtype`从下表中选择。 + **equal_nan** (*Union**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")*,* [*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]*) – 如果为`True`,则认为两个`NaN`值相等。 + **check_device** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`(默认),则断言相应的张量位于相同的`device`上。如果禁用此检查,则位于不同`device`上的张量在比较之前将移动到 CPU。 + **check_dtype** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`(默认),则断言相应的张量具有相同的`dtype`。如果禁用此检查,则在比较之前将具有不同`dtype`的张量提升为公共`dtype`(根据`torch.promote_types()`)。 + **check_layout** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`(默认),则断言相应的张量具有相同的`layout`。如果禁用此检查,则在比较之前将具有不同`layout`的张量转换为分步张量。 + **check_stride** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)")) – 如果为`True`且相应的张量是分步的,则断言它们具有相同的步幅。 + **msg** (*Optional**[**Union**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Callable**[**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]**,* [*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]**]**]*) – 在比较过程中发生失败时使用的可选错误消息。也可以作为可调用对象传递,此时将使用生成的消息并应返回新消息。 引发 + [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(in Python v3.12)") – 如果无法从输入构造`torch.Tensor`。 + [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(in Python v3.12)") – 如果只指定了`rtol`或`atol`。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果相应的输入不是 Python 标量且不直接相关。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果`allow_subclasses`为`False`,但相应的输入不是 Python 标量并且类型不同。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果输入是[`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence "(in Python v3.12)"),但它们的长度不匹配。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果输入是[`Mapping`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping "(in Python v3.12)"),但它们的键集不匹配。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果相应的张量的`shape`不相同。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(in Python v3.12)") – 如果`check_layout`为`True`,但相应的张量的`layout`不相同。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果仅有一组相应的张量被量化。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果相应的张量被量化,但具有不同的`qscheme()`。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果`check_device`为`True`,但相应的张量不在相同的`device`上。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果`check_dtype`为`True`,但相应的张量的`dtype`不相同。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果`check_stride`为`True`,但相应的步进张量的步幅不相同。 + [**AssertionError**](https://docs.python.org/3/library/exceptions.html#AssertionError "(在 Python v3.12 中)") – 如果相应张量的值根据上述定义不接近。 以下表显示了不同`dtype`的默认`rtol`和`atol`。在`dtype`不匹配的情况下,使用两个容差中的最大值。 | `dtype` | `rtol` | `atol` | | --- | --- | --- | | `float16` | `1e-3` | `1e-5` | | `bfloat16` | `1.6e-2` | `1e-5` | | `float32` | `1.3e-6` | `1e-5` | | `float64` | `1e-7` | `1e-7` | | `complex32` | `1e-3` | `1e-5` | | `complex64` | `1.3e-6` | `1e-5` | | `complex128` | `1e-7` | `1e-7` | | `quint8` | `1.3e-6` | `1e-5` | | `quint2x4` | `1.3e-6` | `1e-5` | | `quint4x2` | `1.3e-6` | `1e-5` | | `qint8` | `1.3e-6` | `1e-5` | | `qint32` | `1.3e-6` | `1e-5` | | 其他 | `0.0` | `0.0` | 注意 `assert_close()` 具有严格的默认设置,可以高度配置。鼓励用户使用[`partial()`](https://docs.python.org/3/library/functools.html#functools.partial "(在 Python v3.12 中)") 来适应其用例。例如,如果需要进行相等性检查,可以定义一个`assert_equal`,默认情况下对每个`dtype`使用零容差: ```py >>> import functools >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0) >>> assert_equal(1e-9, 1e-10) Traceback (most recent call last): ... AssertionError: Scalars are not equal! Expected 1e-10 but got 1e-09. Absolute difference: 9.000000000000001e-10 Relative difference: 9.0 ``` 示例 ```py >>> # tensor to tensor comparison >>> expected = torch.tensor([1e0, 1e-1, 1e-2]) >>> actual = torch.acos(torch.cos(expected)) >>> torch.testing.assert_close(actual, expected) ``` ```py >>> # scalar to scalar comparison >>> import math >>> expected = math.sqrt(2.0) >>> actual = 2.0 / math.sqrt(2.0) >>> torch.testing.assert_close(actual, expected) ``` ```py >>> # numpy array to numpy array comparison >>> import numpy as np >>> expected = np.array([1e0, 1e-1, 1e-2]) >>> actual = np.arccos(np.cos(expected)) >>> torch.testing.assert_close(actual, expected) ``` ```py >>> # sequence to sequence comparison >>> import numpy as np >>> # The types of the sequences do not have to match. They only have to have the same >>> # length and their elements have to match. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)] >>> actual = tuple(expected) >>> torch.testing.assert_close(actual, expected) ``` ```py >>> # mapping to mapping comparison >>> from collections import OrderedDict >>> import numpy as np >>> foo = torch.tensor(1.0) >>> bar = 2.0 >>> baz = np.array(3.0) >>> # The types and a possible ordering of mappings do not have to match. They only >>> # have to have the same set of keys and their elements have to match. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)]) >>> actual = {"baz": baz, "bar": bar, "foo": foo} >>> torch.testing.assert_close(actual, expected) ``` ```py >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = expected.clone() >>> # By default, directly related instances can be compared >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected) >>> # This check can be made more strict with allow_subclasses=False >>> torch.testing.assert_close( ... torch.nn.Parameter(actual), expected, allow_subclasses=False ... ) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type and . >>> # If the inputs are not directly related, they are never considered close >>> torch.testing.assert_close(actual.numpy(), expected) Traceback (most recent call last): ... TypeError: No comparison pair was able to handle inputs of type and . >>> # Exceptions to these rules are Python scalars. They can be checked regardless of >>> # their type if check_dtype=False. >>> torch.testing.assert_close(1.0, 1, check_dtype=False) ``` ```py >>> # NaN != NaN by default. >>> expected = torch.tensor(float("Nan")) >>> actual = expected.clone() >>> torch.testing.assert_close(actual, expected) Traceback (most recent call last): ... AssertionError: Scalars are not close! Expected nan but got nan. Absolute difference: nan (up to 1e-05 allowed) Relative difference: nan (up to 1.3e-06 allowed) >>> torch.testing.assert_close(actual, expected, equal_nan=True) ``` ```py >>> expected = torch.tensor([1.0, 2.0, 3.0]) >>> actual = torch.tensor([1.0, 4.0, 5.0]) >>> # The default error message can be overwritten. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!") Traceback (most recent call last): ... AssertionError: Argh, the tensors are not close! >>> # If msg is a callable, it can be used to augment the generated message with >>> # extra information >>> torch.testing.assert_close( ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter" ... ) Traceback (most recent call last): ... AssertionError: Header Tensor-likes are not close! Mismatched elements: 2 / 3 (66.7%) Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed) Footer ``` ```py torch.testing.make_tensor(*shape, dtype, device, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False, memory_format=None) ``` 创建具有给定`shape`、`device`和`dtype`的张量,并用从`[low, high)`均匀抽取的值填充。 如果指定了`low`或`high`,并且超出了`dtype`的可表示有限值范围,则它们将被夹紧到最低或最高可表示有限值,分别。如果为`None`,则以下表格描述了`low`和`high`的默认值,这取决于`dtype`。 | `dtype` | `low` | `high` | | --- | --- | --- | | 布尔类型 | `0` | `2` | | 无符号整数类型 | `0` | `10` | | 有符号整数类型 | `-9` | `10` | | 浮点类型 | `-9` | `9` | | 复数类型 | `-9` | `9` | 参数 + **shape** (*元组**[*[*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")*,* *...**]*) – 定义输出张量形状的单个整数或整数序列。 + **dtype** (`torch.dtype`) – 返回张量的数据类型。 + **device** (*Union**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *torch.device**]*) – 返回张量的设备。 + **low** (*可选**[**数字**]*) – 设置给定范围的下限(包括)。如果提供了一个数字,它将被夹紧到给定 dtype 的最小可表示有限值。当为`None`(默认)时,此值根据`dtype`(见上表)确定。默认值:`None`。 + **high** (*可选**[**数字**]*) – 设置给定范围的上限(不包括)。如果提供了一个数字,则它将被夹紧到给定 dtype 的最大可表示有限值。当为 `None`(默认)时,此值基于 `dtype` 决定(参见上表)。默认值:`None`。 自版本 2.1 起已弃用:对于浮点或复数类型,将 `low==high` 传递给 `make_tensor()` 自 2.1 版本起已弃用,并将在 2.3 版本中移除。请改用 `torch.full()`。 + **requires_grad**(*可选**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*]*) – 如果 autograd 应记录返回的张量上的操作。默认值:`False`。 + **noncontiguous**(*可选**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*]*) – 如果为 True,则返回的张量将是非连续的。如果构造的张量少于两个元素,则忽略此参数。与 `memory_format` 互斥。 + **exclude_zero**(*可选**[*[*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*]*) – 如果为 `True`,则零将被替换为依赖于 `dtype` 的小正值。对于布尔和整数类型,零将被替换为一。对于浮点类型,它将被替换为 `dtype` 的最小正常数(`dtype` 的 `finfo()` 对象的“微小”值),对于复数类型,它将被替换为一个实部和虚部都是复数类型可表示的最小正常数的复数。默认为 `False`。 + **memory_format**(*可选***[*torch.memory_format**]*) – 返回张量的内存格式。与 `noncontiguous` 互斥。 引发 + [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果为整数 dtype 传递了 `requires_grad=True`。 + [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果 `low >= high`。 + [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果 `low` 或 `high` 中有一个为 `nan`。 + [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError "(在 Python v3.12 中)") – 如果同时传递了 `noncontiguous` 和 `memory_format`。 + [**TypeError**](https://docs.python.org/3/library/exceptions.html#TypeError "(在 Python v3.12 中)") – 如果 `dtype` 不受此函数支持。 返回类型 *Tensor* 示例 ```py >>> from torch.testing import make_tensor >>> # Creates a float tensor with values in [-1, 1) >>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1) tensor([ 0.1205, 0.2282, -0.6380]) >>> # Creates a bool tensor on CUDA >>> make_tensor((2, 2), device='cuda', dtype=torch.bool) tensor([[False, False], [False, True]], device='cuda:0') ``` ```py torch.testing.assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True, msg='') ``` 警告 `torch.testing.assert_allclose()` 自 `1.12` 版本起已被弃用,并将在将来的版本中移除。请改用 `torch.testing.assert_close()`。您可以在[此处](https://github.com/pytorch/pytorch/issues/61844)找到详细的升级说明。