使用PyG (PyTorch Geometric) 实现同质图transductive链路预测任务

简介: 使用PyG (PyTorch Geometric) 实现同质图transductive链路预测任务

1. 数据获取


本文直接调用PyG官方的Cora数据集,如果环境可以直接登外网的话,其实可以直接运行后续模型。如果不能的话,可以参考我之前撰写的博文手动下载对应数据:PyG的Planetoid无法直接下载Cora等数据集的3个解决方式


2. 数据预处理


这里的处理方式是直接在载入数据时,就直接调用PyG的类:


  1. 对节点特征进行行归一化(T.NormalizeFeatures(),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.NormalizeFeatures,源码torch_geometric.transforms.normalize_features — pytorch_geometric documentation):使每一行总和为1、且更稀疏,具体做法是:元素减去最小值,然后除以总值(设置最小值为1)
  2. 将DataSet对象放到GPU上(T.ToDevice(device),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.ToDevice
  3. 对DataSet对象用链路预测的方法进行数据集划分:

T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False)

文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.RandomLinkSplit

训练集中不包含验证集和测试集的边,验证集中不包含测试集的边。注意本代码是transductive的,所以划分得到的3个数据集

返回的DataSet对象中的元素是tuple,每个tuple包含3个元素(train_data/val_data/test_data),每个元素都是Data对象。


import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      add_negative_train_samples=False),
])
dataset = Planetoid('pyg_data/Planetoid', name='Cora', transform=transform)
print(type(dataset))
train_data, val_data, test_data = dataset[0]
print(type(train_data))


输出(由于SciPy包版本导致的警告不赘):


<class 'torch_geometric.datasets.planetoid.Planetoid'>
<class 'torch_geometric.data.data.Data'>


3. 建立链路预测模型


  1. encode()函数:GNN节点表征,使用2层GCN,其中用了ReLU激活函数。没有其他trick。
  2. decode()函数在训练时使用,仅计算指定edge_label_index上的边,在代码上用逐元素求和表示点积。
  3. decode_all()函数在测试时使用,计算整张图所有节点对存在边的概率,也是用矩阵乘法来实现点积,结果的概率大于0直接认为节点对之间存在边,返回的是这个被认为存在边的edge list。


import torch
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)
    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


4. 实例化模型,设置优化器、损失函数


链路预测一般被建模为二分类任务(即边是否存在,因此常用

torch.nn.BCEWithLogitsLoss())
model = Net(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()


5. 构建训练函数


每个epoch调用一次训练函数。

在训练集上,首先用GNN实现节点表征,然后调用negative_sampling(文档:https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.negative_sampling)抽样负边(与正边数量一样),计算对应的损失函数。


from torch_geometric.utils import negative_sampling
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)
    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss


6. 构建每个epoch运行时的测试函数


我个人比较喜欢用with torch.no_grad()

每个epoch调用一次。

计算图数据上正边的概率,直接用其通过Sigmoid激活函数后的结果作为边存在的概率,用以计算ROC AUC值。


@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())


7. 训练和测试


训练100个epoch,最后得到测试集上所有模型认为存在的边。


best_val_auc = final_test_auc = 0
for epoch in range(1, 101):
    loss = train()
    val_auc = test(val_data)
    test_auc = test(test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')
print(f'Final Test: {final_test_auc:.4f}')
z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)


输出:


Epoch: 001, Loss: 0.6930, Val: 0.6729, Test: 0.7026
Epoch: 002, Loss: 0.6820, Val: 0.6589, Test: 0.6913
Epoch: 003, Loss: 0.7065, Val: 0.6619, Test: 0.6967
Epoch: 004, Loss: 0.6766, Val: 0.6686, Test: 0.7069
Epoch: 005, Loss: 0.6842, Val: 0.6716, Test: 0.7128
Epoch: 006, Loss: 0.6876, Val: 0.6637, Test: 0.7132
Epoch: 007, Loss: 0.6881, Val: 0.6471, Test: 0.7009
Epoch: 008, Loss: 0.6867, Val: 0.6317, Test: 0.6859
Epoch: 009, Loss: 0.6829, Val: 0.6240, Test: 0.6767
Epoch: 010, Loss: 0.6765, Val: 0.6223, Test: 0.6720
Epoch: 011, Loss: 0.6715, Val: 0.6208, Test: 0.6684
Epoch: 012, Loss: 0.6759, Val: 0.6204, Test: 0.6640
Epoch: 013, Loss: 0.6687, Val: 0.6272, Test: 0.6656
Epoch: 014, Loss: 0.6621, Val: 0.6488, Test: 0.6778
Epoch: 015, Loss: 0.6593, Val: 0.6748, Test: 0.6907
Epoch: 016, Loss: 0.6534, Val: 0.6824, Test: 0.6923
Epoch: 017, Loss: 0.6477, Val: 0.6796, Test: 0.6867
Epoch: 018, Loss: 0.6389, Val: 0.6847, Test: 0.6888
Epoch: 019, Loss: 0.6332, Val: 0.7155, Test: 0.7115
Epoch: 020, Loss: 0.6217, Val: 0.7487, Test: 0.7430
Epoch: 021, Loss: 0.6060, Val: 0.7645, Test: 0.7582
Epoch: 022, Loss: 0.5993, Val: 0.7650, Test: 0.7574
Epoch: 023, Loss: 0.5837, Val: 0.7632, Test: 0.7550
Epoch: 024, Loss: 0.5719, Val: 0.7612, Test: 0.7530
Epoch: 025, Loss: 0.5654, Val: 0.7565, Test: 0.7518
Epoch: 026, Loss: 0.5697, Val: 0.7574, Test: 0.7534
Epoch: 027, Loss: 0.5676, Val: 0.7610, Test: 0.7576
Epoch: 028, Loss: 0.5551, Val: 0.7629, Test: 0.7634
Epoch: 029, Loss: 0.5446, Val: 0.7682, Test: 0.7723
Epoch: 030, Loss: 0.5422, Val: 0.7774, Test: 0.7848
Epoch: 031, Loss: 0.5259, Val: 0.7896, Test: 0.7988
Epoch: 032, Loss: 0.5277, Val: 0.8005, Test: 0.8127
Epoch: 033, Loss: 0.5218, Val: 0.8135, Test: 0.8245
Epoch: 034, Loss: 0.5156, Val: 0.8234, Test: 0.8342
Epoch: 035, Loss: 0.5057, Val: 0.8285, Test: 0.8414
Epoch: 036, Loss: 0.4981, Val: 0.8314, Test: 0.8462
Epoch: 037, Loss: 0.4984, Val: 0.8302, Test: 0.8459
Epoch: 038, Loss: 0.4960, Val: 0.8332, Test: 0.8489
Epoch: 039, Loss: 0.4873, Val: 0.8381, Test: 0.8555
Epoch: 040, Loss: 0.4883, Val: 0.8418, Test: 0.8609
Epoch: 041, Loss: 0.4993, Val: 0.8427, Test: 0.8615
Epoch: 042, Loss: 0.4852, Val: 0.8452, Test: 0.8616
Epoch: 043, Loss: 0.4718, Val: 0.8474, Test: 0.8640
Epoch: 044, Loss: 0.4768, Val: 0.8492, Test: 0.8679
Epoch: 045, Loss: 0.4708, Val: 0.8472, Test: 0.8688
Epoch: 046, Loss: 0.4726, Val: 0.8457, Test: 0.8680
Epoch: 047, Loss: 0.4729, Val: 0.8500, Test: 0.8698
Epoch: 048, Loss: 0.4726, Val: 0.8517, Test: 0.8705
Epoch: 049, Loss: 0.4730, Val: 0.8527, Test: 0.8722
Epoch: 050, Loss: 0.4715, Val: 0.8521, Test: 0.8734
Epoch: 051, Loss: 0.4667, Val: 0.8547, Test: 0.8756
Epoch: 052, Loss: 0.4609, Val: 0.8577, Test: 0.8784
Epoch: 053, Loss: 0.4632, Val: 0.8607, Test: 0.8829
Epoch: 054, Loss: 0.4612, Val: 0.8626, Test: 0.8862
Epoch: 055, Loss: 0.4591, Val: 0.8646, Test: 0.8878
Epoch: 056, Loss: 0.4568, Val: 0.8644, Test: 0.8874
Epoch: 057, Loss: 0.4569, Val: 0.8656, Test: 0.8874
Epoch: 058, Loss: 0.4568, Val: 0.8688, Test: 0.8897
Epoch: 059, Loss: 0.4516, Val: 0.8721, Test: 0.8929
Epoch: 060, Loss: 0.4567, Val: 0.8729, Test: 0.8942
Epoch: 061, Loss: 0.4625, Val: 0.8742, Test: 0.8938
Epoch: 062, Loss: 0.4547, Val: 0.8729, Test: 0.8919
Epoch: 063, Loss: 0.4479, Val: 0.8723, Test: 0.8927
Epoch: 064, Loss: 0.4517, Val: 0.8728, Test: 0.8962
Epoch: 065, Loss: 0.4517, Val: 0.8719, Test: 0.8972
Epoch: 066, Loss: 0.4538, Val: 0.8726, Test: 0.8962
Epoch: 067, Loss: 0.4532, Val: 0.8718, Test: 0.8944
Epoch: 068, Loss: 0.4540, Val: 0.8725, Test: 0.8937
Epoch: 069, Loss: 0.4542, Val: 0.8734, Test: 0.8953
Epoch: 070, Loss: 0.4487, Val: 0.8726, Test: 0.8967
Epoch: 071, Loss: 0.4497, Val: 0.8727, Test: 0.8973
Epoch: 072, Loss: 0.4539, Val: 0.8694, Test: 0.8949
Epoch: 073, Loss: 0.4478, Val: 0.8703, Test: 0.8937
Epoch: 074, Loss: 0.4449, Val: 0.8737, Test: 0.8945
Epoch: 075, Loss: 0.4486, Val: 0.8770, Test: 0.8968
Epoch: 076, Loss: 0.4491, Val: 0.8724, Test: 0.8970
Epoch: 077, Loss: 0.4431, Val: 0.8678, Test: 0.8957
Epoch: 078, Loss: 0.4447, Val: 0.8688, Test: 0.8952
Epoch: 079, Loss: 0.4540, Val: 0.8704, Test: 0.8943
Epoch: 080, Loss: 0.4548, Val: 0.8741, Test: 0.8955
Epoch: 081, Loss: 0.4468, Val: 0.8746, Test: 0.8985
Epoch: 082, Loss: 0.4495, Val: 0.8727, Test: 0.8994
Epoch: 083, Loss: 0.4473, Val: 0.8708, Test: 0.8990
Epoch: 084, Loss: 0.4464, Val: 0.8715, Test: 0.8976
Epoch: 085, Loss: 0.4376, Val: 0.8755, Test: 0.8977
Epoch: 086, Loss: 0.4455, Val: 0.8762, Test: 0.8993
Epoch: 087, Loss: 0.4442, Val: 0.8727, Test: 0.9004
Epoch: 088, Loss: 0.4411, Val: 0.8726, Test: 0.9009
Epoch: 089, Loss: 0.4445, Val: 0.8760, Test: 0.9010
Epoch: 090, Loss: 0.4474, Val: 0.8780, Test: 0.9002
Epoch: 091, Loss: 0.4468, Val: 0.8754, Test: 0.9009
Epoch: 092, Loss: 0.4470, Val: 0.8712, Test: 0.9015
Epoch: 093, Loss: 0.4467, Val: 0.8680, Test: 0.9006
Epoch: 094, Loss: 0.4454, Val: 0.8720, Test: 0.9019
Epoch: 095, Loss: 0.4355, Val: 0.8761, Test: 0.9028
Epoch: 096, Loss: 0.4486, Val: 0.8749, Test: 0.9013
Epoch: 097, Loss: 0.4418, Val: 0.8695, Test: 0.8999
Epoch: 098, Loss: 0.4396, Val: 0.8651, Test: 0.9002
Epoch: 099, Loss: 0.4365, Val: 0.8684, Test: 0.9034
Epoch: 100, Loss: 0.4428, Val: 0.8720, Test: 0.9050
Final Test: 0.9002
torch.Size([2, 3262820])


8. 整体代码


import torch
from sklearn.metrics import roc_auc_score
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      add_negative_train_samples=False),
])
dataset = Planetoid('/data/pyg_data/Planetoid', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)
    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()
model = Net(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)
    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss
@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
best_val_auc = final_test_auc = 0
for epoch in range(1, 101):
    loss = train()
    val_auc = test(val_data)
    test_auc = test(test_data)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        final_test_auc = test_auc
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')
print(f'Final Test: {final_test_auc:.4f}')
z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)
print(final_edge_index.size())



相关文章
|
8月前
|
自然语言处理 PyTorch 算法框架/工具
自然语言生成任务中的5种采样方法介绍和Pytorch代码实现
在自然语言生成任务(NLG)中,采样方法是指从生成模型中获取文本输出的一种技术。本文将介绍常用的5中方法并用Pytorch进行实现。
295 0
|
6月前
|
资源调度 PyTorch 调度
多任务高斯过程数学原理和Pytorch实现示例
本文探讨了如何使用高斯过程扩展到多任务场景,强调了多任务高斯过程(MTGP)在处理相关输出时的优势。通过独立多任务GP、内在模型(ICM)和线性模型(LMC)的核心区域化方法,MTGP能够捕捉任务间的依赖关系,提高泛化能力。ICM和LMC通过引入核心区域化矩阵来学习任务间的共享结构。在PyTorch中,使用GPyTorch库展示了如何实现ICM模型,包括噪声建模和训练过程。实验比较了MTGP与独立GP,显示了MTGP在预测性能上的提升。
117 7
|
8月前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
139 1
|
机器学习/深度学习 存储 PyTorch
使用Pytorch Geometric 进行链接预测代码示例
PyTorch Geometric (PyG)是构建图神经网络模型和实验各种图卷积的主要工具。在本文中我们将通过链接预测来对其进行介绍。
87 0
|
8月前
|
机器学习/深度学习 PyTorch 测试技术
PyTorch实战:图像分类任务的实现与优化
【4月更文挑战第17天】本文介绍了使用PyTorch实现图像分类任务的步骤,包括数据集准备(如使用CIFAR-10数据集)、构建简单的CNN模型、训练与优化模型以及测试模型性能。在训练过程中,使用了交叉熵损失和SGD优化器。此外,文章还讨论了提升模型性能的策略,如调整模型结构、数据增强、正则化和利用预训练模型。通过本文,读者可掌握基础的PyTorch图像分类实践。
|
8月前
|
PyTorch 算法框架/工具
使用Pytorch Geometric 进行链接预测代码示例
该代码示例使用PyTorch和`torch_geometric`库实现了一个简单的图卷积网络(GCN)模型,处理Cora数据集。模型包含两层GCNConv,每层后跟ReLU激活和dropout。模型在训练集上进行200轮训练,使用Adam优化器和交叉熵损失函数。最后,计算并打印测试集的准确性。
156 6
|
8月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
679 2
|
8月前
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch在NLP任务中的应用:文本分类、序列生成等
【4月更文挑战第18天】PyTorch在NLP中应用于文本分类和序列生成,支持RNN、CNN、Transformer等模型构建。其动态计算图、丰富API及强大社区使其在NLP研究中备受欢迎。预训练模型和多模态学习的发展将进一步拓宽PyTorch在NLP的应用前景。
|
8月前
|
机器学习/深度学习 算法 PyTorch
PyTorch中的动态计算图与静态计算图
【4月更文挑战第18天】PyTorch的动态计算图在运行时构建,灵活且易于调试,适合模型开发,但执行效率相对较低,不易优化。静态计算图预定义,执行效率高,利于优化,适用于对效率要求高的场景,但灵活性和调试难度较大。两者在模型开发与部署阶段各有优势。
|
8月前
|
机器学习/深度学习 存储 PyTorch
使用pytorch构建图卷积网络预测化学分子性质
在本文中,我们将通过化学的视角探索图卷积网络,我们将尝试将网络的特征与自然科学中的传统模型进行比较,并思考为什么它的工作效果要比传统的方法好。
96 0