图分类任务实战

简介: 图分类任务实战

图分类任务

前期准备

  • 这回咱们拿到的不是一个图,而是一堆图(188个)
  • 对这些图整体进行分类任务,数据集概述如下:
import torch
# 导入TUDataset,这是一个用于加载图数据集的类。
from torch_geometric.datasets import TUDataset #分子数据集:https://chrsmrrs.github.io/datasets/
# 创建一个名为dataset的图数据集对象。这一行指定了数据集的根目录和数据集的名称。数据集包含多个图,每个图表示一种分子结构,。
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
print()
print(f'Dataset: {dataset}:')  # 打印数据集的信息,包括数据集的名称。
print('====================')
print(f'Number of graphs: {len(dataset)}')  # 打印数据集中包含的图的数量。
print(f'Number of features: {dataset.num_features}')  # 打印每个图中节点的特征数量。
print(f'Number of classes: {dataset.num_classes}')  # 打印数据集中的类别数量。在分子数据集中,通常表示不同种类的分子。
data = dataset[0]  # Get the first graph object.
print()
"""
      edge_index=[2, 38]:这部分信息表示图中的边的索引。edge_index 是一个2xN的矩阵,其中每一列代表一条边,N 是边的总数。在这个例子中,图中有 38 条边,每个边由两个节点组成。
      x=[17, 7]:这部分信息表示图中的节点特征。x是一个节点特征矩阵,其中17表示图中有17个节点,每个节点的特征向量的维度为7.
      edge_attr=[38, 4]:这部分信息表示图中每条边的特征。edge_attr 是一个边特征矩阵,其中38表示有38条边,每个边的特征向量的维度为 4。
      y=[1]:这部分信息表示图的标签或类别。在这个例子中,图被标记为一个类别,其值为 1。
"""
print(data)
print('=============================================================')
# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')  # 打印第一个图中的节点数量。
print(f'Number of edges: {data.num_edges}')  # 打印第一个图中的边的数量。
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')  # 计算并打印平均节点度(每个节点连接的边的平均数量)。
print(f'Has isolated nodes: {data.has_isolated_nodes()}')  # 检查并打印图是否包含孤立节点(只与自己连接的节点)。
print(f'Has self-loops: {data.has_self_loops()}')  # 检查并打印图是否包含自环(节点与自己连接的边)。
print(f'Is undirected: {data.is_undirected()}')  # 检查并打印图是否是无向图。

输出结果:

Mini-batching

图神经网络中的 Mini-batching 流程:

1. 图数据准备: 首先,你需要准备一个大型的图数据集。这个图可以表示各种关系,如社交网络、知识图谱、通信网络等。图数据节点组成,节点代表实体代表它们之间的连接关系节点可以有特征,类似于传统神经网络中的输入特征

2. Mini-batch 划分: 大型图数据通常无法一次性输入到图神经网络中,因此你将图数据划分成多个小批量(mini-batch)数据。每个小批量包含一部分节点和它们之间的连接关系小批量的大小通常是一个超参数,例如 32 或 64,表示每次训练时处理的节点数量

3. 节点表示初始化: 在 Mini-batch 训练开始时,为每个节点初始化表示。这通常是一个随机向量,可以是节点特征的一部分。节点表示将随着训练逐渐更新,以捕获节点之间的信息传递。

4. 前向传播: 对于每个小批量数据,进行前向传播。这涉及计算节点表示的更新,以便考虑它们的邻居节点。节点表示的更新通常使用聚合函数如池化或均值池化)来汇总邻居节点的信息。

5. 汇总: 在前向传播期间,节点表示会根据邻居节点的信息进行汇总,以更新每个节点的表示。这使得节点能够考虑到它们的邻居节点的信息,以更好地捕获图中的关系。

6. 计算损失: 使用前向传播的结果计算损失函数,以度量模型的性能。损失函数通常比较模型的预测输出真实标签,然后计算误差

7. 反向传播:损失进行反向传播,计算每个节点表示的梯度。这些梯度将用于更新节点表示以改善模型的性能

8. 参数更新: 使用梯度下降等优化算法来更新图神经网络的参数,以减小损失函数的值。这会改善模型的性能,并使其能够更好地捕获图中的模式。

9. 重复迭代: 以上的步骤将在多次迭代中执行,通常称为"epoch"。在每个 epoch 中,会多次处理不同的 Mini-batch,直到模型收敛或达到预定的停止条件。

总之,Mini-batching 是一种有效处理大型图数据的方法,通过将数据划分成小批量并利用前向传播反向传播训练图神经网络。这有助于提高训练效率,使模型能够处理大规模的图数据

# 导入用于加载和准备图数据以进行训练的库
from torch_geometric.loader import DataLoader
# 常见DataLoader对象,用于加载训练数据
# batch_size=64:每次从训练数据中加载 64 个图作为一个批次。
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    # 当前小批量中包含的图的数量
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()
# DataBatch(edge_index=[2, 2420], x=[1091, 7], edge_attr=[2420, 4], y=[60], batch=[1091], ptr=[61])  
# data输出分析
# 1、edge_index=[2, 2420]:这部分信息表示批量中所有图的边索引。edge_index 是一个 2xN 的矩阵,其中 N 表示边的总数。
# 在这个批量中,共有 2420 条边,每一列代表一条边,前一行是起始节点,后一行是结束节点。
# 2、x=[1091, 7]:这部分信息表示批量中的节点特征。x 是一个节点特征矩阵,其中 1091 表示批量中有 1091 个节点,每个节点的特征向量的维度为 7。
# 3、edge_attr=[2420, 4]:这部分信息表示批量中的每条边的特征。edge_attr 是一个边特征矩阵,其中 2420 表示有 2420条边,每个边的特征向量的维度为 4。
# 4、y=[60]:这部分信息表示批量中的图的标签或类别。在这个示例中,批量中包含了60个图,每个图被标记为一个类别。
# 5、batch=[1091]:这部分信息是关于节点所属的子图(即批量中的哪个图)的信息。这表示在批量中的1091个节点中,有1091个节点属于同一个子图。
# 6、ptr=[61]:这部分信息指示批量中的每个图的节点范围。在这个示例中,有61个图,ptr 表示每个图的节点数目的范围,用于在批量中划分不同的图。

设计网络架构

  • 其实还是对各个节点进行特征编码,只不过现在多了一步聚合操作
  • 把各个节点特征汇总成全局特征就相当于得到了整个图的编码

from torch.nn import Linear  # 导入线性层的模块
import torch.nn.functional as F  # 导入pytorch中的函数模块,通常用于激活函数等操作
from torch_geometric.nn import GCNConv  # 导入GCVConv模块,用于实现图卷积层
from torch_geometric.nn import global_mean_pool  #  导入global_mean_pool模块,用于执行全局平均池化
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        # 设置随机种子,以确保结果的可重复性。
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        # 创建一个线性层,将 GCN 模型的输出特征维度(hidden_channels)映射到数据集的类别数量
        self.lin = Linear(hidden_channels, dataset.num_classes)
    def forward(self, x, edge_index, batch):
        # 1.对各节点进行编码
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        # 2. 平均操作(使用全局平均池化对节点特征进行汇总,生成一个全局特征向量 x。)
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        # 3. 输出
        x = F.dropout(x, p=0.5, training=self.training)
#       # 使用线性层将全局特征映射到类别数量,生成最终的输出。
        x = self.lin(x)
        
        return x
model = GCN(hidden_channels=64)
print(model)

# 创建模型
model = GCN(hidden_channels=64)
# 定义优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
# 训练函数
def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
# 测试模型
def test(loader):
     model.eval()  # 将模型设置为评估模式,以禁用一些特定于训练的操作,如 dropout。
    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.batch)  
        pred = out.argmax(dim=1)  # 找到每个样本中概率最高的类别,即模型的预测。
        correct += int((pred == data.y).sum())  # 将正确预测的样本数量累积到计数器中。
    return correct / len(loader.dataset)  # 计算并返回准确率,即正确预测的样本数量除以总样本数量。
# 训练和测试循环
for epoch in range(1, 201):
    train()
    train_acc = test(train_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}')
Epoch: 001, Train Acc: 0.6649
Epoch: 002, Train Acc: 0.6649
Epoch: 003, Train Acc: 0.6649
Epoch: 004, Train Acc: 0.6702
Epoch: 005, Train Acc: 0.6809
Epoch: 006, Train Acc: 0.7181
Epoch: 007, Train Acc: 0.7234
Epoch: 008, Train Acc: 0.7234
Epoch: 009, Train Acc: 0.7287
Epoch: 010, Train Acc: 0.7340
Epoch: 011, Train Acc: 0.7340
Epoch: 012, Train Acc: 0.7447
Epoch: 013, Train Acc: 0.7447
Epoch: 014, Train Acc: 0.7394
Epoch: 015, Train Acc: 0.7394
Epoch: 016, Train Acc: 0.7500
Epoch: 017, Train Acc: 0.7394
Epoch: 018, Train Acc: 0.7500
Epoch: 019, Train Acc: 0.7500
Epoch: 020, Train Acc: 0.7394
Epoch: 021, Train Acc: 0.7553
Epoch: 022, Train Acc: 0.7553
Epoch: 023, Train Acc: 0.7660
Epoch: 024, Train Acc: 0.7606
Epoch: 025, Train Acc: 0.7660
Epoch: 026, Train Acc: 0.7660
Epoch: 027, Train Acc: 0.7766
Epoch: 028, Train Acc: 0.7660
Epoch: 029, Train Acc: 0.7660
Epoch: 030, Train Acc: 0.7819
Epoch: 031, Train Acc: 0.7660
Epoch: 032, Train Acc: 0.7660
Epoch: 033, Train Acc: 0.7819
Epoch: 034, Train Acc: 0.7660
Epoch: 035, Train Acc: 0.7766
Epoch: 036, Train Acc: 0.7819
Epoch: 037, Train Acc: 0.7553
Epoch: 038, Train Acc: 0.7660
Epoch: 039, Train Acc: 0.7713
Epoch: 040, Train Acc: 0.7606
Epoch: 041, Train Acc: 0.7660
Epoch: 042, Train Acc: 0.7819
Epoch: 043, Train Acc: 0.7819
Epoch: 044, Train Acc: 0.7713
Epoch: 045, Train Acc: 0.7606
Epoch: 046, Train Acc: 0.7926
Epoch: 047, Train Acc: 0.7713
Epoch: 048, Train Acc: 0.7606
Epoch: 049, Train Acc: 0.7660
Epoch: 050, Train Acc: 0.7819
Epoch: 051, Train Acc: 0.7766
Epoch: 052, Train Acc: 0.7606
Epoch: 053, Train Acc: 0.7660
Epoch: 054, Train Acc: 0.7766
Epoch: 055, Train Acc: 0.7713
Epoch: 056, Train Acc: 0.7606
Epoch: 057, Train Acc: 0.7766
Epoch: 058, Train Acc: 0.7766
Epoch: 059, Train Acc: 0.7766
Epoch: 060, Train Acc: 0.7606
Epoch: 061, Train Acc: 0.7819
Epoch: 062, Train Acc: 0.7713
Epoch: 063, Train Acc: 0.7660
Epoch: 064, Train Acc: 0.7766
Epoch: 065, Train Acc: 0.7766
Epoch: 066, Train Acc: 0.7766
Epoch: 067, Train Acc: 0.7819
Epoch: 068, Train Acc: 0.7819
Epoch: 069, Train Acc: 0.7713
Epoch: 070, Train Acc: 0.7660
Epoch: 071, Train Acc: 0.7713
Epoch: 072, Train Acc: 0.7872
Epoch: 073, Train Acc: 0.7713
Epoch: 074, Train Acc: 0.7713
Epoch: 075, Train Acc: 0.7819
Epoch: 076, Train Acc: 0.7766
Epoch: 077, Train Acc: 0.7713
Epoch: 078, Train Acc: 0.7713
Epoch: 079, Train Acc: 0.7819
Epoch: 080, Train Acc: 0.7766
Epoch: 081, Train Acc: 0.7979
Epoch: 082, Train Acc: 0.7766
Epoch: 083, Train Acc: 0.7713
Epoch: 084, Train Acc: 0.7713
Epoch: 085, Train Acc: 0.7766
Epoch: 086, Train Acc: 0.7766
Epoch: 087, Train Acc: 0.7766
Epoch: 088, Train Acc: 0.7872
Epoch: 089, Train Acc: 0.7766
Epoch: 090, Train Acc: 0.7766
Epoch: 091, Train Acc: 0.7819
Epoch: 092, Train Acc: 0.7766
Epoch: 093, Train Acc: 0.7819
Epoch: 094, Train Acc: 0.7926
Epoch: 095, Train Acc: 0.7766
Epoch: 096, Train Acc: 0.7819
Epoch: 097, Train Acc: 0.7766
Epoch: 098, Train Acc: 0.7766
Epoch: 099, Train Acc: 0.7713
Epoch: 100, Train Acc: 0.8085
Epoch: 101, Train Acc: 0.7766
Epoch: 102, Train Acc: 0.7766
Epoch: 103, Train Acc: 0.7766
Epoch: 104, Train Acc: 0.7819
Epoch: 105, Train Acc: 0.7819
Epoch: 106, Train Acc: 0.8085
Epoch: 107, Train Acc: 0.7819
Epoch: 108, Train Acc: 0.7819
Epoch: 109, Train Acc: 0.7872
Epoch: 110, Train Acc: 0.7979
Epoch: 111, Train Acc: 0.7872
Epoch: 112, Train Acc: 0.7872
Epoch: 113, Train Acc: 0.7979
Epoch: 114, Train Acc: 0.7926
Epoch: 115, Train Acc: 0.7872
Epoch: 116, Train Acc: 0.8032
Epoch: 117, Train Acc: 0.7872
Epoch: 118, Train Acc: 0.7766
Epoch: 119, Train Acc: 0.8138
Epoch: 120, Train Acc: 0.7926
Epoch: 121, Train Acc: 0.7766
Epoch: 122, Train Acc: 0.8138
Epoch: 123, Train Acc: 0.7926
Epoch: 124, Train Acc: 0.7553
Epoch: 125, Train Acc: 0.7766
Epoch: 126, Train Acc: 0.7926
Epoch: 127, Train Acc: 0.7872
Epoch: 128, Train Acc: 0.7872
Epoch: 129, Train Acc: 0.7872
Epoch: 130, Train Acc: 0.8032
Epoch: 131, Train Acc: 0.7872
Epoch: 132, Train Acc: 0.7713
Epoch: 133, Train Acc: 0.8138
Epoch: 134, Train Acc: 0.7979
Epoch: 135, Train Acc: 0.7979
Epoch: 136, Train Acc: 0.7979
Epoch: 137, Train Acc: 0.7979
Epoch: 138, Train Acc: 0.7979
Epoch: 139, Train Acc: 0.8032
Epoch: 140, Train Acc: 0.8085
Epoch: 141, Train Acc: 0.8138
Epoch: 142, Train Acc: 0.8138
Epoch: 143, Train Acc: 0.8085
Epoch: 144, Train Acc: 0.8085
Epoch: 145, Train Acc: 0.7926
Epoch: 146, Train Acc: 0.8085
Epoch: 147, Train Acc: 0.8191
Epoch: 148, Train Acc: 0.7819
Epoch: 149, Train Acc: 0.8138
Epoch: 150, Train Acc: 0.8191
Epoch: 151, Train Acc: 0.7819
Epoch: 152, Train Acc: 0.8138
Epoch: 153, Train Acc: 0.8085
Epoch: 154, Train Acc: 0.8191
Epoch: 155, Train Acc: 0.8085
Epoch: 156, Train Acc: 0.8032
Epoch: 157, Train Acc: 0.8245
Epoch: 158, Train Acc: 0.7872
Epoch: 159, Train Acc: 0.8191
Epoch: 160, Train Acc: 0.8298
Epoch: 161, Train Acc: 0.8191
Epoch: 162, Train Acc: 0.8191
Epoch: 163, Train Acc: 0.8138
Epoch: 164, Train Acc: 0.8191
Epoch: 165, Train Acc: 0.8245
Epoch: 166, Train Acc: 0.8191
Epoch: 167, Train Acc: 0.8298
Epoch: 168, Train Acc: 0.8191
Epoch: 169, Train Acc: 0.8245
Epoch: 170, Train Acc: 0.8085
Epoch: 171, Train Acc: 0.8245
Epoch: 172, Train Acc: 0.8191
Epoch: 173, Train Acc: 0.8191
Epoch: 174, Train Acc: 0.8191
Epoch: 175, Train Acc: 0.8138
Epoch: 176, Train Acc: 0.8298
Epoch: 177, Train Acc: 0.8245
Epoch: 178, Train Acc: 0.8191
Epoch: 179, Train Acc: 0.8138
Epoch: 180, Train Acc: 0.8351
Epoch: 181, Train Acc: 0.8191
Epoch: 182, Train Acc: 0.8457
Epoch: 183, Train Acc: 0.8351
Epoch: 184, Train Acc: 0.8351
Epoch: 185, Train Acc: 0.8298
Epoch: 186, Train Acc: 0.8404
Epoch: 187, Train Acc: 0.8191
Epoch: 188, Train Acc: 0.8032
Epoch: 189, Train Acc: 0.8191
Epoch: 190, Train Acc: 0.8298
Epoch: 191, Train Acc: 0.8404
Epoch: 192, Train Acc: 0.8564
Epoch: 193, Train Acc: 0.8404
Epoch: 194, Train Acc: 0.8351
Epoch: 195, Train Acc: 0.8351
Epoch: 196, Train Acc: 0.8404
Epoch: 197, Train Acc: 0.8404
Epoch: 198, Train Acc: 0.8298
Epoch: 199, Train Acc: 0.8617
Epoch: 200, Train Acc: 0.8617
相关文章
点分类模型实战
点分类模型实战
|
数据可视化
R可视化学习(5) -- 脊线图
Ridgeline 图(脊线图),(有时称为Joyplot)可以同时显示几个组的数值分布情况,分布可以使用直方图或密度图来表示,它们都与相同的水平尺度对齐,并略有重叠。常常被用来可视化随时间或空间变化的多个分布/直方图变化。
290 0
R可视化学习(5) -- 脊线图
|
数据采集 数据可视化 算法
数据分析可视化常用图介绍以及相关代码实现(箱型图、Q-Q图、Kde图、线性回归图、热力图)
数据分析可视化常用图介绍以及相关代码实现(箱型图、Q-Q图、Kde图、线性回归图、热力图)
|
机器学习/深度学习
图(Graph)--经典图特征工程:链接分析
本文记录了关于学习图链接分析方法的个人笔记,以供参考学习
181 0
|
数据可视化 数据处理
R可视化学习(4) -- 棒棒糖图
棒棒糖图其实类似于柱状图加散点图的效果,因为他的形状就是由俩部分组成(点+线条),因此在ggplot中,我们只要通过geom_point()函数绘制"糖"的那一部分,geom_segment()函数绘制“棒棒”那一部分,就可轻松绘制出这种图形
191 0
|
机器学习/深度学习 Python
图(Graph)--经典图特征工程:节点分析
本文分享了关于学习参考多变博客文献后总结的 经典图特征工程 分析方法,以供参考学习
570 0
离散数学_十章-图 ( 2 ):图的术语和几种特殊的图(一)
离散数学_十章-图 ( 2 ):图的术语和几种特殊的图(一)
116 0
|
机器学习/深度学习
离散数学_十章-图 ( 2 ):图的术语和几种特殊的图(二)
离散数学_十章-图 ( 2 ):图的术语和几种特殊的图(二)
2297 0
|
机器学习/深度学习 DataWorks 算法
使用 PAI 进行分类分析 | 学习笔记
快速学习使用 PAI 进行分类分析
使用 PAI 进行分类分析 | 学习笔记
|
存储 C++
C++实现图 - 01 图的概述及实现
前面我们讲的数据结构都是针对于一对一或一对多的情形,如果涉及到多对多的复杂情况就要用到我们接下来讲解的图了,这一讲我们重点讲解邻接表、邻接矩阵、十字链表以及邻接多重表的代码实现。如果已经对图的概念比较熟悉的小伙伴,可以拉到下面看相关的代码实现。
318 0
C++实现图 - 01 图的概述及实现
下一篇
无影云桌面