# torch.fx > 原文:[`pytorch.org/docs/stable/fx.html`](https://pytorch.org/docs/stable/fx.html)> > > 译者:[飞龙](https://github.com/wizardforcel) > > 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) ##概述 FX 是开发人员用来转换`nn.Module`实例的工具包。FX 由三个主要组件组成:**符号跟踪器**,**中间表示**和**Python 代码生成**。这些组件的演示: ```py import torch # Simple module for demonstration class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return self.linear(x + self.param).clamp(min=0.0, max=1.0) module = MyModule() from torch.fx import symbolic_trace # Symbolic tracing frontend - captures the semantics of the module symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) # High-level intermediate representation (IR) - Graph representation print(symbolic_traced.graph) """ graph(): %x : [num_users=1] = placeholder[target=x] %param : [num_users=1] = get_attr[target=param] %add : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %linear : [num_users=1] = call_moduletarget=linear, kwargs = {}) %clamp : [num_users=1] = call_methodtarget=clamp, kwargs = {min: 0.0, max: 1.0}) return clamp """ # Code generation - valid Python code print(symbolic_traced.code) """ def forward(self, x): param = self.param add = x + param; x = param = None linear = self.linear(add); add = None clamp = linear.clamp(min = 0.0, max = 1.0); linear = None return clamp """ ``` **符号跟踪器**执行 Python 代码的“符号执行”。它通过代码传递称为代理的虚拟值。对这些代理的操作被记录下来。有关符号跟踪的更多信息,请参阅`symbolic_trace()`和`Tracer`文档。 **中间表示**是在符号跟踪期间记录的操作的容器。它包含一系列代表函数输入、调用点(到函数、方法或`torch.nn.Module`实例)和返回值的节点。有关 IR 的更多信息,请参阅`Graph`的文档。IR 是应用变换的格式。 **Python 代码生成**是使 FX 成为 Python 到 Python(或模块到模块)转换工具包的关键。对于每个 Graph IR,我们可以创建与 Graph 语义匹配的有效 Python 代码。这个功能被封装在`GraphModule`中,它是一个包含`Graph`以及从 Graph 生成的`forward`方法的`torch.nn.Module`实例。 综合起来,这些组件的流水线(符号跟踪->中间表示->变换->Python 代码生成)构成了 FX 的 Python 到 Python 转换流水线。此外,这些组件也可以单独使用。例如,符号跟踪可以单独用于捕获代码的一种形式以进行分析(而不是转换)目的。代码生成可用于通过配置文件程序生成模型。FX 有许多用途! 在[示例](https://github.com/pytorch/examples/tree/master/fx)存储库中可以找到几个示例变换。##编写变换 什么是 FX 变换?基本上,它是一个看起来像这样的函数。 ```py import torch import torch.fx def transform(m: nn.Module, tracer_class : type = torch.fx.Tracer) -> torch.nn.Module: # Step 1: Acquire a Graph representing the code in `m` # NOTE: torch.fx.symbolic_trace is a wrapper around a call to # fx.Tracer.trace and constructing a GraphModule. We'll # split that out in our transform to allow the caller to # customize tracing behavior. graph : torch.fx.Graph = tracer_class().trace(m) # Step 2: Modify this Graph or create a new one graph = ... # Step 3: Construct a Module to return return torch.fx.GraphModule(m, graph) ``` 您的变换将接受一个`torch.nn.Module`,从中获取一个`Graph`,进行一些修改,然后返回一个新的`torch.nn.Module`。您应该将您的 FX 变换返回的`torch.nn.Module`视为与常规`torch.nn.Module`相同-您可以将其传递给另一个 FX 变换,可以将其传递给 TorchScript,或者可以运行它。确保您的 FX 变换的输入和输出是`torch.nn.Module`将允许组合。 注意 也可以修改现有的`GraphModule`而不是创建一个新的,就像这样: ```py import torch import torch.fx def transform(m : nn.Module) -> nn.Module: gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m) # Modify gm.graph # <...> # Recompile the forward() method of `gm` from its Graph gm.recompile() return gm ``` 请注意,您必须调用`GraphModule.recompile()`来使生成的`forward()`方法与修改后的`Graph`同步。 给定您传入了一个已经被追踪到一个`torch.nn.Module`的`Graph`,现在有两种主要方法可以用来构建一个新的`Graph`。 ### 关于图形的快速入门 关于图形语义的完整处理可以在`Graph`文档中找到,但我们将在这里介绍基础知识。`Graph`是表示`GraphModule`上的方法的数据结构。这需要的信息是: + 方法的输入是什么? + 方法内部运行的操作是什么? + 方法的输出(即返回)值是什么? 这三个概念都用`Node`实例表示。让我们通过一个简短的示例来看看我们的意思: ```py import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum( self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m) gm.graph.print_tabular() ``` 在这里,我们为演示目的定义了一个模块`MyModule`,实例化它,对其进行符号追踪,然后调用`Graph.print_tabular()`方法来打印一个表格,显示这个`Graph`的节点: > | opcode | name | target | args | kwargs | > | --- | --- | --- | --- | --- | > | placeholder | x | x | () | {} | > | get_attr | linear_weight | linear.weight | () | {} | > | call_function | add_1 | | (x, linear_weight) | {} | > | call_module | linear_1 | linear | (add_1,) | {} | > | call_method | relu_1 | relu | (linear_1,) | {} | > | call_function | sum_1 | | (relu_1,) | {‘dim’: -1} | > | call_function | topk_1 | | (sum_1, 3) | {} | > | output | output | output | (topk_1,) | {} | 我们可以使用这些信息来回答我们上面提出的问题。 + 方法的输入是什么?在 FX 中,方法的输入是通过特殊的`placeholder`节点指定的。在这种情况下,我们有一个带有`target`为`x`的单个`placeholder`节点,这意味着我们有一个名为 x 的单个(非 self)参数。 + 方法内部的操作是什么?`get_attr`、`call_function`、`call_module`和`call_method`节点代表方法中的操作。关于所有这些操作的语义的完整处理可以在`Node`文档中找到。 + 方法的返回值是什么?在`Graph`中,返回值由一个特殊的`output`节点指定。 既然我们现在知道了 FX 中代码是如何表示的基础知识,我们现在可以探讨如何编辑一个`Graph`。 ### 图形操作 #### 直接图形操作 构建这个新的`Graph`的一种方法是直接操作您的旧`Graph`。为了帮助这一点,我们可以简单地获取从符号追踪中获得的`Graph`并对其进行修改。例如,假设我们希望用`torch.mul()`调用替换`torch.add()`调用。 ```py import torch import torch.fx # Sample module class M(torch.nn.Module): def forward(self, x, y): return torch.add(x, y) def transform(m: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: graph : fx.Graph = tracer_class().trace(m) # FX represents its Graph as an ordered list of # nodes, so we can iterate through them. for node in graph.nodes: # Checks if we're calling a function (i.e: # torch.add) if node.op == 'call_function': # The target attribute is the function # that call_function calls. if node.target == torch.add: node.target = torch.mul graph.lint() # Does some checks to make sure the # Graph is well-formed. return fx.GraphModule(m, graph) ``` 我们还可以进行更复杂的`Graph`重写,比如删除或追加节点。为了帮助这些转换,FX 有用于转换图形的实用函数,可以在`Graph`文档中找到。下面是使用这些 API 追加`torch.relu()`调用的示例。 ```py # Specifies the insertion point. Any nodes added to the # Graph within this scope will be inserted after `node` with traced.graph.inserting_after(node): # Insert a new `call_function` node calling `torch.relu` new_node = traced.graph.call_function( torch.relu, args=(node,)) # We want all places that used the value of `node` to # now use that value after the `relu` call we've added. # We use the `replace_all_uses_with` API to do this. node.replace_all_uses_with(new_node) ``` 对于只包含替换的简单转换,您也可以使用[subgraph rewriter.](https://github.com/pytorch/pytorch/blob/main/torch/fx/subgraph_rewriter.py) #### 使用 replace_pattern()进行子图重写 FX 还提供了另一种自动化级别,即在直接图操作之上。`replace_pattern()` API 本质上是一个用于编辑`Graph`的“查找/替换”工具。它允许您指定一个`pattern`和一个`replacement`函数,然后将跟踪这些函数,查找`pattern`图中操作组的实例,并用`replacement`图的副本替换这些实例。这可以帮助大大自动化繁琐的图操作代码,随着转换变得更复杂,这些代码可能变得难以控制。 #### 图操作示例 + [替换一个操作](https://github.com/pytorch/examples/blob/master/fx/replace_op.py) + [卷积/批量归一化融合](https://github.com/pytorch/pytorch/blob/40cbf342d3c000712da92cfafeaca651b3e0bd3e/torch/fx/experimental/optimization.py#L50) + [replace_pattern:基本用法](https://github.com/pytorch/examples/blob/master/fx/subgraph_rewriter_basic_use.py) + [量化](https://pytorch.org/docs/main/quantization.html#prototype-fx-graph-mode-quantization) + [反转转换](https://github.com/pytorch/examples/blob/master/fx/invert.py) ### 代理/重追踪 另一种操作`Graph`的方法是重用符号跟踪中使用的`Proxy`机制。例如,假设我们想要编写一个将 PyTorch 函数分解为较小操作的转换。它将每个`F.relu(x)`调用转换为`(x > 0) * x`。一种可能性是执行必要的图重写,将比较和乘法插入到`F.relu`之后,然后清理原始的`F.relu`。但是,我们可以通过使用`Proxy`对象自动记录操作并将其附加到`Graph`来自动化此过程。 要使用此方法,我们编写要插入的操作作为常规 PyTorch 代码,并使用`Proxy`对象作为参数调用该代码。这些`Proxy`对象将捕获对它们执行的操作,并将它们附加到`Graph`。 ```py # Note that this decomposition rule can be read as regular Python def relu_decomposition(x): return (x > 0) * x decomposition_rules = {} decomposition_rules[F.relu] = relu_decomposition def decompose(model: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: """ Decompose `model` into smaller constituent operations. Currently,this only supports decomposing ReLU into its mathematical definition: (x > 0) * x """ graph : fx.Graph = tracer_class().trace(model) new_graph = fx.Graph() env = {} tracer = torch.fx.proxy.GraphAppendingTracer(new_graph) for node in graph.nodes: if node.op == 'call_function' and node.target in decomposition_rules: # By wrapping the arguments with proxies, # we can dispatch to the appropriate # decomposition rule and implicitly add it # to the Graph by symbolically tracing it. proxy_args = [ fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args] output_proxy = decomposition_rulesnode.target # Operations on `Proxy` always yield new `Proxy`s, and the # return value of our decomposition rule is no exception. # We need to extract the underlying `Node` from the `Proxy` # to use it in subsequent iterations of this transform. new_node = output_proxy.node env[node.name] = new_node else: # Default case: we don't have a decomposition rule for this # node, so just copy the node over into the new graph. new_node = new_graph.node_copy(node, lambda x: env[x.name]) env[node.name] = new_node return fx.GraphModule(model, new_graph) ``` 除了避免显式图操作外,使用`Proxy`还允许您将重写规则指定为本机 Python 代码。对于需要大量重写规则的转换(如 vmap 或 grad),这通常可以提高规则的可读性和可维护性。请注意,当调用`Proxy`时,我们还传递了一个指向底层变量图的跟踪器。这样做是为了防止图中的操作是 n 元的情况(例如,add 是一个二元运算符),对`Proxy`的调用不会创建多个图跟踪器的实例,这可能导致意外的运行时错误。我们特别推荐在底层运算符不能安全地假定为一元运算符时使用`Proxy`的这种方法。 关于使用`Proxy`进行`Graph`操作的示例可以在[这里](https://github.com/pytorch/examples/blob/master/fx/proxy_based_graph_creation.py)找到。 ### 解释器模式 FX 中一个有用的代码组织模式是循环遍历`Graph`中的所有`Node`并执行它们。这可以用于多种目的,包括通过使用`Proxy`进行回溯来运行时分析流经图中的值或转换代码。例如,假设我们想要运行一个`GraphModule`并在运行时记录节点上的`torch.Tensor`形状和 dtype 属性。可能看起来像这样: ```py import torch import torch.fx from torch.fx.node import Node from typing import Dict class ShapeProp: """ Shape propagation. This class takes a `GraphModule`. Then, its `propagate` method executes the `GraphModule` node-by-node with the given arguments. As each operation executes, the ShapeProp class stores away the shape and element type for the output values of each operation on the `shape` and `dtype` attributes of the operation's `Node`. """ def __init__(self, mod): self.mod = mod self.graph = mod.graph self.modules = dict(self.mod.named_modules()) def propagate(self, *args): args_iter = iter(args) env : Dict[str, Node] = {} def load_arg(a): return torch.fx.graph.map_arg(a, lambda n: env[n.name]) def fetch_attr(target : str): target_atoms = target.split('.') attr_itr = self.mod for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") attr_itr = getattr(attr_itr, atom) return attr_itr for node in self.graph.nodes: if node.op == 'placeholder': result = next(args_iter) elif node.op == 'get_attr': result = fetch_attr(node.target) elif node.op == 'call_function': result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) elif node.op == 'call_method': self_obj, *args = load_arg(node.args) kwargs = load_arg(node.kwargs) result = getattr(self_obj, node.target)(*args, **kwargs) elif node.op == 'call_module': result = self.modulesnode.target, **load_arg(node.kwargs)) # This is the only code specific to shape propagation. # you can delete this `if` branch and this becomes # a generic GraphModule interpreter. if isinstance(result, torch.Tensor): node.shape = result.shape node.dtype = result.dtype env[node.name] = result return load_arg(self.graph.result) ``` 正如您所看到的,FX 的完整解释器并不复杂,但却非常有用。为了简化使用这种模式,我们提供了`Interpreter`类,它以某种方式包含了上述逻辑,使得解释器执行的某些方面可以通过方法重写来覆盖。 除了执行操作,我们还可以通过将`Proxy`值通过解释器来生成一个新的 Graph。类似地,我们提供了`Transformer`类来包含这种模式。`Transformer`的行为类似于`Interpreter`,但是不是调用`run`方法从模块中获取具体的输出值,而是调用`Transformer.transform()`方法返回一个经过您安装的任何转换规则的新的`GraphModule`。 #### 解释器模式的示例 + [形状传播](https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/shape_prop.py) + [性能分析器](https://github.com/pytorch/tutorials/pull/1319) ## 调试 ### 介绍 在编写转换时,通常我们的代码不会完全正确。在这种情况下,我们可能需要进行一些调试。关键是要向后工作:首先,检查调用生成的模块的结果以证明或证伪正确性。然后,检查和调试生成的代码。然后,调试导致生成的代码的转换过程。 如果您对调试器不熟悉,请参阅辅助部分可用调试器。 ### 转换作者常见的陷阱 + 非确定性的`set`迭代顺序。在 Python 中,`set`数据类型是无序的。例如,使用`set`来包含像`Node`这样的对象集合可能会导致意外的非确定性。一个例子是遍历一组`Node`并将它们插入到`Graph`中。因为`set`数据类型是无序的,输出程序中操作的顺序将是非确定性的,并且可能会在程序调用时发生变化。推荐的替代方案是使用`dict`数据类型,它在 Python 3.7(以及 cPython 3.6)中是[插入有序的](https://mail.python.org/pipermail/python-dev/2017-December/151283.html)。可以通过将要去重的值存储在`dict`的键中来等效地使用`dict`来代替`set`。 ### 检查模块正确性 由于大多数深度学习模块的输出由浮点`torch.Tensor`实例组成,检查两个`torch.nn.Module`的结果是否等价并不像进行简单的相等性检查那样直接。为了激励这一点,让我们举个例子: ```py import torch import torch.fx import torchvision.models as models def transform(m : torch.nn.Module) -> torch.nn.Module: gm = torch.fx.symbolic_trace(m) # Imagine we're doing some transforms here # <...> gm.recompile() return gm resnet18 = models.resnet18() transformed_resnet18 = transform(resnet18) input_image = torch.randn(5, 3, 224, 224) assert resnet18(input_image) == transformed_resnet18(input_image) """ RuntimeError: Boolean value of Tensor with more than one value is ambiguous """ ``` 在这里,我们尝试使用`==`等号操作符来检查两个深度学习模型的值是否相等。然而,这并不是很明确,因为该操作符返回一个张量而不是布尔值,而且由于浮点值的比较应该使用误差边界(或 epsilon)来考虑浮点运算的非交换性(有关更多详细信息,请参见[这里](https://floating-point-gui.de/errors/comparison/))。我们可以使用`torch.allclose()`来进行近似比较,考虑相对和绝对容差阈值: ```py assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image)) ``` 这是我们工具箱中的第一个工具,用来检查转换后的模块是否与参考实现的行为一致。 ### 调试生成的代码 因为 FX 在`GraphModule`上生成`forward()`函数,所以使用传统的调试技术如`print`语句或`pdb`并不那么直接。幸运的是,我们有几种技术可以用来调试生成的代码。 #### 使用`pdb` 调用`pdb`来步入运行中的程序。尽管代表`Graph`的代码不在任何源文件中,但当调用前向传递时,我们仍然可以使用`pdb`手动步入其中。 ```py import torch import torch.fx import torchvision.models as models def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: graph = tracer_class().trace(inp) # Transformation logic here # <...> # Return new Module return fx.GraphModule(inp, graph) my_module = models.resnet18() my_module_transformed = my_pass(my_module) input_value = torch.randn(5, 3, 224, 224) # When this line is executed at runtime, we will be dropped into an # interactive `pdb` prompt. We can use the `step` or `s` command to # step into the execution of the next line import pdb; pdb.set_trace() my_module_transformed(input_value) ``` 打印生成的代码 如果您想多次运行相同的代码,那么使用`pdb`逐步到正确的代码可能有点繁琐。在这种情况下,一个方法是简单地将生成的`forward`传递复制粘贴到您的代码中,并从那里检查它。 ```py # Assume that `traced` is a GraphModule that has undergone some # number of transforms # Copy this code for later print(traced) # Print the code generated from symbolic tracing. This outputs: """ def forward(self, y): x = self.x add_1 = x + y; x = y = None return add_1 """ # Subclass the original Module class SubclassM(M): def __init__(self): super().__init__() # Paste the generated `forward` function (the one we printed and # copied above) here def forward(self, y): x = self.x add_1 = x + y; x = y = None return add_1 # Create an instance of the original, untraced Module. Then, create an # instance of the Module with the copied `forward` function. We can # now compare the output of both the original and the traced version. pre_trace = M() post_trace = SubclassM() ``` #### 使用`GraphModule`中的`to_folder`函数 `GraphModule.to_folder()`是`GraphModule`中的一个方法,允许您将生成的 FX 代码转储到一个文件夹中。尽管将前向传递复制到代码中通常足够,如打印生成的代码,但使用`to_folder`来检查模块和参数可能更容易。 ```py m = symbolic_trace(M()) m.to_folder("foo", "Bar") from foo import Bar y = Bar() ``` 在运行上面的示例之后,我们可以查看`foo/module.py`中的代码,并根据需要进行修改(例如添加`print`语句或使用`pdb`)来调试生成的代码。 ### 调试转换 既然我们已经确定一个转换正在创建不正确的代码,现在是时候调试转换本身了。首先,我们将在文档中检查符号跟踪的限制部分。一旦验证了跟踪工作正常,目标就是弄清楚在我们的`GraphModule`转换过程中出了什么问题。在编写转换中可能会有一个快速答案,但如果没有,有几种方法可以检查我们跟踪的模块: ```py # Sample Module class M(torch.nn.Module): def forward(self, x, y): return x + y # Create an instance of `M` m = M() # Symbolically trace an instance of `M` (returns a GraphModule). In # this example, we'll only be discussing how to inspect a # GraphModule, so we aren't showing any sample transforms for the # sake of brevity. traced = symbolic_trace(m) # Print the code produced by tracing the module. print(traced) # The generated `forward` function is: """ def forward(self, x, y): add = x + y; x = y = None return add """ # Print the internal Graph. print(traced.graph) # This print-out returns: """ graph(): %x : [num_users=1] = placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %add : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) return add """ # Print a tabular representation of the internal Graph. traced.graph.print_tabular() # This gives us: """ opcode name target args kwargs ------------- ------ ----------------------- ------ -------- placeholder x x () {} placeholder y y () {} call_function add (x, y) {} output output output (add,) {} """ ``` 使用上面的实用函数,我们可以比较在应用转换之前和之后的跟踪模块。有时,简单的视觉比较就足以追踪错误。如果仍然不清楚出了什么问题,调试器如`pdb`可能是一个不错的下一步。 根据上面的示例,考虑以下代码: ```py # Sample user-defined function def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module: # Get the Graph from our traced Module g = tracer_class().trace(module) """ Transformations on `g` go here """ return fx.GraphModule(module, g) # Transform the Graph transformed = transform_graph(traced) # Print the new code after our transforms. Check to see if it was # what we expected print(transformed) ``` 使用上面的示例,假设调用`print(traced)`显示我们的转换中存在错误。我们想要使用调试器找出问题所在。我们启动一个`pdb`会话。我们可以通过在`transform_graph(traced)`上设置断点,然后按`s`键“步入”调用`transform_graph(traced)`来查看转换过程中发生了什么。 我们也可以通过编辑`print_tabular`方法来打印图中节点的不同属性来获得好运气。(例如,我们可能想要查看节点的`input_nodes`和`users`。) ### 可用的调试器 最常见的 Python 调试器是[pdb](https://docs.python.org/3/library/pdb.html)。您可以通过在命令行中键入`python -m pdb FILENAME.py`来以`pdb`的“调试模式”启动程序,其中`FILENAME`是您要调试的文件的名称。之后,您可以使用`pdb`[调试器命令](https://docs.python.org/3/library/pdb.html#debugger-commands)逐步浏览正在运行的程序。当您启动`pdb`时设置断点(`b LINE-NUMBER`)是很常见的,然后调用`c`运行程序直到那一点。这样可以避免您必须逐行执行(使用`s`或`n`)以到达您想要检查的代码部分。或者,您可以在要中断的行之前写入`import pdb; pdb.set_trace()`。如果添加了`pdb.set_trace()`,当您运行程序时,它将自动以调试模式启动。 (换句话说,您只需在命令行中键入`python FILENAME.py`,而不是`python -m pdb FILENAME.py`。)运行文件时,您可以通过使用特定命令逐步执行代码并检查程序的内部状态。有许多关于`pdb`的优秀教程在线,包括 RealPython 的[“使用 Pdb 进行 Python 调试”](https://realpython.com/python-debugging-pdb/)。 像 PyCharm 或 VSCode 这样的 IDE 通常内置了调试器。在您的 IDE 中,您可以选择要么 a)通过在 IDE 中打开终端窗口(例如在 VSCode 中选择 View → Terminal)使用`pdb`,要么 b)使用内置调试器(通常是围绕`pdb`的图形包装器)。## 符号跟踪的限制 FX 使用一种**符号跟踪**系统(也称为[符号执行](https://en.wikipedia.org/wiki/Symbolic_execution))来以可转换/可分析的形式捕获程序的语义。该系统是**跟踪**的,因为它执行程序(实际上是一个`torch.nn.Module`或函数)以记录操作。它是**符号**的,因为在执行过程中流经程序的数据不是真实数据,而是符号(FX 术语中的`Proxy`)。 尽管符号跟踪适用于大多数神经网络代码,但它也有一些限制。 ### 动态控制流 符号跟踪的主要限制是它目前不支持*动态控制流*。也就是说,循环或`if`语句的条件可能取决于程序的输入值。 例如,让我们来看下面的程序: ```py def func_to_trace(x): if x.sum() > 0: return torch.relu(x) else: return torch.neg(x) traced = torch.fx.symbolic_trace(func_to_trace) """ <...> File "dyn.py", line 6, in func_to_trace if x.sum() > 0: File "pytorch/torch/fx/proxy.py", line 155, in __bool__ return self.tracer.to_bool(self) File "pytorch/torch/fx/proxy.py", line 85, in to_bool raise TraceError('symbolically traced variables cannot be used as inputs to control flow') torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow """ ``` `if`语句的条件取决于`x.sum()`的值,而`x`的值取决于函数输入。由于`x`可能会改变(即如果您向被追踪的函数传递一个新的输入张量),这就是*动态控制流*。回溯会沿着您的代码向上走,以显示这种情况发生的位置。 #### 静态控制流 另一方面,所谓的*静态控制流*是受支持的。静态控制流是循环或`if`语句,其值在调用之间不会改变。通常,在 PyTorch 程序中,这种控制流是基于超参数对模型架构做出决策的代码。举个具体的例子: ```py import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self, do_activation : bool = False): super().__init__() self.do_activation = do_activation self.linear = torch.nn.Linear(512, 512) def forward(self, x): x = self.linear(x) # This if-statement is so-called static control flow. # Its condition does not depend on any input values if self.do_activation: x = torch.relu(x) return x without_activation = MyModule(do_activation=False) with_activation = MyModule(do_activation=True) traced_without_activation = torch.fx.symbolic_trace(without_activation) print(traced_without_activation.code) """ def forward(self, x): linear_1 = self.linear(x); x = None return linear_1 """ traced_with_activation = torch.fx.symbolic_trace(with_activation) print(traced_with_activation.code) """ import torch def forward(self, x): linear_1 = self.linear(x); x = None relu_1 = torch.relu(linear_1); linear_1 = None return relu_1 """ ``` `if self.do_activation`语句不依赖于任何函数输入,因此它是静态的。`do_activation`可以被视为超参数,具有不同参数值的`MyModule`实例的跟踪代码是不同的。这是一种受符号跟踪支持的有效模式。 许多动态控制流的实例在语义上是静态控制流。通过消除对输入值的数据依赖关系,例如将值移动到`Module`属性或在符号跟踪期间将具体值绑定到参数,可以使这些实例支持符号跟踪: ```py def f(x, flag): if flag: return x else: return x*2 fx.symbolic_trace(f) # Fails! fx.symbolic_trace(f, concrete_args={'flag': True}) ``` 在真正动态控制流的情况下,包含此代码的程序部分可以被追踪为对方法的调用(参见使用 Tracer 类自定义追踪)或函数的调用(参见`wrap()`)而不是通过追踪它们。 ### 非`torch`函数 FX 使用`__torch_function__`作为拦截调用的机制(有关更多信息,请参阅[技术概述](https://github.com/pytorch/pytorch/blob/master/torch/fx/OVERVIEW.md#technical-details))。一些函数,如内置的 Python 函数或`math`模块中的函数,不受`__torch_function__`覆盖,但我们仍希望在符号追踪中捕获它们。例如: ```py import torch import torch.fx from math import sqrt def normalize(x): """ Normalize `x` by the size of the batch dimension """ return x / sqrt(len(x)) # It's valid Python code normalize(torch.rand(3, 4)) traced = torch.fx.symbolic_trace(normalize) """ <...> File "sqrt.py", line 9, in normalize return x / sqrt(len(x)) File "pytorch/torch/fx/proxy.py", line 161, in __len__ raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope """ ``` 错误告诉我们内置函数`len`不受支持。我们可以通过使用`wrap()` API 使这样的函数在追踪中记录为直接调用: ```py torch.fx.wrap('len') torch.fx.wrap('sqrt') traced = torch.fx.symbolic_trace(normalize) print(traced.code) """ import math def forward(self, x): len_1 = len(x) sqrt_1 = math.sqrt(len_1); len_1 = None truediv = x / sqrt_1; x = sqrt_1 = None return truediv """ ``` ### 使用`Tracer`类自定义追踪 `Tracer`类是`symbolic_trace`实现的基础类。通过对 Tracer 进行子类化,可以自定义追踪的行为,如下所示: ```py class MyCustomTracer(torch.fx.Tracer): # Inside here you can override various methods # to customize tracing. See the `Tracer` API # reference pass # Let's use this custom tracer to trace through this module class MyModule(torch.nn.Module): def forward(self, x): return torch.relu(x) + torch.ones(3, 4) mod = MyModule() traced_graph = MyCustomTracer().trace(mod) # trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable traced = torch.fx.GraphModule(mod, traced_graph) ``` #### 叶子模块 叶子模块是在符号追踪中显示为调用而不是被追踪的模块。默认的叶子模块集是标准`torch.nn`模块实例的集合。例如: ```py class MySpecialSubmodule(torch.nn.Module): def forward(self, x): return torch.neg(x) class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 4) self.submod = MySpecialSubmodule() def forward(self, x): return self.submod(self.linear(x)) traced = torch.fx.symbolic_trace(MyModule()) print(traced.code) # `linear` is preserved as a call, yet `submod` is traced though. # This is because the default set of "Leaf Modules" includes all # standard `torch.nn` modules. """ import torch def forward(self, x): linear_1 = self.linear(x); x = None neg_1 = torch.neg(linear_1); linear_1 = None return neg_1 """ ``` 可以通过覆盖`Tracer.is_leaf_module()`来自定义叶子模块的集合。 ### 杂项 + 张量构造函数(例如`torch.zeros`,`torch.ones`,`torch.rand`,`torch.randn`,`torch.sparse_coo_tensor`)目前不可追踪。 + 确定性构造函数(`zeros`,`ones`)可以使用,并且它们产生的值将嵌入到追踪中作为常量。只有当这些构造函数的参数引用动态输入大小时才会出现问题。在这种情况下,`ones_like`或`zeros_like`可能是一个可行的替代品。 + 非确定性构造函数(`rand`,`randn`)将在追踪中嵌入一个随机值。这可能不是预期的行为。一个解决方法是将`torch.randn`包装在`torch.fx.wrap`函数中,并调用该函数。 > ```py > @torch.fx.wrap > def torch_randn(x, shape): > return torch.randn(shape) > > def f(x): > return x + torch_randn(x, 5) > fx.symbolic_trace(f) > ``` + 这种行为可能在未来的版本中被修复。 + 类型注解 + 支持 Python 3 风格的类型注解(例如`func(x : torch.Tensor, y : int) -> torch.Tensor`),并且将被符号追踪保留。 + Python 2 风格的注释类型注解`# type: (torch.Tensor, int) -> torch.Tensor`目前不受支持。 + 目前不支持函数内部局部名称的注释。 + 关于`training`标志和子模块的注意事项 + 当使用像`torch.nn.functional.dropout`这样的函数时,通常会将训练参数作为`self.training`传入。在 FX 追踪期间,这可能会被固定为一个常量值。 > ```py > import torch > import torch.fx > > class DropoutRepro(torch.nn.Module): > def forward(self, x): > return torch.nn.functional.dropout(x, training=self.training) > > traced = torch.fx.symbolic_trace(DropoutRepro()) > print(traced.code) > """ > def forward(self, x): > dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None > return dropout > """ > > traced.eval() > > x = torch.randn(5, 3) > torch.testing.assert_close(traced(x), x) > """ > AssertionError: Tensor-likes are not close! > > Mismatched elements: 15 / 15 (100.0%) > Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) > Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) > """ > ``` + 然而,当使用标准的`nn.Dropout()`子模块时,训练标志被封装,并且由于`nn.Module`对象模型的保留,可以被更改。 > ```py > class DropoutRepro2(torch.nn.Module): > def __init__(self): > super().__init__() > self.drop = torch.nn.Dropout() > > def forward(self, x): > return self.drop(x) > > traced = torch.fx.symbolic_trace(DropoutRepro2()) > print(traced.code) > """ > def forward(self, x): > drop = self.drop(x); x = None > return drop > """ > > traced.eval() > > x = torch.randn(5, 3) > torch.testing.assert_close(traced(x), x) > ``` > > + 由于这种差异,请考虑将与`training`标志动态交互的模块标记为叶子模块。 ## API 参考 ```py torch.fx.symbolic_trace(root, concrete_args=None) ``` 符号追踪 API 给定一个`nn.Module`或函数实例`root`,此函数将返回一个通过记录在`root`中追踪到的操作而构建的`GraphModule`。 `concrete_args`允许您部分特化您的函数,无论是为了删除控制流还是数据结构。 例如: ```py def f(a, b): if b == True: return a else: return a*2 ``` 由于存在控制流,FX 通常无法追踪到这一点。但是,我们可以使用 concrete_args 来专门针对 b 的值进行追踪: ```py f = fx.symbolic_trace(f, concrete_args={'b': False}) assert f(3, False) == 6 ``` 请注意,尽管您仍然可以传入不同的 b 值,但它们将被忽略。 我们还可以使用 concrete_args 来消除函数中的数据结构处理。这将使用 pytrees 来展平您的输入。为了避免过度特化,传递 fx.PH 以表示不应特化的值。例如: ```py def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) assert f({'a': 1, 'b': 2, 'c': 4}) == 7 ``` 参数 + **root**(*Union***[*torch.nn.Module**,* *Callable**]*) - 要跟踪并转换为图形表示的模块或函数。 + **concrete_args**(*可选**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *any**]**]*) - 要部分特化的输入 返回 从`root`记录的操作创建的模块。 返回类型 GraphModule 注意 此 API 的向后兼容性已得到保证。 ```py torch.fx.wrap(fn_or_name) ``` 这个函数可以在模块级别范围内调用,将 fn_or_name 注册为“叶函数”。 “叶函数”将作为 CallFunction 节点保留在 FX 跟踪中,而不是被跟踪通过: ```py # foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap('my_custom_function') def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y) ``` 这个函数也可以等效地用作装饰器: ```py # foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y ``` 包装函数可以被视为“叶函数”,类似于“叶模块”的概念,即它们是在 FX 跟踪中作为调用保留的函数,而不是被跟踪。 参数 **fn_or_name**(*Union**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *Callable**]*) - 要在调用时插入到图中的全局函数或名称 注意 此 API 的向后兼容性已得到保证。 ```py class torch.fx.GraphModule(*args, **kwargs) ``` GraphModule 是从 fx.Graph 生成的 nn.Module。Graphmodule 具有`graph`属性,以及从该`graph`生成的`code`和`forward`属性。 警告 当重新分配`graph`时,`code`和`forward`将自动生成。但是,如果您编辑`graph`的内容而不重新分配`graph`属性本身,则必须调用`recompile()`以更新生成的代码。 注意 此 API 的向后兼容性已得到保证。 ```py __init__(root, graph, class_name='GraphModule') ``` 构建一个 GraphModule。 参数 + **root**(*Union***[*torch.nn.Module**,* *Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *Any**]*) - `root`可以是 nn.Module 实例或将字符串映射到任何属性类型的字典。如果`root`是一个 Module,那么 Graph 的 Nodes 的`target`字段中对 Module-based 对象(通过限定名称)的任何引用将从`root`的 Module 层次结构中的相应位置复制到 GraphModule 的模块层次结构中。如果`root`是一个字典,那么在 Node 的`target`中找到的限定名称将直接在字典的键中查找。字典映射到的对象将被复制到 GraphModule 的模块层次结构中的适当位置。 + **graph**(*Graph*) - `graph`包含此 GraphModule 应用于代码生成的节点 + **class_name**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")) - `name`表示此 GraphModule 的名称,用于调试目的。如果未设置,所有错误消息将报告为源自`GraphModule`。将此设置为`root`的原始名称或在转换上下文中有意义的名称可能有助于。 注意 此 API 的向后兼容性已得到保证。 ```py add_submodule(target, m) ``` 将给定的子模块添加到`self`。 如果它们是`target`的子路径,将安装空的 Modules。 参数 + **target**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")) - 新子模块的完全限定字符串名称(请参阅`nn.Module.get_submodule`中的示例,了解如何指定完全限定字符串。) + **m**(*Module*) - 子模块本身;我们要安装在当前 Module 中的实际对象 返回 子模块是否可以插入。对于 为了使此方法返回 True,`target`表示的链中的每个对象必须是 a)尚不存在,或 b)引用`nn.Module`(而不是参数或其他属性) 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)") 注意 此 API 的向后兼容性是有保证的。 ```py property code: str ``` 返回从支持此`GraphModule`的`Graph`生成的 Python 代码。 ```py delete_all_unused_submodules() ``` 从`self`中删除所有未使用的子模块。 如果满足以下任一条件,则将模块视为“已使用”:1. 其子模块已被使用 2. 其 forward 直接通过`call_module`节点调用 3. 它具有一个从`get_attr`节点使用的非 Module 属性 可以调用此方法来清理`nn.Module`,而无需手动在每个未使用的子模块上调用`delete_submodule`。 注意 此 API 的向后兼容性是有保证的。 ```py delete_submodule(target) ``` 从`self`中删除给定的子模块。 如果`target`不是有效目标,则不会删除模块。 参数 **target**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")) – 新子模块的完全限定字符串名称(请参阅`nn.Module.get_submodule`中的示例,了解如何指定完全限定字符串。) 返回 无论目标字符串是否引用了 我们要删除的子模块。返回值为`False`表示`target`不是对子模块的有效引用。 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)") 注意 此 API 的向后兼容性是有保证的。 ```py property graph: Graph ``` 返回此`GraphModule`的基础`Graph` ```py print_readable(print_output=True) ``` 返回为当前`GraphModule`及其子`GraphModule`生成的 Python 代码 警告 此 API 是实验性的,*不*向后兼容。 ```py recompile() ``` 重新编译此`GraphModule`从其`graph`属性。在编辑包含的`graph`后应调用此方法,否则此`GraphModule`的生成代码将过时。 注意 此 API 的向后兼容性是有保证的。 返回类型 *PythonCode* ```py to_folder(folder, module_name='FxModule') ``` 将模块转储到带有`module_name`的`folder`中,以便可以 通过`from import `导入 参数: > folder(Union[str, os.PathLike]):要将代码写入的文件夹 > > module_name(str):用于`Module`的顶级名称 > > 写出代码 警告 此 API 是实验性的,*不*向后兼容。 ```py class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None) ``` `Graph`是 FX 中间表示中使用的主要数据结构。它由一系列`Node`组成,每个`Node`代表调用点(或其他语法结构)。一起取出的`Node`列表构成一个有效的 Python 函数。 例如,以下代码 ```py import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m) ``` 将生成以下图表: ```py print(gm.graph) ``` ```py graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_functiontarget=operator.add, kwargs = {}) %linear_1 : [num_users=1] = call_moduletarget=linear, kwargs = {}) %relu_1 : [num_users=1] = call_methodtarget=relu, kwargs = {}) %sum_1 : [num_users=1] = call_functiontarget=torch.sum, kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_functiontarget=torch.topk, kwargs = {}) return topk_1 ``` 有关在`Graph`中表示的操作的语义,请参见`Node`。 注意 此 API 的向后兼容性是有保证的。 ```py __init__(owning_module=None, tracer_cls=None, tracer_extras=None) ``` 构建一个空的图表。 注意 此 API 的向后兼容性是有保证的。 ```py call_function(the_function, args=None, kwargs=None, type_expr=None) ``` 将`call_function` `Node`插入`Graph`。`call_function`节点表示对由`the_function`指定的 Python 可调用对象的调用。 参数 + **the_function**(*Callable**[**...**,* *Any**]*) – 要调用的函数。可以是任何 PyTorch 运算符、Python 函数或`builtins`或`operator`命名空间的成员。 + **args**(*可选**[**Tuple**[**Argument**,* *...**]**]*) – 要传递给调用函数的位置参数。 + **kwargs**(*可选**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")*,* *Argument**]**]*) – 要传递给调用函数的关键字参数 + **type_expr**(*可选**[**Any**]*) – 一个可选的类型注释,表示此节点的输出将具有的 Python 类型。 返回 新创建并插入的`call_function`节点。 返回类型 *Node* 注意 此方法的插入点和类型表达式规则与`Graph.create_node()`相同。 注意 此 API 的向后兼容性得到保证。 ```py call_method(method_name, args=None, kwargs=None, type_expr=None) ``` 将`call_method` `Node`插入`Graph`中。`call_method`节点表示在`args`的第 0 个元素上调用给定方法。 参数 + **method_name** ([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) – 要应用于 self 参数的方法的名称。例如,如果 args[0]是表示`Tensor`的`Node`,则要在该`Tensor`上调用`relu()`,请将`relu`传递给`method_name`。 + **args** (*Optional**[**Tuple**[**Argument**,* *...**]**]*) – 要传递给调用方法的位置参数。请注意,这*应该*包括一个`self`参数。 + **kwargs** (*Optional**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Argument**]**]*) – 要传递给调用方法的关键字参数 + **type_expr** (*Optional**[**Any**]*) – 一个可选的类型注释,表示此节点的输出将具有的 Python 类型。 返回 新创建并插入的`call_method`节点。 返回类型 *Node* 注意 此方法的插入点和类型表达式规则与`Graph.create_node()`相同。 注意 此 API 的向后兼容性得到保证。 ```py call_module(module_name, args=None, kwargs=None, type_expr=None) ``` 将`call_module` `Node`插入`Graph`中。`call_module`节点表示在`Module`层次结构中的`Module`的 forward()函数的调用。 参数 + **module_name** ([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) – 要调用的`Module`层次结构中的`Module`的限定名称。例如,如果被跟踪的`Module`有一个名为`foo`的子模块,该子模块有一个名为`bar`的子模块,则应将限定名称`foo.bar`作为`module_name`传递以调用该模块。 + **args** (*Optional**[**Tuple**[**Argument**,* *...**]**]*) – 要传递给调用方法的位置参数。请注意,这不应该包括一个`self`参数。 + **kwargs** (*Optional**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Argument**]**]*) – 要传递给调用方法的关键字参数 + **type_expr** (*Optional**[**Any**]*) – 一个可选的类型注释,表示此节点的输出将具有的 Python 类型。 返回 新创建并插入的`call_module`节点。 返回类型 *Node* 注意 此方法的插入点和类型表达式规则与`Graph.create_node()`相同。 注意 此 API 的向后兼容性得到保证。 ```py create_node(op, target, args=None, kwargs=None, name=None, type_expr=None) ``` 创建一个`Node`并将其添加到当前插入点的`Graph`。请注意,当前插入点可以通过`Graph.inserting_before()`和`Graph.inserting_after()`进行设置。 参数 + **op** ([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) – 此节点的操作码。其中之一是‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’或‘output’。这些操作码的语义在`Graph`文档中有描述。 + **args** (*Optional**[**Tuple**[**Argument**,* *...**]**]*) – 这是传递给此节点的参数元组。 + **kwargs** (*Optional**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Argument**]**]*) – 此节点的关键字参数 + **name** (*Optional**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*]*) – `Node`的可选字符串名称。这将影响在 Python 生成的代码中分配给值的名称。 + **type_expr**(*Optional**[**Any**]*) - 表示此节点输出的 Python 类型的可选类型注释。 返回 新创建并插入的节点。 返回类型 *Node* 注意 此 API 的向后兼容性已得到保证。 ```py eliminate_dead_code() ``` 从图中删除所有死代码,基于每个节点的用户数量以及节点是否具有任何副作用。在调用之前,图必须进行拓扑排序。 返回 图是否因此传递而更改。 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)") 示例: 在消除死代码之前,下面的 a = x + 1 没有用户,因此可以从图中消除而不产生影响。 ```py def forward(self, x): a = x + 1 return x + self.attr_1 ``` 消除死代码后,a = x + 1 已被移除,前进的其余部分保留。 ```py def forward(self, x): return x + self.attr_1 ``` 警告 死代码消除具有一些启发式方法,以避免删除具有副作用的节点(请参见 Node.is_impure),但总体覆盖率非常差,因此您应该假设除非您知道您的 FX 图完全由功能操作组成,否则不应调用此方法。 注意 此 API 的向后兼容性已得到保证。 ```py erase_node(to_erase) ``` 从`Graph`中擦除`Node`。如果在`Graph`中仍然有该节点的用户,则会引发异常。 参数 **to_erase**(*Node*) - 要从`Graph`中删除的`Node`。 注意 此 API 的向后兼容性已得到保证。 ```py get_attr(qualified_name, type_expr=None) ``` 将`get_attr`节点插入图中。`get_attr` `Node`表示从`Module`层次结构中获取属性。 参数 + **qualified_name**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) - 要检索的属性的完全限定名称。例如,如果被跟踪的模块有一个名为`foo`的子模块,该子模块有一个名为`bar`的子模块,该子模块有一个名为`baz`的属性,则应将限定名称`foo.bar.baz`作为`qualified_name`传递。 + **type_expr**(*Optional**[**Any**]*) - 表示此节点输出的 Python 类型的可选类型注释。 返回 新创建并插入的`get_attr`节点。 返回类型 *Node* 注意 与`Graph.create_node`方法相同的插入点和类型表达式规则适用于此方法。 注意 此 API 的向后兼容性已得到保证。 ```py graph_copy(g, val_map, return_output_node=False) ``` 将给定图中的所有节点复制到`self`中。 参数 + **g**(*Graph*) - 要将节点从中复制到`self`的源图。 + **val_map**(*Dict***[*Node**,* *Node**]*) - 一个将从`g`中的节点映射到`self`中的节点的映射的字典。请注意,`val_map`可以传入已经具有值的值,以覆盖某些值的复制。 返回 现在在`self`中的值等同于`g`中的输出值,如果`g`有一个`output`节点。否则为`None`。 返回类型 [*Optional*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12 中)")[[*Union*](https://docs.python.org/3/library/typing.html#typing.Union "(在 Python v3.12 中)")[[*Tuple*](https://docs.python.org/3/library/typing.html#typing.Tuple "(在 Python v3.12 中)")[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)"), …], [*List*](https://docs.python.org/3/library/typing.html#typing.List "(在 Python v3.12 中)")[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)")], [*Dict*](https://docs.python.org/3/library/typing.html#typing.Dict "(在 Python v3.12 中)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)"), [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)")], [slice](https://docs.python.org/3/library/functions.html#slice "(在 Python v3.12 中)"), [range](https://docs.python.org/3/library/stdtypes.html#range "(在 Python v3.12 中)"), Node, [str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)"), [int](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)"), [float](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12 中)"), [bool](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)"), [complex](https://docs.python.org/3/library/functions.html#complex "(在 Python v3.12 中)"), *dtype*, *Tensor*, *device*, *memory_format*, *layout*, *OpOverload*]] 注意 此 API 的向后兼容性已得到保证。 ```py inserting_after(n=None) ``` 设置 create_node 和 companion 方法将插入到图中的位置。 在‘with’语句中使用时,这将临时设置插入点,然后在退出 with 语句时恢复它: ```py with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently ``` 参数: > n(可选[节点]):要插入的节点之前。如果为 None,则将在其之后插入 > > 整个图的开始。 返回: 将在`__exit__`上恢复插入点的资源管理器。 注意 此 API 的向后兼容性已得到保证。 ```py inserting_before(n=None) ``` 设置 create_node 和 companion 方法将插入到图中的位置。 在‘with’语句中使用时,这将临时设置插入点,然后在退出 with 语句时恢复它: ```py with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently ``` 参数: > n(可选[节点]):要插入的节点之前。如果为 None,则将在其之前插入 > > 整个图的开始。 返回: 将在`__exit__`上恢复插入点的资源管理器。 注意 此 API 的向后兼容性已得到保证。 ```py lint() ``` 对此图运行各种检查,以确保其形式良好。特别是:-检查节点具有正确的所有权(由此图拥有)-检查节点按拓扑顺序出现-如果此图具有拥有的 GraphModule,则检查目标是否存在于该 GraphModule 中 注意 此 API 的向后兼容性已得到保证。 ```py node_copy(node, arg_transform=>) ``` 将一个图中的节点复制到另一个图中。`arg_transform`需要将节点图中的参数转换为 self 图中的参数。例如: ```py # Copying all the nodes in `g` into `new_graph` g : torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) ``` 参数 + **node**(*Node*)–要复制到`self`中的节点。 + **arg_transform**(*Callable****[*[*Node**]**,* *Argument**]*) – 一个函数,将节点的`args`和`kwargs`中的`Node`参数转换为`self`中的等效参数。在最简单的情况下,这应该从将原始图中的节点映射到`self`的表中检索值。 返回类型 *Node* 注意 此 API 的向后兼容性已得到保证。 ```py property nodes: _node_list ``` 获取构成此图的节点列表。 请注意,此`Node`列表表示是一个双向链表。在迭代期间进行突变(例如删除节点,添加节点)是安全的。 返回 节点的双向链表。请注意,可以在此列表上调用`reversed`以切换迭代顺序。 ```py on_generate_code(make_transformer) ``` 在生成 Python 代码时注册一个转换器函数 > 参数: > > make_transformer(Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): > > 返回一个要注册的代码转换器的函数。此函数由 on_generate_code 调用以获取代码转换器。 > > 此函数还将当前注册的代码转换器(如果没有注册任何内容,则为 None)作为其输入,以防不希望覆盖它。这对于链接代码转换器很有用。 > > 返回: > > 一个上下文管理器,当在 with 语句中使用时,自动恢复先前注册的代码转换器。 > > 例子: > > ```py > gm: fx.GraphModule = ... > > # This is a code transformer we want to register. This code > # transformer prepends a pdb import and trace statement at the very > # beginning of the generated torch.fx code to allow for manual > # debugging with the PDB library. > def insert_pdb(body): > return ["import pdb; pdb.set_trace()\n", *body] > > # Registers `insert_pdb`, and overwrites the current registered > # code transformer (given by `_` to the lambda): > gm.graph.on_generate_code( > lambda _: insert_pdb > ) > > # Or alternatively, registers a code transformer which first > # runs `body` through existing registered transformer, then > # through `insert_pdb`: > gm.graph.on_generate_code( > lambda current_trans: ( > lambda body: insert_pdb( > current_trans(body) if current_trans > else body > ) > ) > ) > > gm.recompile() > gm(*inputs) # drops into pdb > ``` > > 此函数也可以作为上下文管理器使用,有助于自动恢复先前注册的代码转换器: > > ```py > # ... continue from previous example > > with gm.graph.on_generate_code(lambda _: insert_pdb): > # do more stuff with `gm`... > gm.recompile() > gm(*inputs) # drops into pdb > > # now previous code transformer is restored (but `gm`'s code with pdb > # remains - that means you can run `gm` with pdb here too, until you > # run next `recompile()`). > ``` 警告 此 API 是实验性的,*不*向后兼容。 ```py output(result, type_expr=None) ``` 将`output` `Node`插入`Graph`。`output`节点表示 Python 代码中的`return`语句。`result`是应该返回的值。 参数 + **result**(*Argument*) - 要返回的值。 + **type_expr**(*Optional**[**Any**]*) - 表示此节点输出的 Python 类型的可选类型注释。 注意 与`Graph.create_node`方法相同的插入点和类型表达式规则适用于此方法。 注意 此 API 的向后兼容性得到保证。 ```py placeholder(name, type_expr=None, default_value) ``` 将`placeholder`节点插入 Graph。`placeholder`表示函数输入。 参数 + **name**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")) - 输入值的名称。这对应于此`Graph`表示的函数的位置参数的名称。 + **type_expr**(*Optional**[**Any**]*) - 表示此节点输出的 Python 类型的可选类型注释。在某些情况下,这是必要的,以便进行正确的代码生成(例如,当函数随后在 TorchScript 编译中使用时)。 + **default_value**(*Any*) - 此函数参数应该采用的默认值。注意:为了允许 None 作为默认值,应将 inspect.Signature.empty 作为此参数传递,以指定该参数没有默认值。 返回类型 *Node* 注意 与`Graph.create_node`方法相同的插入点和类型表达式规则适用于此方法。 注意 此 API 的向后兼容性得到保证。 ```py print_tabular() ``` 以表格形式打印图的中间表示。请注意,此 API 需要安装`tabulate`模块。 注意 此 API 的向后兼容性得到保证。 ```py process_inputs(*args) ``` 处理参数,以便它们可以传递给 FX 图。 警告 此 API 是实验性的,*不*向后兼容。 ```py process_outputs(out) ``` 警告 此 API 是实验性的,*不*向后兼容。 ```py python_code(root_module, *, verbose=False) ``` 将此`Graph`转换为有效的 Python 代码。 参数 **root_module**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12 中)")) - 要查找限定名称目标的根模块的名称。通常为“self”。 返回 src:表示对象全局名称的 Python 源代码 globals:src 中的全局名称的字典->它们引用的对象。 返回类型 一个 PythonCode 对象,由两个字段组成 注意 此 API 的向后兼容性得到保证。 ```py set_codegen(codegen) ``` 警告 此 API 是实验性的,*不*向后兼容。 ```py class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None) ``` `Node`是表示`Graph`中各个操作的数据结构。在大多数情况下,节点表示对各种实体的调用点,例如运算符、方法和模块(一些例外包括指定函数输入和输出的节点)。每个`Node`都有一个由其`op`属性指定的函数。对于`op`的每个值,`Node`的语义如下: + `placeholder` 代表一个函数输入。`name` 属性指定此值将采用的名称。`target` 同样是参数的名称。`args` 包含:1) 什么都没有,或者 2) 一个表示函数输入的默认参数的单个参数。`kwargs` 不重要。占位符对应于图形打印输出中的函数参数(例如 `x`)。 + `get_attr` 从模块层次结构中检索参数。`name` 同样是结果的名称。`target` 是参数在模块层次结构中的完全限定名称。`args` 和 `kwargs` 不重要。 + `call_function` 将一个自由函数应用到一些值上。`name` 同样是要分配的值的名称。`target` 是要应用的函数。`args` 和 `kwargs` 代表函数的参数,遵循 Python 的调用约定。 + `call_module` 将模块在模块层次结构的 `forward()` 方法中应用到给定的参数上。`name` 与之前相同。`target` 是要调用的模块在模块层次结构中的完全限定名称。`args` 和 `kwargs` 代表要在模块上调用的参数,*不包括 self 参数*。 + `call_method` 在一个值上调用一个方法。`name` 与之前相似。`target` 是要应用于 `self` 参数的方法的字符串名称。`args` 和 `kwargs` 代表要在模块上调用的参数,*包括 self 参数*。 + `output` 包含跟踪函数的输出,在其 `args[0]` 属性中。这对应于图形打印输出中的“return”语句。 注意 此 API 的向后兼容性是有保证的。 ```py property all_input_nodes: List[Node] ``` 返回是此节点的输入的所有节点。这相当于迭代 `args` 和 `kwargs` 并仅收集是节点的值。 返回 在这个 `Node` 的 `args` 和 `kwargs` 中出现的 `Nodes` 列表,按顺序排列。 ```py append(x) ``` 在此节点后插入 `x` 到图中的节点列表中。等同于 `self.next.prepend(x)` 参数 **x** (*Node*) – 要在此节点后放置的节点。必须是同一图中的成员。 注意 此 API 的向后兼容性是有保证的。 ```py property args: Tuple[Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload]], ...] ``` 这个 `Node` 的参数元组。参数的解释取决于节点的操作码。查看 `Node` 文档字符串以获取更多信息。 允许对此属性进行赋值。在赋值时,所有使用和用户的计数都会自动更新。 ```py format_node(placeholder_names=None, maybe_return_typename=None) ``` 返回 `self` 的描述性字符串表示。 此方法可以作为调试工具而无需参数使用。 此函数还在 `Graph` 的 `__str__` 方法中内部使用。`placeholder_names` 和 `maybe_return_typename` 中的字符串一起构成了此 Graph 的周围 GraphModule 中自动生成的 `forward` 函数的签名。否则不应使用 `placeholder_names` 和 `maybe_return_typename`。 参数 + **placeholder_names** ([*可选*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12)")*[*[*List*](https://docs.python.org/3/library/typing.html#typing.List "(在 Python v3.12)")*[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")*]**]*) – 一个列表,将存储表示生成的 `forward` 函数中占位符的格式化字符串。仅供内部使用。 + **maybe_return_typename** ([*可选*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12)")*[*[*List*](https://docs.python.org/3/library/typing.html#typing.List "(在 Python v3.12)")*[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")*]**]*) – 一个单元素列表,将存储表示生成的 `forward` 函数输出的格式化字符串。仅供内部使用。 返回 如果 1) 我们在内部使用 `format_node` 作为辅助函数 在`Graph`的`__str__`方法中,以及 2)`self`是占位符节点,则返回`None`。否则,返回当前 Node 的描述性字符串表示。 返回类型 [str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)") 注意 此 API 的向后兼容性得到保证。 ```py insert_arg(idx, arg) ``` 在给定索引处向参数列表中插入一个位置参数。 参数 + **idx**([*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12)"))-要在`self.args`中插入的元素的索引。 + **arg**(*参数*)-要插入`args`的新参数值 注意 此 API 的向后兼容性得到保证。 ```py is_impure() ``` 返回此操作是否不纯净,即其操作是占位符或输出,或者是不纯净的 call_function 或 call_module。 返回 如果操作是不纯净的。 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12)") 警告 此 API 是实验性的,*不*向后兼容。 ```py property kwargs: Dict[str, Optional[Union[Tuple[Any, ...], List[Any], Dict[str, Any], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload]]] ``` 关键字参数的字典到此`Node`。参数的解释取决于节点的操作码。有关更多信息,请参阅`Node`文档字符串。 允许对此属性进行赋值。所有使用和用户的计算都会在分配时自动更新。 ```py property next: Node ``` 返回链接列表中的下一个`Node`。 返回 链接列表中的下一个`Node`。 ```py normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False) ``` 将参数标准化为 Python 目标的参数。这意味着 args/kwargs 将与模块/功能的签名匹配,并且如果 normalize_to_only_use_kwargs 为 true,则仅以位置顺序返回 kwargs。还填充默认值。不支持仅位置参数或可变参数。 支持模块调用。 可能需要 arg_types 和 kwarg_types 以消除重载。 参数 + **root**(*torch.nn.Module*)-要解析模块目标的模块。 + **arg_types**(*可选**[**Tuple**[**Any**]**]*)-args 的参数类型元组 + **kwarg_types**(*可选**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)")*,* *Any**]**]*)-kwargs 的参数类型字典 + **normalize_to_only_use_kwargs**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12)"))-是否规范化为仅使用 kwargs。 返回 返回命名元组 ArgsKwargsPair,如果不成功则返回 None。 返回类型 [*可选*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12)")[*ArgsKwargsPair*] 警告 此 API 是实验性的,*不*向后兼容。 ```py prepend(x) ``` 在图中的节点列表中的此节点之前插入 x。示例: ```py Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax ``` 参数 **x**(*Node*)-要在此节点之前放置的节点。必须是同一图的成员。 注意 此 API 的向后兼容性得到保证。 ```py property prev: Node ``` 返回链接列表中的上一个`Node`。 返回 链接列表中的上一个`Node`。 ```py replace_all_uses_with(replace_with, delete_user_cb=>, *, propagate_meta=False) ``` 用 Node`replace_with`替换图中所有`self`的使用。 参数 + **replace_with**(*Node*)-要用`self`替换所有使用的节点。 + **delete_user_cb**(*可调用*)-用于确定是否应删除 self 节点的给定用户的回调。 + **propagate_meta**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12)"))-是否复制原始节点的.meta 字段上的所有属性到替换节点。出于安全考虑,只有在替换节点尚未具有现有.meta 字段时才能这样做。 返回 进行此更改的节点列表。 返回类型 [*列表*](https://docs.python.org/3/library/typing.html#typing.List "(在 Python v3.12)")[*Node*] 注意 此 API 的向后兼容性得到保证。 ```py replace_input_with(old_input, new_input) ``` 循环遍历`self`的输入节点,并用`new_input`替换所有`old_input`的实例。 参数 + **old_input**(*Node*)-要替换的旧输入节点。 + **new_input**(*Node*)-要替换`old_input`的新输入节点。 注意 此 API 的向后兼容性已得到保证。 ```py property stack_trace: Optional[str] ``` 返回在跟踪期间记录的 Python 堆栈跟踪,如果有的话。使用 fx.Tracer 跟踪时,此属性通常由 Tracer.create_proxy 填充。为了记录跟踪期间的堆栈跟踪以进行调试,可以在 Tracer 实例上设置 record_stack_traces = True。使用 dynamo 跟踪时,此属性将默认由 OutputGraph.create_proxy 填充。 stack_trace 将在字符串末尾具有最内部的帧。 ```py update_arg(idx, arg) ``` 更新现有的位置参数以包含新值`arg`。调用后,`self.args[idx] == arg`。 参数 + **idx**([*int*](https://docs.python.org/3/library/functions.html#int "(in Python v3.12)"))-要更新的元素在`self.args`中的索引 + **arg**(*参数*)-要写入`args`的新参数值 注意 此 API 的向后兼容性已得到保证。 ```py update_kwarg(key, arg) ``` 更新现有的关键字参数以包含新值`arg`。调用后,`self.kwargs[key] == arg`。 参数 + **key**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"))-要更新的元素在`self.kwargs`中的键 + **arg**(*参数*)-要写入`kwargs`的新参数值 注意 此 API 的向后兼容性已得到保证。 ```py class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=()) ``` > `Tracer`是实现`torch.fx.symbolic_trace`的符号跟踪功能的类。对`symbolic_trace(m)`的调用等效于`Tracer().trace(m)`。 > > 跟踪器可以被子类化以覆盖跟踪过程的各种行为。可以覆盖的不同行为在此类的方法的文档字符串中描述。 注意 此 API 的向后兼容性已得到保证。 ```py call_module(m, forward, args, kwargs) ``` 指定此`Tracer`在遇到对`nn.Module`实例的调用时的行为的方法。 默认情况下,行为是检查所调用的模块是否是叶模块,通过`is_leaf_module`。如果是,则在`Graph`中引用`m`发出`call_module`节点。否则,正常调用`Module`,跟踪其`forward`函数中的操作。 可以重写此方法,例如创建嵌套的跟踪 GraphModules,或者在跨`Module`边界跟踪时希望的任何其他行为。 参数 + **m**(*Module*)-正在发出调用的模块 + **forward**(*可调用*)-要调用的`Module`的 forward()方法 + **args**(*元组*)-模块调用站点的 args + **kwargs**(*Dict*)-模块调用站点的 kwargs 返回 模块调用的返回值。在发出`call_module`节点的情况下,这是一个`Proxy`值。否则,它是从`Module`调用返回的任何值。 返回类型 [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)") 注意 此 API 的向后兼容性已得到保证。 ```py create_arg(a) ``` 一种方法,用于指定在准备将值用作图中节点的参数时跟踪的行为。 默认情况下,行为包括: 1. 遍历集合类型(例如元组、列表、字典)并在元素上递归调用`create_args`。 1. 给定一个代理对象,返回对底层 IR`Node`的引用 1. 给定一个非代理张量对象,发出各种情况的 IR: > + 对于参数,发出引用该参数的`get_attr`节点 > + > + 对于非参数张量,将张量存储在一个特殊属性中,该属性指向该属性。 可以重写此方法以支持更多类型。 参数 **a**(*任意*)-要在`Graph`中作为`Argument`发出的值。 返回 将值`a`转换为适当的`Argument` 返回类型 [*Optional*](https://docs.python.org/3/library/typing.html#typing.Optional "(in Python v3.12)")[[*Union*](https://docs.python.org/3/library/typing.html#typing.Union "(in Python v3.12)")[[*Tuple*](https://docs.python.org/3/library/typing.html#typing.Tuple "(in Python v3.12)")[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)"), …], [*List*](https://docs.python.org/3/library/typing.html#typing.List "(in Python v3.12)")[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)")], [*Dict*](https://docs.python.org/3/library/typing.html#typing.Dict "(in Python v3.12)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)")], [slice](https://docs.python.org/3/library/functions.html#slice "(in Python v3.12)"), [range](https://docs.python.org/3/library/stdtypes.html#range "(in Python v3.12)"), Node, [str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)"), [int](https://docs.python.org/3/library/functions.html#int "(in Python v3.12)"), [float](https://docs.python.org/3/library/functions.html#float "(in Python v3.12)"), [bool](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)"), [complex](https://docs.python.org/3/library/functions.html#complex "(in Python v3.12)"), *dtype*, *Tensor*, *device*, *memory_format*, *layout*, *OpOverload*]] 注意 此 API 的向后兼容性已经得到保证。 ```py create_args_for_root(root_fn, is_module, concrete_args=None) ``` 创建与`root`模块的签名对应的`placeholder`节点。此方法审查 root 的签名并相应地发出这些节点,还支持`*args`和`**kwargs`。 警告 此 API 是实验性的,*不*向后兼容。 ```py create_node(kind, target, args, kwargs, name=None, type_expr=None) ``` 插入一个给定目标、args、kwargs 和名称的图节点。 可以重写此方法以执行额外的检查、验证或修改用于节点创建的值。例如,可能希望禁止记录原地操作。 注意 此 API 的向后兼容性已经得到保证。 返回类型 *Node* ```py create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) ``` 从给定的参数创建一个节点,然后返回包装在代理对象中的节点。 如果 kind = `placeholder`,那么我们正在创建一个代表函数参数的节点。如果我们需要编码默认参数,我们使用`args`元组。对于`placeholder`节点,`args`否则为空。 注意 此 API 的向后兼容性已经得到保证。 ```py getattr(attr, attr_val, parameter_proxy_cache) ``` 指定当我们在调用`nn.Module`实例的`getattr`时,此`Tracer`的行为的方法。 默认情况下,行为是返回属性的代理值。它还将代理值存储在`parameter_proxy_cache`中,以便将来的调用将重用代理而不是创建新的代理。 可以重写此方法,例如,在查询参数时不返回代理。 参数 + **attr** ([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) – 正在查询的属性的名称 + **attr_val** (*Any*) – 属性的值 + **parameter_proxy_cache** (*Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *Any**]*) – 一个属性名称到代理的缓存 返回 从 getattr 调用的返回值。 警告 此 API 是实验性的,*不*向后兼容。 ```py is_leaf_module(m, module_qualified_name) ``` 指定给定的`nn.Module`是否是“叶子”模块的方法。 叶模块是出现在 IR 中的原子单位,由`call_module`调用引用。默认情况下,PyTorch 标准库命名空间(torch.nn)中的模块是叶模块。除非通过此参数另有规定,否则将跟踪所有其他模块,并记录其组成操作。 参数 + **m**(*Module*) - 被查询的模块 + **module_qualified_name**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) - 此模块的根路径。例如,如果您有一个模块层次结构,其中子模块`foo`包含子模块`bar`,后者包含子模块`baz`,那么该模块将以限定名称`foo.bar.baz`显示在此处。 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)") 注意 此 API 的向后兼容性得到保证。 ```py iter(obj) ``` 当代理对象被迭代时调用,例如 在控制流中使用时。通常我们不知道该做什么,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 附加更多信息到图节点,并选择返回一个迭代器。 注意 此 API 的向后兼容性得到保证。 返回类型 [*Iterator*](https://docs.python.org/3/library/typing.html#typing.Iterator "(in Python v3.12)") ```py keys(obj) ``` 当代理对象调用 keys()方法时调用。 这是在代理上调用**时发生的情况。这应该返回一个迭代器,**应该在您的自定义跟踪器中起作用。 注意 此 API 的向后兼容性得到保证。 返回类型 [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(in Python v3.12)") ```py path_of_module(mod) ``` 在`root`的模块层次结构中查找`mod`的限定名称的辅助方法。例如,如果`root`有一个名为`foo`的子模块,其中有一个名为`bar`的子模块,将`bar`传递给此函数将返回字符串“foo.bar”。 参数 **mod**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")) - 要检索限定名称的`Module`。 返回类型 [str](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)") 注意 此 API 的向后兼容性得到保证。 ```py proxy(node) ``` 注意 此 API 的向后兼容性得到保证。 返回类型 *Proxy* ```py to_bool(obj) ``` 当代理对象被转换为布尔值时调用,例如 在控制流中使用时。通常我们不知道该做什么,因为我们不知道代理的值,但是自定义跟踪器可以使用 create_node 附加更多信息到图节点,并选择返回一个值。 注意 此 API 的向后兼容性得到保证。 返回类型 [bool](https://docs.python.org/3/library/functions.html#bool "(in Python v3.12)") ```py trace(root, concrete_args=None) ``` 跟踪`root`并返回相应的 FX`Graph`表示。`root`可以是`nn.Module`实例或 Python 可调用对象。 请注意,在此调用之后,`self.root`可能与此处传入的`root`不同。例如,当将自由函数传递给`trace()`时,我们将创建一个`nn.Module`实例用作根,并添加嵌入常量。 参数 + **root**(*Union***[*Module**,* *Callable**]*) - 要跟踪的`Module`或函数。此参数的向后兼容性得到保证。 + **concrete_args**(*可选**[**Dict**[*[*str*](https://docs.python.org/3/library/stdtypes.html#str "(in Python v3.12)")*,* *any**]**]*) - 不应视为代理的具体参数。此参数是实验性的,不保证向后兼容。 返回 表示传入`root`语义的`Graph`。 返回类型 *Graph* 注意 此 API 的向后兼容性得到保证。 ```py class torch.fx.Proxy(node, tracer=None) ``` `Proxy`对象是`Node`包装器,它们在符号跟踪期间流经程序,并记录它们触及的所有操作(`torch`函数调用,方法调用,运算符)到不断增长的 FX 图中。 如果您正在进行图形转换,可以将自己的`Proxy`方法包装在原始`Node`周围,以便您可以使用重载的运算符向`Graph`添加其他内容。 `Proxy`对象无法迭代。换句话说,如果在循环中使用`Proxy`或作为`*args`/`**kwargs`函数参数,则符号跟踪器将抛出错误。 有两种主要方法可以解决这个问题:1. 将无法跟踪的逻辑提取到顶层函数中,并在其上使用`fx.wrap`。2. 如果控制流是静态的(即循环次数基于某些超参数),则可以将代码保留在原始位置并重构为类似以下内容: ```py for i in range(self.some_hyperparameter): indexed_item = proxied_value[i] ``` 有关 Proxy 内部更详细的描述,请查看 torch/fx/OVERVIEW.md 中的“Proxy”部分 注意 此 API 的向后兼容性是有保证的。 ```py class torch.fx.Interpreter(module, garbage_collect_values=True) ``` Interpreter 逐节点执行 FX 图。这种模式对于许多事情都很有用,包括编写代码转换以及分析传递。 Interpreter 类中的方法可以被重写以自定义执行行为。在调用层次结构方面,可以重写的方法映射如下: ```py run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output() ``` 示例 假设我们想要交换所有`torch.neg`的实例与`torch.sigmoid`以及反之(包括它们的`Tensor`方法等效)。我们可以像这样子类化 Interpreter: ```py class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid()) ``` 参数 + **module**(*GraphModule*)- 要执行的模块 + **garbage_collect_values**([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)"))- 是否在模块执行后删除值。这可以确保执行期间的最佳内存使用。可以禁用此功能,例如,通过查看`Interpreter.env`属性来检查执行中的所有中间值。 注意 此 API 的向后兼容性是有保证的。 ```py boxed_run(args_list) ``` 通过解释运行模块并返回结果。这使用“封装”调用约定,您传递一个参数列表,该列表将被解释器清除。这确保输入张量会被及时释放。 注意 此 API 的向后兼容性是有保证的。 ```py call_function(target, args, kwargs) ``` 执行`call_function`节点并返回结果。 参数 + **target**(*Target*)- 此节点的调用目标。有关语义的详细信息,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args**(*Tuple*)- 该调用的位置参数元组 + **kwargs**(*Dict*)- 该调用的关键字参数字典 返回类型 [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)") 返回 任何:函数调用返回的值 注意 此 API 的向后兼容性是有保证的。 ```py call_method(target, args, kwargs) ``` 执行`call_method`节点并返回结果。 参数 + **target**(*Target*)- 此节点的调用目标。有关语义的详细信息,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args**(*Tuple*)- 该调用的位置参数元组 + **kwargs**(*Dict*)- 该调用的关键字参数字典 返回类型 [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)") 返回 任何:方法调用返回的值 注意 此 API 的向后兼容性是有保证的。 ```py call_module(target, args, kwargs) ``` 执行`call_module`节点并返回结果。 参数 + **target**(*Target*)- 此节点的调用目标。有关语义的详细信息,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args**(*Tuple*)- 该调用的位置参数元组 + **kwargs**(*Dict*)- 该调用的关键字参数字典 返回类型 [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)") 返回 任何:模块调用返回的值 注意 此 API 的向后兼容性已得到保证。 ```py fetch_args_kwargs_from_env(n) ``` 从当前执行环境中获取节点`n`的`args`和`kwargs`的具体值。 参数 **n**(*Node*)- 应该获取`args`和`kwargs`的节点。 返回 具体值为`n`的`args`和`kwargs`。 返回类型 Tuple[Tuple, Dict] 注意 此 API 的向后兼容性已得到保证。 ```py fetch_attr(target) ``` 从`self.module`的`Module`层次结构中提取属性。 参数 **target**([*str*](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)"))- 要获取的属性的完全限定名称 返回 属性的值。 返回类型 任何 注意 此 API 的向后兼容性已得到保证。 ```py get_attr(target, args, kwargs) ``` 执行一个`get_attr`节点。将从`self.module`的`Module`层次结构中检索属性值。 参数 + **target**(*Target*)- 此节点的调用目标。有关语义的详细信息,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args**(*Tuple*)- 用于此调用的位置参数元组 + **kwargs**(*Dict*)- 用于此调用的关键字参数字典 返回 检索到的属性值 返回类型 任何 注意 此 API 的向后兼容性已得到保证。 ```py map_nodes_to_values(args, n) ``` 递归地遍历`args`并查找当前执行环境中每个`Node`的具体值。 参数 + **args**(*Argument*)- 用于查找具体值的数据结构 + **n**(*Node*)- `args`所属的节点。仅用于错误报告。 返回类型 [*Optional*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12)")[[*Union*](https://docs.python.org/3/library/typing.html#typing.Union "(在 Python v3.12)")[[*Tuple*](https://docs.python.org/3/library/typing.html#typing.Tuple "(在 Python v3.12)")[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12)"), …], [*List*](https://docs.python.org/3/library/typing.html#typing.List "(在 Python v3.12)")[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12)")], [*Dict*](https://docs.python.org/3/library/typing.html#typing.Dict "(在 Python v3.12)")[[str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)"), [*Any*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12)")], [slice](https://docs.python.org/3/library/functions.html#slice "(在 Python v3.12)"), [range](https://docs.python.org/3/library/stdtypes.html#range "(在 Python v3.12)"), *Node*, [str](https://docs.python.org/3/library/stdtypes.html#str "(在 Python v3.12)"), [int](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12)"), [float](https://docs.python.org/3/library/functions.html#float "(在 Python v3.12)"), [bool](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12)"), [complex](https://docs.python.org/3/library/functions.html#complex "(在 Python v3.12)"), *dtype*, *Tensor*, *device*, *memory_format*, *layout*, *OpOverload*]] 注意 此 API 的向后兼容性已得到保证。 ```py output(target, args, kwargs) ``` 执行一个`output`节点。这只是检索由`output`节点引用的值并返回它。 参数 + **target**(*Target*)- 此节点的调用目标。有关语义的详细信息,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args**(*Tuple*)- 用于此调用的位置参数元组 + **kwargs**(*Dict*)- 用于此调用的关键字参数字典 返回 由输出节点引用的返回值 返回类型 任何 注意 此 API 的向后兼容性已得到保证。 ```py placeholder(target, args, kwargs) ``` 执行`placeholder`节点。请注意,这是有状态的:`Interpreter`在传递给`run`的参数上维护一个内部迭代器,此方法在该迭代器上返回 next()。 参数 + **target** (*目标*) – 此节点的调用目标。有关语义详情,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args** (*元组*) – 该调用的位置参数元组 + **kwargs** (*字典*) – 该调用的关键字参数字典 返回 检索到的参数值。 返回类型 任何 注意 此 API 的向后兼容性已得到保证。 ```py run(*args, initial_env=None, enable_io_processing=True) ``` 通过解释运行模块并返回结果。 参数 + ***args** – 要运行的模块的参数,按位置顺序排列 + **initial_env** (*可选****字典**[*[*Node**,* *任何**]***) – 执行的可选起始环境。这是一个将节点映射到任何值的字典。例如,可以用来预先填充某些节点的结果,以便在解释器中仅进行部分评估。 + **enable_io_processing** ([*布尔*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")) – 如果为真,则在使用它们之前,我们首先使用图的 process_inputs 和 process_outputs 函数处理输入和输出。 返回 执行模块后返回的值 返回类型 任何 注意 此 API 的向后兼容性已得到保证。 ```py run_node(n) ``` 运行特定节点`n`并返回结果。根据`node.op`调用占位符、get_attr、call_function、call_method、call_module 或输出 参数 **n** (*Node*) – 要执行的节点 返回 执行`n`的结果 返回类型 任何 注意 此 API 的向后兼容性已得到保证。 ```py class torch.fx.Transformer(module) ``` `Transformer`是一种特殊类型的解释器,它生成一个新的`Module`。它公开了一个返回转换后`Module`的`transform()`方法。`Transformer`不需要参数来运行,而`Interpreter`需要。`Transformer`完全以符号方式工作。 示例 假设我们想要将所有`torch.neg`的实例与`torch.sigmoid`互换,反之亦然(包括它们的`Tensor`方法等效)。我们可以像这样子类化`Transformer`: ```py class NegSigmSwapXformer(Transformer): def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) ``` 参数 **module** (*GraphModule*) – 要转换的`Module`。 注意 此 API 的向后兼容性已得到保证。 ```py call_function(target, args, kwargs) ``` 注意 此 API 的向后兼容性已得到保证。 返回类型 [*任何*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)") ```py call_module(target, args, kwargs) ``` 注意 此 API 的向后兼容性已得到保证。 返回类型 [*任何*](https://docs.python.org/3/library/typing.html#typing.Any "(在 Python v3.12 中)") ```py get_attr(target, args, kwargs) ``` 执行`get_attr`节点。在`Transformer`中,这被重写为将一个新的`get_attr`节点插入输出图中。 参数 + **target** (*目标*) – 此节点的调用目标。有关语义详情,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args** (*元组*) – 该调用的位置参数元组 + **kwargs** (*字典*) – 该调用的关键字参数字典 返回类型 *代理* 注意 此 API 的向后兼容性已得到保证。 ```py placeholder(target, args, kwargs) ``` 执行`placeholder`节点。在`Transformer`中,这被重写为将一个新的`placeholder`插入输出图中。 参数 + **target** (*目标*) – 此节点的调用目标。有关语义详情,请参见[Node](https://pytorch.org/docs/master/fx.html#torch.fx.Node) + **args** (*元组*) – 该调用的位置参数元组 + **kwargs** (*字典*) – 该调用的关键字参数字典 返回类型 *代理* 注意 此 API 的向后兼容性已得到保证。 ```py transform() ``` 转换`self.module`并返回转换后的`GraphModule`。 注意 此 API 的向后兼容性已得到保证。 返回类型 *GraphModule* ```py torch.fx.replace_pattern(gm, pattern, replacement) ``` 在 GraphModule(`gm`)的图中匹配所有可能的非重叠运算符及其数据依赖关系(`pattern`),然后用另一个子图(`replacement`)替换这些匹配的子图。 参数 + **gm** (*GraphModule*) – 包装要操作的图的 GraphModule + **pattern** ([*Union*](https://docs.python.org/3/library/typing.html#typing.Union "(在 Python v3.12 中)")*[*[*Callable*](https://docs.python.org/3/library/typing.html#typing.Callable "(在 Python v3.12 中)")*,* *GraphModule**]*) – 要在`gm`中匹配以进行替换的子图 + **replacement** ([*Union*](https://docs.python.org/3/library/typing.html#typing.Union "(在 Python v3.12 中)")*[*[*Callable*](https://docs.python.org/3/library/typing.html#typing.Callable "(在 Python v3.12 中)")*,* *GraphModule**]*) – 用于替换`pattern`的子图 返回 一个表示`pattern`在原始图中匹配的位置的`Match`对象列表。如果没有匹配项,则列表为空。`Match`被定义为: ```py class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node] ``` 返回类型 List[Match] 示例: ```py import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]).sum() def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) ``` 上面的代码将首先在`traced_module`的`forward`方法中匹配`pattern`。模式匹配是基于使用-def 关系而不是节点名称的。例如,如果在`pattern`中有`p = torch.cat([a, b])`,则可以在原始`forward`函数中匹配`m = torch.cat([a, b])`,尽管变量名称不同(`p`与`m`)。 `pattern`中的`return`语句仅基于其值匹配;它可能匹配也可能不匹配较大图中的`return`语句。换句话说,模式不必延伸到较大图的末尾。 当匹配模式时,它将从较大函数中移除,并被`replacement`替换。如果在较大函数中有多个`pattern`的匹配项,每个不重叠的匹配项将被替换。在匹配重叠的情况下,将替换在重叠匹配集中找到的第一个匹配项。(这里的“第一个”被定义为节点使用-def 关系的拓扑排序中的第一个。在大多数情况下,第一个节点是直接在`self`之后出现的参数,而最后一个节点是函数返回的内容。) 一个重要的事情要注意的是`pattern` Callable 的参数必须在 Callable 本身中使用,而`replacement` Callable 的参数必须与模式匹配。第一个规则是为什么在上面的代码块中,`forward`函数具有参数`x, w1, w2`,但`pattern`函数只有参数`w1, w2`。`pattern`不使用`x`,所以不应该将`x`指定为参数。作为第二条规则的示例,考虑替换 ```py def pattern(x, y): return torch.neg(x) + torch.relu(y) ``` 与 ```py def replacement(x, y): return torch.relu(x) ``` 在这种情况下,`replacement`需要与`pattern`(`x`和`y`)具有相同数量的参数,即使`replacement`中没有使用参数`y`。 调用`subgraph_rewriter.replace_pattern`后,生成的 Python 代码如下: ```py def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2 ``` 注意 此 API 的向后兼容性是有保证的。