PyTorch深度学习实战 |手算GCN (图神经网络)模型

本文涉及的产品
RDS DuckDB + QuickBI 企业套餐,8核32GB + QuickBI 专业版
简介: 本文介绍了使用PyTorch实现图神经网络(GNN)处理分子结构数据的实战方法。主要内容包括:1) GNN的基本原理,通过节点特征矩阵和邻接矩阵处理图结构数据;2) 分子图的表示方式,将SMILES字符串转换为PyTorch Geometric图对象;3) 图卷积运算过程,包括特征变换和邻接特征聚合;4) 代码实现示例,构建包含GCN层和全局池化的模型,对乙醇分子进行特征提取和分类预测。文章通过具体案例展示了GNN在化学领域的应用,为读者提供了从理论到实践的完整指导。

 💻图神经网络

      图神经网络(Graph Neural Network,GNN)是一种专门用于处理图结构数据的深度学习方

法。与传统的神经网络主要处理规则结构的数据(如图像和文本)不同,GNN能够处理各种不规

则的数据结构,如社交网络、分子结构等。GNN通过在图上定义节点之间的连接关系,利用节点

的邻居信息来更新节点的表示,实现对整个图的信息传递和学习。

image.gif


📘分子的图结构

图神经网络(GNN)处理的输入是图(Graph),而不是传统的像素矩阵或序列。因此,第一步

是将我们的目标分子——乙醇,抽象地转化为一个数学图。

节点与邻接关系

为了简化手算过程,我们只关注重原子:两个碳原子和一个氧原子。我们将氢原子的影响体现在

节点的特征中。

邻接矩阵 : 描述节点之间的连接关系(化学键)。

image.gif

节点特征矩阵:是 GCN 的输入。为了演示目的,我们为每个原子分配一个简单的特

征向量(例如一个独热编码和它的连接度):

image.gif

🔬 真实项目中分子图的表示方式

在真实的分子图神经网络项目中,虽然图的基本原理(节点和边)是一样的,但节点特征图的复

杂度会有巨大的差异。在 PyTorch Geometric (PyG) 或 Deep Graph Library (DGL) 等专业 GNN 库

中处理分子(例如基于 RDKit 的分子表示)时,图的定义会丰富得多,最后我们会详细的介绍一

下这部分的内容。


📘图卷积

    图卷积的输入数据是节点特征矩阵H和邻接矩阵A,下图我们展示了3个节点的图,每个节点特

征数为3的特征,图卷积在计算的时候有两个关键的步骤,分别是节点特征的线性变换邻接特征

的聚合

GCN 单层的核心公式是:

节点特征的线性变换

            就是使用线性层,对节点特征矩阵进行线性变换,提取特征,节点的特征数目发生变化。

X' 现在代表了每个原子经过权重变换后的新特征。计算过程如下:

image.gif

邻接特征的聚合

    这里为了方便理解,邻接矩阵没有使用稀疏表示,就是使用邻接矩阵adj进行特征聚合,将相邻

节点中的特征信息,传导到该节点上。

新的特征就等于聚合邻居信息变换特征的乘积

聚合邻居信息

这个矩阵中的每一行和每一列的元素,定义了消息如何从邻居节点传递并加权求和

image.gif

总结一下:一次的图卷积,本质上是特征的变换(或者说是特征数的变化),在这个过程中,节点

特征矩阵数据,包含每个节点的属性信息,每层都会通过 GCN 运算更新。归一化邻接矩阵结构

定义图的拓扑结构和消息传递路径,它是固定的。


📘分子图的表达方式

🔬 真实项目中分子图的表示方式

在真实项目中,分子表示的流程是:

SMILES 字符串      RDKit 解析     PyG/DGL 图对象

安装相关的包:

pip install rdkit

image.gif

pip install torch_geometric

image.gif

# 假设 PyTorch, RDKit, PyTorch Geometric (PyG) 库已安装
import torch
from rdkit import Chem
from torch_geometric.utils.smiles import from_smiles  # PyG中用于SMILES转换的实用函数 (或使用更早版本的'from_rdkit')
# --- 1. 定义和转换 ---
smiles_ethanol = "CCO" 
# 使用 PyG 的封装函数,一步完成解析、特征提取和图结构构建
# 这个函数内部自动完成了原子特征编码、键索引构建、以及 PyTorch 张量转换。
ethanol_data = from_smiles(smiles_ethanol)
# --- 2. 展示结果 ---
print("=" * 40)
print(f"乙醇分子 SMILES: {smiles_ethanol}")
print("PyG Data 对象结构 (封装结果)")
print("=" * 40)
# PyG Data 对象概览
print(ethanol_data) 
print("\n--- 关键张量尺寸分析 ---")
print(f"节点特征矩阵 (x): {ethanol_data.x.shape}")
print(f"邻接信息 (edge_index): {ethanol_data.edge_index.shape}")
print("-" * 40)

image.gif

========================================

乙醇分子 SMILES: CCO

PyG Data 对象结构 (封装结果)

========================================

Data(x=[3, 9], edge_index=[2, 4], edge_attr=[4, 3], smiles='CCO')


--- 关键张量尺寸分析 ---

节点特征矩阵 (x): torch.Size([3, 9])

邻接信息 (edge_index): torch.Size([2, 4])

----------------------------------------

📈 PyG Data 对象结构解读

3 (行): 代表图中的节点数量, 模型忽略了 6 个氢原子。9 (列): 代表每个节点的特征维度特征提取

器为每个重原子编码了 9 种不同的化学属性(例如:原子类型、价态、电荷、隐式氢数等)。2

(行): PyG 的固定格式,第一行是源节点索引,第二行是目标节点索引。4 (列): 代表有向边的总

数。


图卷积神经网络

以乙醇分子为例,模拟前向传播

原始3×9大小的张量

经过隐藏层通道数为16的图卷积,得到3×16大小的张量

再经过隐藏层通道数为16的图卷积,得到3×16大小的张量

再经过一个全局平均池化,得到1×16大小的特征矩阵,全局平均将节点特征聚合为图特征(每个图一个数来表示)

然后再通过一个线性层变成1×16大小的张量

🔍 代码实现

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
import sys
# 尝试导入 RDKit 和 PyG 转换工具
from rdkit import Chem
from torch_geometric.utils.smiles import from_smiles
# --- 1. 真实数据准备 (使用 PyG 封装函数) ---
smiles_ethanol = "CCO" 
# 一步转换:生成包含所有 9 个原子(C, C, O, 6H)的图结构
ethanol_data = from_smiles(smiles_ethanol)
 # 修复:确保节点特征是浮点类型 (解决 RuntimeError)
ethanol_data.x = ethanol_data.x.float()
# 生成 Batch Tensor:9 个节点都属于同一个图 (batch size=1)
N_NODES = ethanol_data.x.shape[0]
ethanol_data.batch = torch.zeros(N_NODES, dtype=torch.long)
# --- 2. 展示输入数据结构 ---
print("=" * 60)
print(f"乙醇分子 SMILES: {smiles_ethanol}")
print("PyG Data 对象结构 (GCN 模型输入 - 真实配置)")
print("=" * 60)
print(f"节点数 N: {ethanol_data.x.shape[0]}")
print(f"节点特征矩阵 (x): {ethanol_data.x.shape}")
print(f"邻接信息 (edge_index): {ethanol_data.edge_index.shape}")
print("-" * 60)
# --- 3. 定义 GCN 模型类 (SimpleGCN) ---
class SimpleGCN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, num_classes)
    def forward(self, x, edge_index, batch):
        print(f"\n[A] 初始输入 x (H(0)): {x.shape}")
        # GCN 层 1
        x = self.conv1(x, edge_index)
        print(f"[B] GCNConv 1 输出 (H(1)): {x.shape}")
        x = F.relu(x)
        # GCN 层 2
        x = self.conv2(x, edge_index)
        print(f"[C] GCNConv 2 输出 (H(2)): {x.shape}")
        x = F.relu(x)
        # 全局读出/池化层
        x = global_mean_pool(x, batch)
        print(f"[D] Global Mean Pool 输出: {x.shape} <--- **图级特征**")
        # 线性分类层
        x = self.lin(x)
        print(f"[E] 最终分类层输出: {x.shape} <--- **预测结果**")
        return x
# --- 4. 模型实例化与运行 ---
# 定义模型参数
INPUT_DIM = ethanol_data.x.shape[1] # 自动获取真实/模拟的特征维度 (例如:11)
HIDDEN_DIM = 16 
OUTPUT_DIM = 1 
# 实例化模型
model = SimpleGCN(
    num_node_features=INPUT_DIM, 
    hidden_channels=HIDDEN_DIM, 
    num_classes=OUTPUT_DIM
)
# 执行前向传播
print("=" * 60)
print(f"【Simple GCN 前向传播过程(隐藏层维度 D_hidden={HIDDEN_DIM})】")
print("=" * 60)
output = model(
    ethanol_data.x, 
    ethanol_data.edge_index, 
    ethanol_data.batch
)

image.gif

============================================================

乙醇分子 SMILES: CCO

PyG Data 对象结构 (GCN 模型输入 - 真实配置)

============================================================

节点数 N: 3

节点特征矩阵 (x): torch.Size([3, 9])

邻接信息 (edge_index): torch.Size([2, 4])

------------------------------------------------------------

============================================================

【Simple GCN 前向传播过程(隐藏层维度 D_hidden=16)】

============================================================


[A] 初始输入 x (H(0)): torch.Size([3, 9])

[B] GCNConv 1 输出 (H(1)): torch.Size([3, 16])

[C] GCNConv 2 输出 (H(2)): torch.Size([3, 16])

[D] Global Mean Pool 输出: torch.Size([1, 16]) <--- **图级特征**

[E] 最终分类层输出: torch.Size([1, 1]) <--- **预测结果**


目录
相关文章
|
1小时前
|
机器学习/深度学习 数据采集 人工智能
田间杂草检测数据集分享(适用于YOLO系列深度学习分类检测任务)
本数据集含4000张真实农田图像(小麦/玉米/水稻田),YOLO格式标注杂草目标,覆盖多天气、光照与视角,适用于YOLO系列等目标检测模型训练,助力智能除草与精准农业研究。(239字)
202 16
|
1小时前
|
存储 搜索推荐 大数据
优路教育借助阿里云Flink+StarRocks+Paimon湖仓一体化构建职业教育业务全链路实时数据服务平台
优路教育大数据团队携手阿里云,基于实时计算 Flink + EMR Serverless StarRocks + DLF(Paimon) 构建了全链路实时数据服务平台,从学员画像、营销筛选到题库关联查询,实现了从“分钟级延迟”到“秒级响应”的质变,为成人教育行业的数据化转型提供了标杆实践。
|
1小时前
|
人工智能 缓存 弹性计算
阿里云服务器2核4G5M199元解析:独享型u1实例,性能、适用场景、购买和续费规则介绍
阿里云通用算力型u1实例(ecs.u1-c1m2.large)2核4G、5M带宽、80G ESSD Entry云盘,活动特惠价仅199元/年(官网价3498.36元),企业新老用户同享,续费同价至2027年3月31日,每人限购1台。该实例采用独享型架构,搭载Intel至强可扩展处理器,内网带宽1Gbit/s、收发包30万PPS、云盘IOPS 1万,性能稳定,适合企业官网、中小Web应用、轻量数据库及开发测试等场景。
|
1小时前
|
自然语言处理 前端开发 安全
2026 世界杯钓鱼即服务平台攻击机理与防御体系研究
2026世界杯前夕,“Ghost Stadium”中文钓鱼即服务平台发动大规模攻击,涉案4.7–10亿美元,受害超4.7万人,窃取FIFA凭证2500+条,注册恶意域名超4000个。该平台采用React+Layui实现像素级克隆、SSO模拟与多语言适配,构建覆盖社交广告、搜索、IM的立体攻击网络。本文基于实证分析,提出检测、响应、溯源、治理闭环防御体系,强调跨机构协同与动态对抗。(239字)
144 10
|
1小时前
|
数据采集 数据可视化 数据挖掘
表格魔法师:QoderWork CN 让脏数据秒变仪表盘
本文介绍如何使用阿里QoderWork CN桌面应用,通过内置xlsx技能自动化完成Excel数据清洗(统一日期格式、补全空值、去重等)与可视化(生成含仪表盘、日志、交互表格及图表的HTML报告),提升数据分析效率。
351 8
|
1小时前
|
人工智能 机器人 芯片
人工智能|YOLOv8实战
本内容为安全帽检测实战项目,基于YOLOv8模型,涵盖Kaggle数据获取、自定义yaml配置、模型训练(yolo_train.py)与测试(yolo_test.py),并提供服务器(FastAPI+Docker)、边缘(Jetson+TensorRT)及国产嵌入式(RK3588+RKNN)三类部署方案,支持工业场景实时智能识别。(239字)
164 1
|
1小时前
|
人工智能 数据可视化 安全
阿里云百炼Token Plan是什么?核心定义、功能及优惠订阅方案详解
随着AI大模型应用从个人尝鲜走向企业规模化落地,模型调用的成本管控、额度管理、团队协作与服务稳定性成为核心痛点。传统按量付费模式虽灵活,但易出现账单波动、预算不可控、高峰调用受限等问题,难以适配团队长期、高频、稳定的AI使用需求。阿里云百炼平台作为一站式大模型服务平台,推出的Token Plan订阅方案,正是为解决这些痛点而生。
257 0
|
1小时前
|
存储 人工智能 安全
|
1小时前
|
人工智能 弹性计算 运维
新手必看教程 阿里云部署Hermes Agent并配置百炼Token Plan完整实操指南
在AI智能体快速普及的当下,具备自主学习、长效记忆、多任务执行能力的智能框架逐渐成为个人办公、项目开发、自动化运维的核心工具。Hermes Agent作为一款热门开源自进化AI智能体,凭借宽松开源协议、跨会话持久记忆、自主技能迭代、多模型兼容等特色能力脱颖而出。它区别于传统对话类工具,不仅可以完成日常问答、内容创作,还能自主拆解复杂任务、沉淀使用习惯、复用过往工作经验,真正实现“越用越智能”,同时支持私有化部署,所有数据本地留存,隐私安全性突出。
182 1
|
1小时前
|
人工智能 缓存 自然语言处理
【AI 尝鲜实验室】上新 | Qoder x DevBox:让 AI 在云端放开手脚写代码
Qoder × 阿里云DevBox融合AI编程与云端开发:用自然语言描述需求,AI自动生成、调试并运行代码;所有依赖、构建、验证均在云端完成,本地零环境负担。2C2G轻量开发机低至0.1元/小时,适合快速验证、轻薄本开发及避免本地污染。