RGCN模型成功运行案例

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,5000CU*H 3个月
简介: # 创建模型in_feats = 3hid_feats = 4out_feats = 2rel_num = 4model = RGCN(in_feats, hid_feats, out_feats, rel_num)# 随机生成特征features = torch.randn((10, 3))# 计算输出output = model(g, features, rel_type)print(output)
import torch
from torch import nn
import dgl
from dgl.nn.pytorch import RelGraphConv
# 定义图结构
edges = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_src = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_dst = torch.tensor([1, 2, 3, 0, 4, 5, 6, 7, 8, 9])
rel_type = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3])
graph = (edges_src, edges_dst)
# 将元组图结构转换为DGLGraph对象
g = dgl.graph(graph)
# 定义模型
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_num):
        super(RGCN, self).__init__()
        self.conv1 = RelGraphConv(in_feats, hid_feats, rel_num)
        self.conv2 = RelGraphConv(hid_feats, out_feats, rel_num)
    def forward(self, g, feats, rel_type):
        h = self.conv1(g, feats, rel_type)
        h = torch.relu(h)
        h = self.conv2(g, h, rel_type)
        return h
# 创建模型
in_feats = 3
hid_feats = 4
out_feats = 2
rel_num = 4
model = RGCN(in_feats, hid_feats, out_feats, rel_num)
# 随机生成特征
features = torch.randn((10, 3))
# 计算输出
output = model(g, features, rel_type)
print(output)

上面的代码主要是实现了一个基于关系图卷积网络(RGCN)的模型。

代码实现的详细解释如下:

  1. 首先,使用 PyTorch 定义了一个 RGCN 模型,并通过构造函数中的 in_feats, hid_feats, out_featsrel_num 来指定输入特征维度,隐藏层维度,输出特征维度和关系类型数量。


class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_num):
        super(RGCN, self).__init__()
        self.conv1 = RelGraphConv(in_feats, hid_feats, rel_num)
        self.conv2 = RelGraphConv(hid_feats, out_feats, rel_num)
  1. 在模型的前向传递函数中,将输入特征 feats 和关系类型 rel_type 通过 RelGraphConv 进行关系图卷积,同时使用 ReLU 激活函数对隐藏层的特征进行非线性变换。


def forward(self, g, feats, rel_type):
    h = self.conv1(g, feats, rel_type)
    h = torch.relu(h)
    h = self.conv2(g, h, rel_type)
    return h
  1. 接着,使用 PyTorch 创建了一个大小为 10x3 的张量 features 作为输入特征。
features = torch.randn((10, 3))
  1. 将定义好的元组图结构 graph 转换为 DGLGraph 对象 g


g = dgl.graph(graph)
  1. 最后,将 g, featuresrel_type 作为输入,通过调用模型的前向传递函数 forward 计算输出特征 output,并将其打印出来。
output = model(g, features, rel_type)
print(output)

在计算输出时,模型会将输入的元组图结构 graph 转换为 DGLGraph 对象 gg 的结构与 graph 是相同的,但是由于它是 DGLGraph 对象,因此可以使用 DGL 提供的一系列 API 来进行图的操作和计算,例如采样、聚合和传播等。在 RGCN 中,DGL 提供的关系图卷积层(RelGraphConv)用于实现图卷积计算。


----------------介绍整个代码实现了什么功能?


这段代码实现了一个基于关系图卷积神经网络(R-GCN)的节点分类任务,具体实现过程如下:

  1. 定义图结构:使用 PyTorch 张量定义了一张包含 10 条边的有向图,并指定每条边的起点、终点和关系类型。
  2. 将元组图结构转换为 DGLGraph 对象:使用 DGL 库提供的函数将元组图结构转换为 DGLGraph 对象,以便后续在其上运行图卷积神经网络。
  3. 定义 R-GCN 模型:使用 PyTorch 搭建了一个包含两个 R-GCN 层的模型,每个层都使用了相同的关系类型数,并指定了输入特征维度、隐藏层维度和输出特征维度。
  4. 随机生成特征:使用 PyTorch 生成了一个形状为 (10, 3) 的张量作为节点特征。
  5. 计算输出:将定义好的模型、DGLGraph 对象和节点特征输入模型,计算输出。输出的形状为 (10, 2),即 10 个节点对应的分类概率。最后将输出打印出来。


解释下面的含义:


edges = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_src = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_dst = torch.tensor([1, 2, 3, 0, 4, 5, 6, 7, 8, 9])
rel_type = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3])

这里定义了一个有向图,其中图中包含10个节点,通过分别指定edges_src和edges_dst中的源节点和目标节点来定义每条边。rel_type是用来指定每个边的关系类型,共有4种类型,分别用0、1、2、3表示。具体来说,这里定义了以下10条有向边:


  • 0 -> 1 (关系类型为0)
  • 1 -> 2 (关系类型为0)
  • 2 -> 3 (关系类型为0)
  • 3 -> 0 (关系类型为1)
  • 4 -> 1 (关系类型为1)
  • 5 -> 2 (关系类型为1)
  • 6 -> 3 (关系类型为2)
  • 7 -> 4 (关系类型为2)
  • 8 -> 5 (关系类型为2)
  • 9 -> 6 (关系类型为3)


上述代码定义了一个图结构,其中包含10个节点和10条边。变量含义如下:

  • edges: 表示图中10条边的编号,取值为 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  • edges_src: 表示每条边的源节点编号,取值为 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],与edges一一对应。
  • edges_dst: 表示每条边的目标节点编号,取值为 [1, 2, 3, 0, 4, 5, 6, 7, 8, 9],与edges一一对应。
  • rel_type: 表示每条边的关系类型,取值为 [0, 0, 0, 1, 1, 1, 2, 2, 2, 3],与edges一一对应。
  • graph: 表示由节点和边构成的图结构,它是一个元组,包含了两个张量 edges_srcedges_dst


最后这段代码实现了以下功能:

# 创建模型
in_feats = 3
hid_feats = 4
out_feats = 2
rel_num = 4
model = RGCN(in_feats, hid_feats, out_feats, rel_num)
# 随机生成特征
features = torch.randn((10, 3))
# 计算输出
output = model(g, features, rel_type)
print(output)
  1. 创建一个RGCN模型对象 model,该模型具有 in_feats 个输入特征、hid_feats 个隐藏特征、out_feats 个输出特征和 rel_num 种不同的关系类型。
  2. 使用 torch.randn 随机生成一个 $10 \times 3$ 大小的特征矩阵 features
  3. 将特征矩阵 features、关系类型张量 rel_type 和转换后的DGL图 g 作为输入,通过 model 模型计算输出特征矩阵 output
  4. 输出 output


输出

tensor([[ 0.5757, -0.0934],
        [ 0.9615,  1.4563],
        [-3.5925, -0.7869],
        [ 1.2882, -0.3457],
        [ 2.5402, -0.2980],
        [ 0.1554,  0.6599],
        [ 3.8173,  2.2265],
        [ 0.8300,  1.1929],
        [ 2.6410,  3.7959],
        [-1.5862, -0.5873]], grad_fn=<AddBackward0>)


























目录
相关文章
|
23天前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
40 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
23天前
|
机器学习/深度学习 并行计算 数据可视化
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
这篇文章介绍了如何使用PaddleClas框架完成多标签分类任务,包括数据准备、环境搭建、模型训练、预测、评估等完整流程。
66 0
目标分类笔记(二): 利用PaddleClas的框架来完成多标签分类任务(从数据准备到训练测试部署的完整流程)
|
5月前
|
Serverless PyTorch 文件存储
函数计算产品使用问题之如何使用并运行PyTorch
函数计算产品作为一种事件驱动的全托管计算服务,让用户能够专注于业务逻辑的编写,而无需关心底层服务器的管理与运维。你可以有效地利用函数计算产品来支撑各类应用场景,从简单的数据处理到复杂的业务逻辑,实现快速、高效、低成本的云上部署与运维。以下是一些关于使用函数计算产品的合集和要点,帮助你更好地理解和应用这一服务。
|
6月前
|
人工智能 测试技术 开发者
大模型自动生成并运行代码的体验与优化
随着近两年大模型的不断发展,它们在各个领域展示出了惊人的能力,可以说是在各个领域到了“开花结果”的阶段。比如最近技术圈比较火的阿里云的通义千问已经可以自己写代码、跑代码了,作为开发者,我觉得这种能力不仅提高了开发效率,还推动了编程实践向更高层次的转变和发展。但是,在使用大模型自动生成代码时,我们也会面临一些挑战,其中之一是代码可能会曲解开发者的需求。那么本文就来分享一下个个人的体验以及如何优化这种情况。
715 2
大模型自动生成并运行代码的体验与优化
|
5月前
|
存储 设计模式 C语言
技术笔记:QOM模型初始化流程
技术笔记:QOM模型初始化流程
30 0
|
6月前
|
自然语言处理
【大模型】如何使用提示工程来改善 LLM 输出?
【5月更文挑战第5天】【大模型】如何使用提示工程来改善 LLM 输出?
|
6月前
|
监控 负载均衡 测试技术
大模型开发:描述一个你之前工作中的模型部署过程。
完成大型语言模型训练后,经过验证集评估和泛化能力检查,进行模型剪枝与量化以减小规模。接着导出模型,封装成API,准备服务器环境。部署模型,集成后端服务,确保安全,配置负载均衡和扩容策略。设置监控和日志系统,进行A/B测试和灰度发布。最后,持续优化与维护,根据线上反馈调整模型。整个流程需团队协作,保证模型在实际应用中的稳定性和效率。
118 3
|
6月前
|
机器学习/深度学习 数据采集 人工智能
人工智能,应该如何测试?(四)模型全生命周期流程与测试图
本文补充了完整的业务和测试流程,包括生命周期流程图,强调测试人员在模型测试中的角色。主要测试活动有:1) 离线模型测试,使用训练集、验证集和测试集评估模型;2) 线上线下一致性测试,确保特征工程的一致性;3) A/B Test,逐步替换新旧模型以观察效果;4) 线上模型监控,实时跟踪用户行为变化;5) 数据质量测试,验证新数据质量以防影响模型效果。
130 0
|
6月前
|
机器学习/深度学习 JSON 自然语言处理
python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
|
11月前
优化模型案例
优化模型案例