← 返回深度学习目录
← 返回学习笔记首页
图神经网络(GNN)
深度学习专题 · 在图结构数据上的深度学习
专题: 深度学习系统学习
关键词: 深度学习, GNN, 图神经网络, GCN, GAT, 消息传递, PyTorch Geometric, 节点分类, 图分类
一、图神经网络概述
图神经网络(Graph Neural Network, GNN)是一类专门处理图结构数据的深度学习模型。与传统的神经网络不同,GNN能够直接操作非欧几里得空间中的数据,捕捉节点之间复杂的拓扑关系和依赖信息。近年来,GNN在社交网络分析、分子性质预测、知识图谱推理、推荐系统、物理模拟等众多领域取得了突破性进展。
现实世界中存在大量图结构数据:社交网络中的用户关注关系、分子结构中的原子连接、知识图谱中的实体关联、交通网络中的道路连接、以及引文网络中的论文引用关系。这些数据天然以图的形式存在,传统卷积神经网络(CNN)和循环神经网络(RNN)难以有效处理。GNN的核心思想是通过迭代地聚合邻居节点的信息来更新每个节点的表示,从而同时捕捉图的结构特征和节点特征。
GNN的发展经历了从谱域方法到空间域方法的重要转变。早期工作主要基于图信号处理和图傅里叶变换的谱域图卷积,代表性工作包括Bruna等人提出的谱图CNN、Defferrard等人提出的ChebNet、以及Kipf和Welling提出的GCN。随后,空间域方法因其灵活性和计算效率逐渐成为主流,包括GraphSAGE、GAT(图注意力网络)、GIN(图同构网络)等。消息传递框架(Message Passing)的出现统一了各种GNN模型的设计范式,使得研究者可以在统一的框架下理解和设计新的GNN架构。
核心思想: GNN的核心是"邻居聚合"——每个节点的表示通过聚合其邻居节点的特征信息来迭代更新,使得节点能够感知整个图的结构信息。这与CNN在图像上的局部感受野理念一脉相承。
GNN的基本工作流程
初始化: 为每个节点赋予初始特征向量(可能是输入特征或通过学习得到的嵌入)
消息传递: 每个节点收集其邻居节点的信息(消息),通过聚合函数组合这些信息
更新: 将聚合后的消息与节点自身的表示相结合,通过非线性变换更新节点表示
读出: 经过多轮消息传递后,通过读出函数(ReadOut)将节点表示汇总为整图表示(用于图级别任务)
前置知识: 理解GNN需要具备基础图论知识(节点、边、邻接矩阵)、线性代数(矩阵运算、特征分解)、以及深度学习基础(前馈网络、梯度下降、反向传播)。本章后续将先介绍图论基础,再逐步深入各类GNN模型。
二、图论基础
在深入GNN之前,我们需要建立必要的图论基础知识。图是一种用于描述对象之间关系的数学结构,是GNN操作的核心数据对象。
2.1 图的基本定义
一个图 G = (V, E) 由节点集合 V 和边集合 E 组成。设图有 N 个节点,每个节点 v_i 可以关联一个 d 维的特征向量 x_i \in \mathbb{R}^d ,所有节点的特征矩阵为 X \in \mathbb{R}^{N \times d} 。边可以是有向的或无向的,可以带有权重,也可以具有多种类型。
图的分类: 根据边是否有方向可分为有向图和无向图;根据边是否带有权重可分为加权图和无权图;根据是否包含多种节点或边类型可分为异构图和同构图;根据节点和边的时序信息可分为动态图和静态图。
2.2 邻接矩阵
邻接矩阵 A \in \mathbb{R}^{N \times N} 是描述节点之间连接关系的核心数据结构。对于无权无向图,若节点 v_i 和 v_j 之间有边相连,则 A_{ij} = 1 ,否则 A_{ij} = 0 。无向图的邻接矩阵是对称矩阵。对于加权图,A_{ij} 存储边的权重值。邻接矩阵是GNN中消息传递的基础,它决定了哪些节点之间可以互相传递信息。
import numpy as np
# 创建一个简单的无向图邻接矩阵 (4个节点,链状结构)
# 0 -- 1 -- 2 -- 3
A = np.array([
[0, 1, 0, 0],
[1, 0, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 0]
])
print ("邻接矩阵 A:\n" , A)
print ("节点 1 的邻居:" , np.where (A[1] > 0)[0])
2.3 度矩阵
度矩阵 D \in \mathbb{R}^{N \times N} 是对角矩阵,其中 D_{ii} = \sum_j A_{ij} ,即节点 v_i 的度(邻居数量)。度矩阵反映了每个节点在图中的连接稠密程度,在GCN的归一化操作中扮演关键角色。
# 计算度矩阵
degrees = np.sum (A, axis=1) # 每个节点的度
D = np.diag (degrees)
print ("度矩阵 D:\n" , D)
print ("每个节点的度:" , degrees)
2.4 拉普拉斯矩阵
拉普拉斯矩阵 L = D - A 是图论中的核心算子,在图信号处理和谱图卷积中具有重要地位。其归一化形式为对称归一化拉普拉斯矩阵 L_{sym} = D^{-\frac{1}{2}} L D^{-\frac{1}{2}} = I - D^{-\frac{1}{2}} A D^{-\frac{1}{2}} 。拉普拉斯矩阵的特征值和特征向量蕴含了图的丰富结构信息,包括图的连通性、聚类结构、以及谱散布特性。谱域GCN正是利用拉普拉斯矩阵的特征分解来定义图上的卷积操作。
# 计算拉普拉斯矩阵
L = D - A
print ("拉普拉斯矩阵 L:\n" , L)
# 计算对称归一化拉普拉斯矩阵 L_sym = I - D^(-1/2) * A * D^(-1/2)
D_inv_sqrt = np.diag (1.0 / np.sqrt (degrees + 1e-8))
L_sym = np.eye (4) - D_inv_sqrt @ A @ D_inv_sqrt
print ("对称归一化拉普拉斯 L_sym:\n" , L_sym)
2.5 图信号
图信号(Graph Signal)是指定义在图的节点上的函数 f: V \rightarrow \mathbb{R} ,将每个节点映射到一个实数值。对于一个有 N 个节点的图,图信号可以表示为一个 N 维向量 x \in \mathbb{R}^N 。图傅里叶变换将图信号投影到拉普拉斯矩阵的特征向量基上,从而将图信号从空间域转换到谱域。图卷积定理指出:两个图信号在空间域的卷积等价于它们在谱域的逐点乘积。这一理论构成了谱域GCN的数学基础。
图论与GNN的关系: 邻接矩阵决定了消息传递的拓扑结构;度矩阵用于归一化以控制不同度数节点的影响力;拉普拉斯矩阵是谱域图卷积的理论基石;图信号对应于节点特征,是GNN处理的对象。理解这些基础概念是掌握GNN的前提。
三、GCN图卷积网络
图卷积网络(Graph Convolutional Network, GCN)是最具代表性和影响力的图神经网络模型之一,由Thomas Kipf和Max Welling于2017年提出。GCN的核心思想是将传统CNN的卷积操作推广到图结构数据上,通过谱域图卷积理论和巧妙的局部近似,实现了高效且强大的节点表示学习。
3.1 谱域图卷积理论
谱域图卷积的出发点是利用图傅里叶变换将卷积操作定义在谱域中。令 L = U \Lambda U^T 为拉普拉斯矩阵的特征分解,其中 U 是特征向量矩阵,\Lambda 是特征值对角阵。图信号 x 和卷积核 g_\theta 的谱域图卷积定义为:
g_\theta \star x = U g_\theta(\Lambda) U^T x
然而,直接进行特征分解的计算复杂度为 O(N^3) ,对于大规模图不可行。ChebNet通过使用切比雪夫多项式近似解决了这一问题。
3.2 ChebNet(切比雪夫网络)
ChebNet由Defferrard等人提出,使用 K 阶切比雪夫多项式来近似谱域卷积核:
g_\theta \star x \approx \sum_{k=0}^{K} \theta_k T_k(\tilde{L}) x
其中 \tilde{L} = \frac{2}{\lambda_{max}} L - I 是对拉普拉斯矩阵的缩放,T_k 是切比雪夫多项式的递归定义:T_0(x)=1, T_1(x)=x, T_k(x)=2xT_{k-1}(x)-T_{k-2}(x) 。通过 K 阶近似,每个节点的感受野扩展到其 K 阶邻居,计算复杂度降低到 O(K|E|) 。
3.3 GCN逐层传播规则
GCN在ChebNet的基础上进一步简化。通过设定 K=1 (只考虑一阶邻居)、\lambda_{max} \approx 2 ,并引入重规范化技巧,得到简洁高效的逐层传播规则:
H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} H^{(l)} W^{(l)})
其中 \tilde{A} = A + I 是添加自环后的邻接矩阵,\tilde{D}_{ii} = \sum_j \tilde{A}_{ij} 是相应的度矩阵,H^{(l)} 是第 l 层的节点表示,W^{(l)} 是可学习的权重矩阵,\sigma 是非线性激活函数(通常为ReLU)。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GCNLayer (nn.Module):
"""单层GCN实现
H^{(l+1)} = σ( D̃^{-1/2} Ã D̃^{-1/2} H^{(l)} W^{(l)} )
"""
def __init__ (self, in_dim, out_dim):
super ().__init__ ()
self.W = nn.Parameter(torch.randn (in_dim, out_dim))
nn.init.xavier_uniform_(self.W)
def forward (self, X, A_norm):
# X: [N, in_dim], A_norm: [N, N] 归一化邻接矩阵
support = X @ self.W # [N, out_dim]
out = A_norm @ support # 邻居聚合
return F.relu (out)
3.4 对称归一化的含义
对称归一化 \tilde{D}^{-\frac{1}{2}} \tilde{A} \tilde{D}^{-\frac{1}{2}} 的直观含义是:在聚合邻居信息时,同时考虑目标节点和邻居节点的度数。高度数节点("网红"节点)的特征传播到其众多邻居时会被稀释(除以 \sqrt{D_{ii}} ),而低度数节点聚合邻居信息时,高度数邻居的影响力会被降低(除以 \sqrt{D_{jj}} )。这种双向归一化避免了模型偏向高度数节点,使得GCN在各种度分布的图上都能稳定工作。
def gcn_normalize (A):
"""对邻接矩阵进行GCN风格的对称归一化"""
A_tilde = A + torch.eye (A.shape [0]) # 添加自环
D_tilde = torch.diag (A_tilde.sum (axis=1))
D_inv_sqrt = torch.diag (torch.pow (D_tilde.diag (), -0.5))
D_inv_sqrt[torch.isinf (D_inv_sqrt)] = 0
A_norm = D_inv_sqrt @ A_tilde @ D_inv_sqrt
return A_norm
3.5 多层GCN与半监督节点分类
堆叠多个GCN层可以扩大感受野,使每个节点感知到更远距离的邻居信息。一个典型的2层GCN用于半监督节点分类的架构如下:
class GCN (nn.Module):
"""2层GCN用于半监督节点分类"""
def __init__ (self, in_dim, hidden_dim, num_classes, dropout=0.5):
super ().__init__ ()
self.gcn1 = GCNLayer (in_dim, hidden_dim)
self.gcn2 = GCNLayer (hidden_dim, num_classes)
self.dropout = nn.Dropout (dropout)
def forward (self, X, A_norm):
h = self.gcn1.forward (X, A_norm)
h = self.dropout(h)
h = self.gcn2.forward (h, A_norm)
return F.log_softmax (h, dim=1)
# 在Cora引文网络数据集上的使用示例
model = GCN (in_dim=1433, hidden_dim=16, num_classes=7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 半监督训练:只使用少量有标签节点的损失进行训练
for epoch in range (200):
model.train ()
logits = model.forward (X, A_norm)
loss = F.nll_loss (logits[train_mask], y[train_mask])
optimizer.zero_grad ()
loss.backward ()
optimizer.step ()
3.6 GCN的核心优势与局限
优势
计算高效:线性复杂度 O(|E|) ,适合大规模图
参数共享:所有节点共享相同的权重矩阵,参数量与图规模无关
归纳能力:可以泛化到训练时未见过的节点
无需特征分解:避免了昂贵的矩阵分解运算
局限
同质性假设:假设相邻节点具有相似的特征和标签
过度平滑问题:层数增加时节点表示趋于一致
静态聚合权重:所有邻居的权重相同,缺乏区分能力
转导学习:标准GCN在全图邻接矩阵上进行操作
过度平滑问题: 当GCN堆叠过多层时(如超过5-7层),所有节点的表示会收敛到同一个子空间,节点之间的区分度消失。这是因为反复的拉普拉斯平滑操作使得节点表示趋向于整个图的全局平均。解决方案包括使用残差连接、Jumping Knowledge网络、或DROPEDGE等正则化技术。
四、GAT图注意力网络
图注意力网络(Graph Attention Network, GAT)由Petar Velickovic等人于2018年提出,通过引入注意力机制解决了GCN中所有邻居权重相同的问题。在GAT中,每个节点可以学习为其不同邻居分配不同的重要性权重,使得模型能够聚焦于更相关的邻居信息,同时忽略噪声邻居的干扰。
4.1 注意力系数的计算
对于节点 v_i 和其邻居节点 v_j ,GAT首先计算注意力系数 e_{ij} ,表示节点 j 对节点 i 的重要性:
e_{ij} = \text{LeakyReLU}(a^T [W h_i \; \Vert \; W h_j])
其中 W 是共享的权重矩阵,a 是可学习的注意力向量,\Vert 表示向量拼接操作。然后通过softmax对邻居间的注意力系数进行归一化:
\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp(e_{ik})}
归一化后的注意力系数 \alpha_{ij} 表示在聚合节点 i 的邻居信息时,节点 j 所占的权重比例。
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer (nn.Module):
"""单头图注意力层"""
def __init__ (self, in_dim, out_dim, negative_slope=0.2):
super ().__init__ ()
self.W = nn.Parameter(torch.randn (in_dim, out_dim))
self.a = nn.Parameter(torch.randn (2 * out_dim, 1))
self.negative_slope = negative_slope
nn.init.xavier_uniform_(self.W)
nn.init.xavier_uniform_(self.a)
def forward (self, X, adj):
# X: [N, in_dim], adj: [N, N] 邻接矩阵
h = X @ self.W # [N, out_dim]
N = h.shape [0]
h_i = h.unsqueeze (1).repeat (1, N, 1) # [N, N, out_dim]
h_j = h.unsqueeze (0).repeat (N, 1, 1) # [N, N, out_dim]
e = torch.cat ([h_i, h_j], dim=-1) # [N, N, 2*out_dim]
e = e @ self.a # [N, N, 1]
e = e.squeeze (-1) # [N, N]
e = F.leaky_relu (e, self.negative_slope)
# 掩码:只对邻居节点计算注意力
mask = (adj == 0)
e = e.masked_fill (mask, float('-inf' ))
alpha = F.softmax (e, dim=-1) # [N, N] 注意力系数
# 加权聚合邻居特征
out = alpha @ h # [N, out_dim]
return out
4.2 多头注意力机制
为提升模型的表达能力和训练稳定性,GAT使用多头注意力机制(Multi-Head Attention)。具体来说,同时使用 K 个独立的注意力头计算节点表示,然后将它们拼接(中间层)或求平均(输出层):
h_i' = \Vert_{k=1}^K \; \sigma(\sum_{j \in \mathcal{N}_i} \alpha_{ij}^{(k)} W^{(k)} h_j) (中间层)
h_i' = \sigma(\frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}_i} \alpha_{ij}^{(k)} W^{(k)} h_j) (输出层)
class MultiHeadGATLayer (nn.Module):
"""多头注意力GAT层"""
def __init__ (self, in_dim, out_dim, num_heads, concat=True):
super ().__init__ ()
self.heads = nn.ModuleList([
GATLayer (in_dim, out_dim) for _ in range (num_heads)
])
self.concat = concat
def forward (self, X, adj):
head_outputs = [head.forward (X, adj) for head in self.heads]
if self.concat:
return torch.cat (head_outputs, dim=-1) # [N, K*out_dim]
else :
return torch.stack (head_outputs).mean (dim=0) # [N, out_dim]
4.3 稀疏注意力与动态聚合
GAT的一个关键特性是注意力权重的动态计算。在GCN中,聚合权重完全由图结构决定(由归一化邻接矩阵固定),而在GAT中,注意力系数 \alpha_{ij} 依赖于节点 i 和节点 j 的当前特征表示。这意味着同一对节点在不同层或不同训练阶段可能有不同的注意力权重,极大地增强了模型的表达能力。此外,由于注意力权重仅在邻居节点之间计算(通过掩码操作使非邻居对的 e_{ij} = -\infty ),GAT天然具有稀疏性,计算复杂度可以控制在 O(|E|) 。
GAT vs GCN 的核心区别: GCN使用基于度的固定归一化权重聚合邻居,相当于"平均池化";GAT通过学习得到动态注意力权重,相当于"加权池化"或"软性选择"。这使得GAT在异构图和包含噪声连接的数据上通常优于GCN。
4.4 完整的2层GAT模型
class GAT (nn.Module):
"""2层GAT用于节点分类"""
def __init__ (self, in_dim, hidden_dim, num_classes,
num_heads=8, dropout=0.6):
super ().__init__ ()
self.dropout = nn.Dropout (dropout)
# 第一层: 8头注意力, 输出 hidden_dim * 8
self.layer1 = MultiHeadGATLayer (
in_dim, hidden_dim, num_heads, concat=True)
# 第二层: 单头注意力, 输出 num_classes
self.layer2 = MultiHeadGATLayer (
hidden_dim * num_heads, num_classes, 1, concat=False)
def forward (self, X, adj):
h = self.dropout(X)
h = self.layer1.forward (h, adj)
h = F.elu (h)
h = self.dropout(h)
h = self.layer2.forward (h, adj)
return F.log_softmax (h, dim=1)
五、消息传递框架(Message Passing)
消息传递神经网络(Message Passing Neural Network, MPNN)由Gilmer等人于2017年提出,为各类GNN模型提供了一个统一的通用框架。MPNN框架将GNN的前向传播过程分为两个核心阶段:消息传递阶段(Message Passing Phase)和读出阶段(Readout Phase)。
5.1 消息传递的数学定义
在MPNN框架中,每个节点 v_i 在第 t 轮的消息传递过程可以表示为:
m_i^{(t)} = \text{AGGREGATE}^{(t)}(\{M^{(t)}(h_i^{(t-1)}, h_j^{(t-1)}, e_{ij}) \; | \; j \in \mathcal{N}_i\})
h_i^{(t)} = \text{UPDATE}^{(t)}(h_i^{(t-1)}, m_i^{(t)})
其中 M^{(t)} 是消息函数,用于生成从邻居 j 发送到节点 i 的消息;\text{AGGREGATE}^{(t)} 是聚合函数,用于将所有邻居的消息合并为一个固定维度的向量;\text{UPDATE}^{(t)} 是更新函数,用于将节点自身的表示与聚合消息相结合。整个消息传递过程重复 T 轮,使得每个节点最终能够感知到其 T 阶邻居的信息。
5.2 聚合函数详解
聚合函数是消息传递框架的关键组件,决定了如何组合多个邻居的信息。常见的聚合函数包括:
聚合函数 计算公式 特点 适用场景
求和(Sum)
\sum_{j \in \mathcal{N}_i} m_{ij}
区分度高,可以感知邻居的数量
图分类任务(对结构敏感)
均值(Mean)
\frac{1}{|\mathcal{N}_i|} \sum_{j \in \mathcal{N}_i} m_{ij}
平滑稳定,对度数不敏感
节点分类任务(特征平滑)
最大值(Max)
\max_{j \in \mathcal{N}_i} m_{ij}
捕捉最显著特征,对噪声鲁棒
分子性质预测(关注重要原子)
注意力(Attention)
\sum_{j \in \mathcal{N}_i} \alpha_{ij} m_{ij}
可学习的重要性加权
异构图、含噪声图
import torch
import torch.nn as nn
import torch.nn.functional as F
class MessagePassingLayer (nn.Module):
"""通用消息传递层实现"""
def __init__ (self, in_dim, out_dim, aggr='sum' ):
super ().__init__ ()
self.msg_fn = nn.Linear (in_dim * 2, out_dim) # 消息函数 M
self.update_fn = nn.Linear (in_dim + out_dim, out_dim) # 更新函数
self.aggr = aggr
def message (self, h_i, h_j):
"""从邻居j到节点i的消息"""
return self.msg_fn(torch.cat ([h_i, h_j], dim=-1))
def aggregate (self, messages, aggr_type):
"""聚合邻居消息"""
if aggr_type == 'sum' :
return messages.sum (dim=1)
elif aggr_type == 'mean' :
return messages.mean (dim=1)
elif aggr_type == 'max' :
return messages.max (dim=1).values
def forward (self, h, adj):
# h: [N, in_dim], adj: [N, N]
N = h.shape [0]
messages = []
for i in range (N):
neighbors = torch.where (adj[i] > 0)[0]
if len (neighbors) == 0:
messages.append (torch.zeros (self.msg_fn.out_features))
continue
h_i = h[i].unsqueeze (0).repeat (len (neighbors), 1)
h_j = h[neighbors]
msg = self.message (h_i, h_j)
messages.append (self.aggregate (msg.unsqueeze (0), self.aggr).squeeze (0))
m = torch.stack (messages) # [N, out_dim]
return F.relu (self.update_fn(torch.cat ([h, m], dim=-1)))
5.3 读出函数(ReadOut)
读出函数的作用是将经过多轮消息传递后的所有节点表示聚合为整图表示,用于图级别的任务(如图分类、图回归)。常见的读出策略包括:对所有节点表示取求和、均值、最大值,或者使用更复杂的Set2Set、注意力池化等方法。
class ReadOut (nn.Module):
"""图级别读出函数"""
def __init__ (self, in_dim, out_dim, pool_type='sum' ):
super ().__init__ ()
self.pool_type = pool_type
self.proj = nn.Linear (in_dim, out_dim)
def forward (self, h, batch_mask=None ):
# h: [N, in_dim]
if self.pool_type == 'sum' :
h_graph = h.sum (dim=0)
elif self.pool_type == 'mean' :
h_graph = h.mean (dim=0)
elif self.pool_type == 'max' :
h_graph = h.max (dim=0).values
return self.proj(h_graph)
5.4 GIN:图同构网络
图同构网络(Graph Isomorphism Network, GIN)由Xu等人提出,从理论和实践上证明了求和(Sum)聚合函数比均值(Mean)和最大值(Max)具有更强的表达能力。具体来说,当使用求和聚合结合多层感知机(MLP)作为更新函数时,GIN能够达到与Weisfeiler-Lehman图同构测试相当的区分能力,这在理论上是最优的。GIN的更新规则为:
h_i^{(t)} = \text{MLP}^{(t)}\left((1 + \epsilon^{(t)}) \cdot h_i^{(t-1)} + \sum_{j \in \mathcal{N}_i} h_j^{(t-1)}\right)
聚合函数的选择: 理论上,Sum聚合可以区分不同的多重集(multiset),表达能力强于Mean和Max。实践中,Sum在分子性质预测等需要精确结构信息的任务中表现最佳,而Mean在节点分类等特征平滑任务中更稳定。建议根据具体任务通过交叉验证选择合适的聚合函数。
六、图分类与池化
图分类(Graph Classification)是GNN的重要应用场景,目标是为整个图分配一个类别标签(例如分子是否为某一类化合物、文档属于哪个主题)。与节点分类不同,图分类需要从一组节点表示中提取出全局的图级别表示。这通常通过图池化(Graph Pooling)操作来实现。
6.1 全局池化方法
最简单的图池化方法是对所有节点表示进行全局池化,包括全局求和池化(Global Add Pooling)、全局均值池化(Global Mean Pooling)和全局最大池化(Global Max Pooling)。这些方法计算简单、参数量为零,但在图规模差异较大时可能丢失重要信息。
class GlobalPooling (nn.Module):
"""图级别全局池化"""
def __init__ (self, pool_type='add' ):
super ().__init__ ()
self.pool_type = pool_type
def forward (self, x, batch):
# x: [total_nodes, feat_dim] batch: [total_nodes] 每个节点所属图的索引
if self.pool_type == 'add' :
return torch.scatter_add (x, batch, dim=0)
elif self.pool_type == 'mean' :
return torch.scatter_add (x, batch, dim=0) / torch.bincount (batch).unsqueeze (1)
elif self.pool_type == 'max' :
return torch.scatter_reduce (x, 0, batch, reduce='max' )
6.2 层次化池化方法
层次化池化通过在GNN层之间插入池化操作来逐步降低图的规模并提取层次化的特征表示。常用的层次化池化方法包括:
DiffPool(可微池化): 学习一个软分配矩阵 S \in \mathbb{R}^{N \times K} ,将 N 个节点软分配到 K 个簇中,生成一个更小的"粗化图"。粗化图的节点特征和邻接矩阵通过 X' = S^T X 和 A' = S^T A S 计算。DiffPool是可微的,可以端到端训练,但空间复杂度较高。
TopKPooling: 根据节点的重要性分数选择保留最重要的 K 个节点,丢弃其余节点。节点重要性通过一个可学习的投影向量 p 计算:score_i = x_i^T p / \|p\| 。TopKPooling简单高效,但丢弃节点的信息完全丢失。
SAGPool(自注意力图池化): 使用GCN计算节点的重要性分数,同时考虑节点的特征和拓扑结构:score = \text{sigmoid}(\text{GCN}(X, A)) 。SAGPool比TopKPooling更能感知图结构,通常获得更好的效果。
class SAGPool (nn.Module):
"""自注意力图池化层 (Self-Attention Graph Pooling)
使用GCN计算节点重要性分数,选择top-k节点保留
"""
def __init__ (self, in_dim, ratio=0.5):
super ().__init__ ()
self.gcn_score = GCNLayer (in_dim, 1) # 输出单个分数
self.ratio = ratio
def forward (self, X, A_norm):
score = self.gcn_score.forward (X, A_norm).squeeze (-1) # [N]
score = torch.sigmoid (score)
k = max(1, int(X.shape [0] * self.ratio))
topk_idx = torch.topk (score, k).indices
# 筛选节点和边
X_pool = X[topk_idx] * score[topk_idx].unsqueeze (-1)
A_pool = A_norm[topk_idx][:, topk_idx]
return X_pool, A_pool, topk_idx
# 使用SAGPool构建层次化图分类模型
class HierarchicalGCN (nn.Module):
def __init__ (self, in_dim, hidden_dim, num_classes):
super ().__init__ ()
self.conv1 = GCNLayer (in_dim, hidden_dim)
self.pool1 = SAGPool (hidden_dim, ratio=0.5)
self.conv2 = GCNLayer (hidden_dim, hidden_dim)
self.pool2 = SAGPool (hidden_dim, ratio=0.5)
self.readout = ReadOut (hidden_dim, num_classes, pool_type='mean' )
def forward (self, X, A_norm):
h = self.conv1.forward (X, A_norm)
h, A_norm, _ = self.pool1.forward (h, A_norm)
h = self.conv2.forward (h, A_norm)
h, A_norm, _ = self.pool2.forward (h, A_norm)
h_graph = self.readout.forward (h)
return F.log_softmax (h_graph, dim=-1)
七、GNN的典型应用
GNN因其强大的图结构数据处理能力,在众多领域取得了广泛的应用。以下介绍几个最具代表性的应用方向。
7.1 节点分类
节点分类是GNN最经典的任务,目标是为图中的每个节点预测一个类别标签。典型应用包括:在引文网络中为论文分配研究领域标签、在社交网络中识别用户所属的社区或兴趣群组、在知识图谱中推断实体的类型。GCN和GAT在半监督节点分类场景中表现尤为突出,仅需少量有标签样本即可利用图结构信息进行全局传播。
7.2 链接预测
链接预测旨在预测图中可能存在的缺失边或未来可能出现的边。其核心思路是:利用GNN学习节点表示后,通过评分函数(如内积、MLP)计算两个节点之间存在边的概率。链接预测在推荐系统(预测用户-物品交互)、社交网络(推荐好友)、知识图谱补全和药物相互作用预测中具有重要应用。
class LinkPredictor (nn.Module):
"""基于GNN的链接预测模型"""
def __init__ (self, node_feat_dim, hidden_dim):
super ().__init__ ()
self.gnn = GCN (node_feat_dim, hidden_dim, hidden_dim)
self.score_fn = nn.Sequential (
nn.Linear (hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear (hidden_dim, 1)
)
def forward (self, X, A_norm, edge_pairs):
# 先获取所有节点表示
h = self.gnn.forward (X, A_norm)
# 对每个候选边对计算连接概率
h_src = h[edge_pairs[:, 0]]
h_dst = h[edge_pairs[:, 1]]
score = self.score_fn(torch.cat ([h_src, h_dst], dim=-1))
return torch.sigmoid (score).squeeze ()
7.3 图分类与分子性质预测
在计算化学和药物发现领域,分子天然可以表示为图(原子为节点,化学键为边)。GNN通过学习分子图的结构和原子特征来预测分子的各种性质,如毒性、溶解度、药物活性等。Google的DeepMind和MIT等机构的研究表明,GNN在分子性质预测任务上显著优于传统的分子指纹方法和前馈神经网络。代表性工作包括MoleculeNet基准测试、以及用于COVID-19药物筛选的GNN应用。
7.4 推荐系统
GNN在推荐系统中的应用近年来发展迅速。用户和物品可以构建为一个二部图(Bipartite Graph),用户-物品交互(点击、购买、评分)构成图的边,用户属性特征和物品属性特征构成初始节点特征。GNN通过在图上的消息传递来捕获用户-物品之间高阶的协同信号。PinSage(Pinterest的GNN推荐系统)和LightGCN是这一方向的代表性工作。
7.5 物理模拟与科学计算
在物理模拟领域,GNN被用于建模粒子系统(如流体力学、分子动力学)和刚体动力学。GraphNetwork-based Simulator(GNS)等方法将物理系统中的实体作为节点,相互作用作为边,通过学习物理规律来预测系统的未来状态。GNN在天气预报(如GraphCast)、交通流预测、以及材料科学中的晶体性质预测等任务中也取得了优异成果。
GNN的核心价值: GNN之所以在如此多的领域取得成功,根本原因在于它提供了一种能够灵活处理非欧几里得数据、同时捕捉节点特征和拓扑结构信息的深度学习范式。在许多任务中,图结构本身携带了丰富的信息,而GNN是迄今为止最有效的利用这些结构信息的方法。
7.6 其他新兴应用
组合优化: 使用GNN学习图上的启发式算法用于求解TSP、最大团等NP难问题
代码分析: 将代码抽象语法树作为图输入GNN,用于bug检测和代码补全
点云处理: 在3D点云的K近邻图上使用GNN进行语义分割和分类
异常检测: 在金融交易图和网络流量图上检测欺诈和入侵行为
知识图谱推理: 使用GNN进行多跳推理和关系预测
八、PyTorch Geometric基础
PyTorch Geometric(PyG)是当前最流行的图神经网络框架之一,基于PyTorch构建,提供了丰富的GNN层实现、数据集加载工具和高效的消息传递运算支持。它极大地简化了GNN的开发流程,使得研究者可以快速原型化和实验不同的GNN模型。
8.1 环境安装与核心概念
# 安装 PyTorch Geometric (CPU版本示例)
# pip install torch torchvision torchaudio
# pip install torch_geometric
# pip install torch_scatter torch_sparse torch_cluster torch_spline_conv
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_add_pool
from torch_geometric.datasets import Planetoid
from torch_geometric.data import Data
PyG的核心数据结构是 \texttt{Data} 对象,包含以下关键属性:
x: 节点特征矩阵 [N, F]
edge_index: 边索引 [2, E] ,COO格式(节省内存)
edge_attr: 边特征矩阵 [E, D] (可选)
y: 标签(节点级别或图级别)
train_mask / val_mask / test_mask: 训练/验证/测试集掩码
batch: 批量处理时用于区分不同图的索引
# 创建自定义图数据
# 3个节点, 边: 0-1, 1-2
edge_index = torch.tensor ([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor ([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float)
y = torch.tensor ([0, 1, 0], dtype=torch.long)
data = Data (x=x, edge_index=edge_index, y=y)
print (data)
# Data(x=[3, 2], edge_index=[2, 4], y=[3])
8.2 使用PyG构建GCN
from torch_geometric.nn import GCNConv
class PyG_GCN (nn.Module):
"""使用PyG实现的2层GCN"""
def __init__ (self, in_dim, hidden_dim, num_classes, dropout=0.5):
super ().__init__ ()
self.conv1 = GCNConv (in_dim, hidden_dim)
self.conv2 = GCNConv (hidden_dim, num_classes)
self.dropout = nn.Dropout (dropout)
def forward (self, x, edge_index):
h = self.conv1(x, edge_index)
h = F.relu (h)
h = self.dropout(h)
h = self.conv2(h, edge_index)
return F.log_softmax (h, dim=1)
# 加载Cora数据集
dataset = Planetoid (root='/tmp/Cora' , name='Cora' )
data = dataset[0]
print (f"数据集: {dataset}" )
print (f"节点数: {data.num_nodes}, 边数: {data.num_edges}" )
print (f"特征维度: {data.num_features}, 类别数: {dataset.num_classes}" )
# 训练
device = torch.device('cuda' if torch.cuda.is_available () else 'cpu' )
model = PyG_GCN (dataset.num_features, 16, dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train ()
for epoch in range (200):
optimizer.zero_grad ()
out = model.forward (data.x, data.edge_index)
loss = F.nll_loss (out[data.train_mask], data.y[data.train_mask])
loss.backward ()
optimizer.step ()
if epoch % 20 == 0:
print (f"Epoch {epoch}, Loss: {loss.item():.4f}" )
8.3 使用PyG构建GAT
from torch_geometric.nn import GATConv
class PyG_GAT (nn.Module):
"""使用PyG实现2层GAT"""
def __init__ (self, in_dim, hidden_dim, num_classes,
heads=8, dropout=0.6):
super ().__init__ ()
self.conv1 = GATConv (in_dim, hidden_dim, heads=heads, dropout=dropout)
self.conv2 = GATConv (hidden_dim * heads, num_classes,
heads=1, concat=False, dropout=dropout)
self.dropout = nn.Dropout (dropout)
def forward (self, x, edge_index):
h = self.dropout(x)
h = self.conv1(h, edge_index)
h = F.elu (h)
h = self.dropout(h)
h = self.conv2(h, edge_index)
return F.log_softmax (h, dim=1)
8.4 使用PyG进行图分类
from torch_geometric.nn import global_mean_pool
from torch_geometric.datasets import TUDataset
class PyG_GraphClassifier (nn.Module):
"""图分类模型,包含GCN卷积 + 全局池化"""
def __init__ (self, in_dim, hidden_dim, num_classes):
super ().__init__ ()
self.conv1 = GCNConv (in_dim, hidden_dim)
self.conv2 = GCNConv (hidden_dim, hidden_dim)
self.conv3 = GCNConv (hidden_dim, hidden_dim)
self.classifier = nn.Linear (hidden_dim, num_classes)
def forward (self, x, edge_index, batch):
h = self.conv1(x, edge_index).relu ()
h = self.conv2(h, edge_index).relu ()
h = self.conv3(h, edge_index).relu ()
# 全局均值池化: 将每个图中的所有节点表示取平均
h_graph = global_mean_pool(h, batch)
return self.classifier(h_graph)
# 加载MUTAG数据集(分子图分类)
dataset = TUDataset (root='/tmp/TUDataset' , name='MUTAG' )
print (f"数据集大小: {len(dataset)}" )
print (f"特征维度: {dataset.num_features}" )
8.5 PyG的高级功能
PyG还提供了许多高级功能,包括但不限于:
Mini-batch处理: 通过 \texttt{DataLoader} 自动将多个图组织成批次,使用 \texttt{batch} 向量跟踪每个节点所属的图
异构图支持: 通过 \texttt{HeteroData} 支持包含多种节点类型和边类型的图
图Transformer: 内置 \texttt{TransformerConv} 等先进架构
时序图: 支持动态图的 \texttt{ToUndirected} 和时序邻接采样
邻接采样: 通过 \texttt{NeighborSampler} 支持大规模图的邻居采样训练
丰富的图基准: 提供OGB(Open Graph Benchmark)等标准数据集的加载接口
# 使用NeighborSampler进行大规模图训练
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader (
data,
num_neighbors=[10, 10], # 每层采样的邻居数量
batch_size=256,
input_nodes=data.train_mask,
)
for batch in train_loader:
out = model.forward (batch.x, batch.edge_index)
loss = F.nll_loss (out[:batch.batch_size], batch.y[:batch.batch_size])
九、总结与展望
核心要点总结:
图神经网络是处理图结构数据的强大深度学习范式,核心思想是通过消息传递机制聚合邻居信息
图论基础(邻接矩阵、度矩阵、拉普拉斯矩阵)是理解GNN的前提,为消息传递提供了数学框架
GCN通过谱域图卷积的1阶近似实现了高效的邻居聚合,每个邻居权重固定为归一化系数
GAT引入注意力机制使得模型能够为不同邻居学习不同权重,增强了表达能力和灵活性
消息传递框架(MPNN)统一了各种GNN模型的设计范式,聚合函数的选择显著影响模型性能
图池化操作(全局池化/层次化池化)是图分类任务的关键组件
PyTorch Geometric提供了高效易用的GNN开发工具链,极大降低了GNN应用的开发门槛
GNN在节点分类、链接预测、图分类、分子发现、推荐系统、物理模拟等领域具有广泛的应用前景
GNN的演进趋势
GNN领域近年来发展迅猛,以下几个方向值得重点关注:
大规模GNN: 针对亿级节点和边的大规模图,研究高效的邻域采样策略和分布式训练方法
深层GNN: 通过残差连接、归一化技术和降噪策略解决过度平滑问题,构建更深层的GNN网络
Graph Transformer: 将Transformer架构的自注意力机制引入图数据,实现全局级别的节点交互
时间动态图: 处理节点和边随时间动态变化的场景,如社交网络演化、交通流量变化
异构图与多模态: 融合文本、图像、知识图谱等多源异构信息的图神经网络
可解释GNN: 发展GNN的可解释性方法(如GNNExplainer),理解模型做出预测的依据
预训练GNN: 借鉴NLP中的预训练范式,在大规模无标注图上进行自监督预训练后迁移到下游任务
推荐学习路径: 1)掌握图论基础和PyTorch入门知识;2)阅读GCN和GAT的原始论文并复现核心代码;3)熟悉PyTorch Geometric框架并通过开源项目实战练习;4)阅读斯坦福CS224W课程《Machine Learning with Graphs》的公开材料;5)关注NeurIPS、ICML、ICLR等顶会的最新GNN论文。
进一步阅读资源:
Kipf & Welling, "Semi-Supervised Classification with Graph Convolutional Networks", ICLR 2017
Velickovic et al., "Graph Attention Networks", ICLR 2018
Gilmer et al., "Neural Message Passing for Quantum Chemistry", ICML 2017
Xu et al., "How Powerful are Graph Neural Networks?", ICLR 2019
Wu et al., "A Comprehensive Survey on Graph Neural Networks", IEEE TNNLS 2021
PyTorch Geometric官方文档: https://pytorch-geometric.readthedocs.io/