# torch.library > 原文:[`pytorch.org/docs/stable/library.html`](https://pytorch.org/docs/stable/library.html) torch.library 是一组用于扩展 PyTorch 核心运算符库的 API。它包含用于创建新的自定义运算符以及扩展使用 PyTorch 的 C++运算符注册 API(例如 aten 运算符)定义的运算符的实用程序。 有关有效使用这些 API 的详细指南,请参阅[此 gdoc](https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit) 使用`torch.library.define()`来定义新的自定义运算符。使用 impl 方法,例如`torch.library.impl()`和 func:torch.library.impl_abstract,为任何运算符添加实现(它们可以使用`torch.library.define()`创建,或通过 PyTorch 的 C++运算符注册 API 创建)。 ```py torch.library.define(qualname, schema, *, lib=None, tags=()) ``` ```py torch.library.define(lib, schema, alias_analysis='') ``` 定义一个新的运算符。 在 PyTorch 中,定义一个 op(即“运算符”)是一个两步过程:- 我们需要定义 op(提供运算符名称和模式)- 我们需要实现运算符与各种 PyTorch 子系统(如 CPU/CUDA 张量,Autograd 等)交互的行为。 此入口点定义了自定义运算符(第一步),然后您必须通过调用各种`impl_*` API(如`torch.library.impl()`或`torch.library.impl_abstract()`)执行第二步。 参数 + **qualname**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")) - 运算符的限定名称。应该是一个看起来像“namespace::name”的字符串,例如“aten::sin”。PyTorch 中的运算符需要一个命名空间以避免名称冲突;给定的运算符只能创建一次。如果您正在编写 Python 库,我们建议命名空间为顶级模块的名称。 + **schema**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")) - 运算符的模式。例如,对于一个接受一个张量并返回一个张量的 op,“(张量 x)->张量”。它不包含运算符名称(传递给`qualname`)。 + **lib**(*可选***[*Library**])- 如果提供,此运算符的生命周期将与 Library 对象的生命周期绑定。 + **标签**(*标签* *|* *序列***[*标签**])- 一个或多个 torch.Tag,应用于此运算符。对运算符进行标记会改变运算符在各种 PyTorch 子系统下的行为;请在应用之前仔细阅读 torch.Tag 的文档。 示例:: ```py >>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the operator >>> @torch.library.impl("mylibrary::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> # Call the new operator from torch.ops. >>> x = torch.randn(3) >>> y = torch.ops.mylib.sin(x) >>> assert torch.allclose(y, x) ``` ```py torch.library.impl(qualname, types, func=None, *, lib=None) ``` ```py torch.library.impl(lib, name, dispatch_key='') ``` 为此运算符的设备类型注册一个实现。 您可以将“default”传递给`types`,以将此实现注册为所有设备类型的默认实现。只有在实现真正支持所有设备类型时才使用此选项;例如,如果它是内置 PyTorch 运算符的组合,则为真。 一些有效的类型是:“cpu”,“cuda”,“xla”,“mps”,“ipu”,“xpu”。 参数 + **qualname**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")) - 应该是一个看起来像“namespace::operator_name”的字符串。 + **types**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)") *|* *序列**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")*])- 要注册实现的设备类型。 + **lib**(*可选***[*Library**])- 如果提供,此注册的生命周期将与 Library 对象的生命周期绑定。 示例 ```py >>> import torch >>> import numpy as np >>> >>> # Define the operator >>> torch.library.define("mylibrary::sin", "(Tensor x) -> Tensor") >>> >>> # Add implementations for the cpu device >>> @torch.library.impl("mylibrary::sin", "cpu") >>> def f(x): >>> return torch.from_numpy(np.sin(x.numpy())) >>> >>> x = torch.randn(3) >>> y = torch.ops.mylibrary.sin(x) >>> assert torch.allclose(y, x.sin()) ``` ```py torch.library.impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1) ``` 为此运算符注册一个抽象实现。 “抽象实现”指定了在携带无数据的张量上的操作符的行为。给定具有某些属性(大小/步幅/存储偏移/设备)的输入张量,它指定了输出张量的属性是什么。 抽象实现与操作符具有相同的签名。它适用于 FakeTensors 和元张量。要编写抽象实现,请假设操作符的所有张量输入都是常规的 CPU/CUDA/Meta 张量,但它们没有存储,并且您正在尝试返回常规的 CPU/CUDA/Meta 张量作为输出。抽象实现必须仅包含 PyTorch 操作(不能直接访问任何输入或中间张量的存储或数据)。 此 API 可以用作装饰器(请参阅示例)。 有关自定义操作的详细指南,请参阅[`docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit`](https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ/edit) 示例 ```py >>> import torch >>> import numpy as np >>> from torch import Tensor >>> >>> # Example 1: an operator without data-dependent output shape >>> torch.library.define( >>> "mylib::custom_linear", >>> "(Tensor x, Tensor weight, Tensor bias) -> Tensor") >>> >>> @torch.library.impl_abstract("mylib::custom_linear") >>> def custom_linear_abstract(x, weight): >>> assert x.dim() == 2 >>> assert weight.dim() == 2 >>> assert bias.dim() == 1 >>> assert x.shape[1] == weight.shape[1] >>> assert weight.shape[0] == bias.shape[0] >>> assert x.device == weight.device >>> >>> return (x @ weight.t()) + bias >>> >>> # Example 2: an operator with data-dependent output shape >>> torch.library.define("mylib::custom_nonzero", "(Tensor x) -> Tensor") >>> >>> @torch.library.impl_abstract("mylib::custom_nonzero") >>> def custom_nonzero_abstract(x): >>> # Number of nonzero-elements is data-dependent. >>> # Since we cannot peek at the data in an abstract impl, >>> # we use the ctx object to construct a new symint that >>> # represents the data-dependent size. >>> ctx = torch.library.get_ctx() >>> nnz = ctx.new_dynamic_size() >>> shape = [nnz, x.dim()] >>> result = x.new_empty(shape, dtype=torch.int64) >>> return result >>> >>> @torch.library.impl("mylib::custom_nonzero", "cpu") >>> def custom_nonzero_cpu(x): >>> x_np = x.numpy() >>> res = np.stack(np.nonzero(x_np), axis=1) >>> return torch.tensor(res, device=x.device) ``` ```py torch.library.get_ctx() ``` get_ctx()返回当前的 AbstractImplCtx 对象。 只有在抽象实现内部调用`get_ctx()`才有效(有关更多用法细节,请参见`torch.library.impl_abstract()`)。 返回类型 *AbstractImplCtx* ## 低级 API 以下 API 是直接绑定到 PyTorch 的 C++低级操作符注册 API。 警告 低级操作符注册 API 和 PyTorch 调度程序是一个复杂的 PyTorch 概念。我们建议在可能的情况下使用上面的更高级 API(不需要 torch.library.Library 对象)。这篇博文<[`blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/`](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/)>`_ 是了解 PyTorch 调度程序的好起点。 有关如何使用此 API 的一些示例的教程可在[Google Colab](https://colab.research.google.com/drive/1RRhSfk7So3Cn02itzLWE9K4Fam-8U011?usp=sharing)上找到。 ```py class torch.library.Library(ns, kind, dispatch_key='') ``` 一个用于创建库的类,可以用于注册新操作符或从 Python 中覆盖现有库中的操作符。用户可以选择传递调度键名称,如果他们只想注册与特定调度键对应的内核。 要创建一个库以覆盖现有库中的操作符(名称为 ns),将 kind 设置为“IMPL”。要创建一个新库(名称为 ns)以注册新操作符,请将 kind 设置为“DEF”。要创建一个可能存在的库的片段以注册操作符(并绕过只有一个给定命名空间的库的限制),将 kind 设置为“FRAGMENT”。 参数 + **ns** – 库名称 + **kind** – “DEF”, “IMPL”(默认:“IMPL”), “FRAGMENT” + **dispatch_key** – PyTorch 调度键(默认:“”) ```py define(schema, alias_analysis='', *, tags=()) ``` 在 ns 命名空间中定义一个新操作符及其语义。 参数 + **schema** – 定义新操作符的函数模式。 + **alias_analysis** (*可选*) – 指示是否可以从模式推断操作符参数的别名属性(默认行为)或不可以(“CONSERVATIVE”)。 + **tags** (*Tag* *|* *Sequence***[*Tag**]*) – 一个或多个 torch.Tag,应用于此操作符。对操作符进行标记会更改操作符在各种 PyTorch 子系统下的行为;请在应用之前仔细阅读 torch.Tag 的文档。 返回 从模式推断的操作符名称。 示例:: ```py >>> my_lib = Library("foo", "DEF") >>> my_lib.define("sum(Tensor self) -> Tensor") ``` ```py impl(op_name, fn, dispatch_key='') ``` 为库中定义的操作符注册函数实现。 参数 + **op_name** – 操作符名称(连同重载)或 OpOverload 对象。 + **fn** – 作为输入调度键的操作符实现的函数或`fallthrough_kernel()`以注册一个 fallthrough。 + **dispatch_key** - 输入函数应注册的调度键。默认情况下,它使用创建库时使用的调度键。 示例: ```py >>> my_lib = Library("aten", "IMPL") >>> def div_cpu(self, other): >>> return self * (1 / other) >>> my_lib.impl("div.Tensor", div_cpu, "CPU") ``` ```py torch.library.fallthrough_kernel() ``` 一个虚拟函数,传递给`Library.impl`以注册一个默认情况。