Below you will find pages that utilize the taxonomy term “Pytorch”
📎Posts
TorchInductor Pattern Matcher
PyTorch FX 图#
PyTorch FX 是用于捕获、分析和转换 PyTorch 计算图。FX 图是一种静态表示,它记录了 PyTorch 代码的执行流程。用户通过将模型表示为FX图,可以更轻松地进行各种转换,例如图优化,量化,算子融合等。
FX 图的核心组件包括:
torch.fx.Graph:计算图的容器torch.fx.Node:图中的节点,表示计算操作,如函数调用、方法调用等torch.fx.GraphModule:由图构建的可执行模块
graph TD
subgraph FX_Graph
A["Placeholder Node"] --> B["CallFunction Node"]
B --> C["CallMethod Node"]
C --> D["Output Node"]
E["Module Node"] --> B
end
subgraph Components
F["torch.fx.Graph"] --> FX_Graph
G["torch.fx.Node"] --> A
G --> B
G --> C
G --> D
H["torch.fx.GraphModule"] --> F
end
style FX_Graph stroke:#333,stroke-width:2px
style Components stroke:#333,stroke-width:2px
FX Symbolic Tracing#
FX 图的生成过程称为"符号追踪"(Symbolic Tracing),主要步骤包括:
- 追踪:使用
torch.fx.symbolic_trace()对 PyTorch 函数或模块进行追踪 - 捕获:捕获函数执行过程中的所有操作,构建计算图
- 表示:将计算图表示为
Graph对象,其中包含一系列Node对象 - 转换:对捕获的图进行分析和转换
- 执行:将转换后的图包装为
GraphModule,可像普通 PyTorch 模块一样执行
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
# 由一个列表组成 代表函数输入、调用点(函数、方法、 或 torch.nn.Module 实例),以及返回值。
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
# 使 FX 成为 Python 到 Python(或 模块到模块)转换工具包。对于每个 Graph IR,我们可以 创建与图语义匹配的有效 Python 代码。
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
FX 图的特点