VQ-VAE:矢量量化变分自编码器,离散化特征学习模型

简介: VQ-VAE 是变分自编码器(VAE)的一种改进。这些模型可以用来学习有效的表示。本文将深入研究 VQ-VAE 之前,不过,在这之前我们先讨论一些概率基础和 VAE 架构。

VQ-VAE 是变分自编码器(VAE)的一种改进。这些模型可以用来学习有效的表示。本文将深入研究 VQ-VAE 之前,不过,在这之前我们先讨论一些概率基础和 VAE 架构。

后验和先验分布

证据下界(ELBO)

在机器学习模型中,大多数后验分布都相当复杂。我们使用变分推理这一基于优化的方法来近似这些分布。ELBO 是变分推理中一个至关重要的目标函数。其推导方式如下。

重构项用于评估解码器从潜在变量重构输入的能力。KL散度项则充当正则化机制。

变分自编码器(VAE)

标准的自编码器将输入映射到潜在空间中的单个点。然而,VAE的编码器输出概率分布的参数(均值和方差)。模型从这个分布中采样一个点,然后将其输入到解码器中。

我们使用ELBO作为损失函数。

VAE存在后验崩溃的问题:模型中的正则化项开始主导损失函数,后验分布变得与先验分布相似。解码器变得过于强大,忽略了潜在表示。因此后验分布将不包含有关潜在变量的信息。

在VQ-VAE中,通过矢量量化步骤避免了后验崩溃。

矢量量化变分自编码器(VQ-VAE)

离散表示可以有效地用来提高机器学习模型的性能。人类语言本质上是离散的,使用符号表示。我们可以使用语言来解释图像。因此在机器学习中使用潜在空间的离散表示是一个自然的选择。

首先,编码器生成嵌入。然后从码本中为给定嵌入选择最佳近似。码本由离散向量组成。使用L2距离进行最近邻查找。

在反向传播过程中,通过嵌入选择步骤的梯度流动并非易事。编码器的输出嵌入和解码器的输入嵌入具有相同的维度。所以直接将解码器输入的梯度复制到编码器输出(红色箭头)。这样可以产生一个良好的梯度近似。

在训练过程中,梯度可以推动编码器嵌入(绿色圆圈)靠近不同的离散表示(紫色圆圈)。

优化编码器、解码器和嵌入(即码本)。损失函数可以用以下方式表达。

第一个术语是重构损失(类似于标准的VAE)。它衡量解码器在生成与输入分布相似的输出方面的表现。如果输入是正态分布的,这一项将是简单的均方误差。

sg 是停止梯度操作符,用来停止参数学习。由于从解码器到编码器的直接路径,重构损失项不会向嵌入提供学习信号。所以使用第二项来优化码本,将嵌入推向编码器表示。

第三项是commitment损失。它防止嵌入任意增长。

解码器仅由第一项优化。第一项和第三项优化编码器。第二项优化码本。

在训练期间,先验保持均匀。因此,ELBO的KL散度项是恒定的。

Pytorch实现

矢量量化器可以通过以下方式实现。

 classVectorQuantizer(nn.Module):
     def__init__(self, num_embeddings, embedding_dim, commitment_cost):
         super(VectorQuantizer, self).__init__()

         self._embedding_dim=embedding_dim
         self._num_embeddings=num_embeddings

         self._embedding=nn.Embedding(self._num_embeddings, self._embedding_dim)
         self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
         self._commitment_cost=commitment_cost

     defforward(self, inputs):
         # convert inputs from BCHW -> BHWC
         inputs=inputs.permute(0, 2, 3, 1).contiguous()
         input_shape=inputs.shape

         # Flatten input
         flat_input=inputs.view(-1, self._embedding_dim)

         # Calculate distances
         distances= (torch.sum(flat_input**2, dim=1, keepdim=True) 
                     +torch.sum(self._embedding.weight**2, dim=1)
                     -2*torch.matmul(flat_input, self._embedding.weight.t()))

         # Encoding
         encoding_indices=torch.argmin(distances, dim=1).unsqueeze(1)
         encodings=torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
         encodings.scatter_(1, encoding_indices, 1)

         # Quantize and unflatten
         quantized=torch.matmul(encodings, self._embedding.weight).view(input_shape)

         # Loss
         e_latent_loss=F.mse_loss(quantized.detach(), inputs)
         q_latent_loss=F.mse_loss(quantized, inputs.detach())
         loss=q_latent_loss+self._commitment_cost*e_latent_loss

         quantized=inputs+ (quantized-inputs).detach()
         avg_probs=torch.mean(encodings, dim=0)
         perplexity=torch.exp(-torch.sum(avg_probs*torch.log(avg_probs+1e-10)))

         # convert quantized from BHWC -> BCHW
         returnloss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings

我们将输入扁平化,并保持嵌入空间的维数为_embedding_dim。假设输入为 16,32,32,64 BHWC/ batch, height, width, channels 。被压扁成[16384,64]。

 # Flatten input
 flat_input = inputs.view(-1, self._embedding_dim)

然后计算从每个嵌入向量到每个码本向量的距离的平方。假设(N, D)是编码器的输出,(K, D)是码本。得到(N, K)大小的结果。

 distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                     + torch.sum(self._embedding.weight**2, dim=1)
                     - 2 * torch.matmul(flat_input, self._embedding.weight.t()))

接下来,我们跨dim = 1(跨码本)执行简单的argmin,获得与编码器输出距离最小的嵌入。我们生成N个大小为K的一元向量。

 encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
 encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)

将嵌入表与这个独热向量相乘以提取最接近的码本向量。这就是量化过程。

接下来定义损失项(重建损失除外)。Mse代表均方误差,.detach作为停止梯度操作。

 e_latent_loss = F.mse_loss(quantized.detach(), inputs)
 q_latent_loss = F.mse_loss(quantized, inputs.detach())
 loss = q_latent_loss + self._commitment_cost * e_latent_loss

最后确保梯度可以直接从解码器流向编码器。

 quantized = inputs + (quantized - inputs).detach()

从数学上讲,左右两边是相等的(+输入和-输入将相互抵消)。在反向传播过程中,.detach部分将被忽略

以上就是VQ VAE的完整实现,原始的完整代码可以在这里找到:

https://avoid.overfit.cn/post/85355d48ece84f77b7c1b02f60de9c8f
最后论文:ArXiv. /abs/1711.00937

作者:Kavishka Abeywardana

目录
相关文章
夸克网盘的文件怎么保存到百度网盘?
夸克网盘的文件怎么保存到百度网盘?
9739 2
夸克网盘的文件怎么保存到百度网盘?
|
机器学习/深度学习 人工智能 达摩院
AIGC玩转卡通化技术实践
伴随着持续不断的AIGC浪潮,越来越多的AI生成玩法正在被广大爱好者定义和提出,图像卡通化(动漫化)基于其还原效果高,风格种类丰富等特点而备受青睐。早在几年前,伴随着GAN网络的兴起,卡通化就曾经风靡一时。而今,伴随着AIGC技术的兴起和不断发展,扩散生成模型为卡通化风格和提供了更多的创意和生成的可能性。本文就将详细介绍达摩院开放视觉团队的卡通化技术实践。
|
5月前
|
人工智能 监控 算法
Transformer模型训练全解析:从数据到智能的炼金术
模型训练是让AI从数据中学习规律的过程,如同教婴儿学语言。预训练相当于通识教育,为模型打下通用知识基础;后续微调则针对具体任务。整个过程包含数据准备、前向传播、损失计算、反向更新等步骤,需克服过拟合、不稳定性等挑战,结合科学与艺术,最终使模型具备智能。
|
1月前
|
JSON NoSQL Redis
OpenClaw核心源码解读:从Gateway到Pi-embedded的完整调用链分析
本文直击OpenClaw实战痛点,剖析其“云端大脑(Orchestrator)+协议桥(Gateway)+本地执行端(Pi-embedded)”三层解耦架构,详解指令流转、沙箱隔离、节点注册与长连接避坑要点,助开发者快速定位超时、不响应等常见问题。
|
机器学习/深度学习 并行计算 PyTorch
【pytorch】【202504】关于torch.nn.Linear
小白从开始这段代码展示了`nn.Linear`的使用及其背后的原理。 此外,小白还深入研究了PyTorch的核心类`torch.nn.Module`以及其子类`torch.nn.Linear`的源码。`grad_fn`作为张量的一个属性,用于指导反向传播 进一步地,小白探讨了`requires_grad`与叶子节点(leaf tensor)的关系。叶子节点是指在计算图中没有前驱操作的张量,只有设置了`requires_grad=True`的叶子节点才会在反向传播时保存梯度。 最后,小白学习了PyTorch中的三种梯度模式 通过以上学习小白对PyTorch的自动求导机制有了更深刻的理解。
478 6
|
机器学习/深度学习 缓存 自然语言处理
深入解析Tiktokenizer:大语言模型中核心分词技术的原理与架构
Tiktokenizer 是一款现代分词工具,旨在高效、智能地将文本转换为机器可处理的离散单元(token)。它不仅超越了传统的空格分割和正则表达式匹配方法,还结合了上下文感知能力,适应复杂语言结构。Tiktokenizer 的核心特性包括自适应 token 分割、高效编码能力和出色的可扩展性,使其适用于从聊天机器人到大规模文本分析等多种应用场景。通过模块化设计,Tiktokenizer 确保了代码的可重用性和维护性,并在分词精度、处理效率和灵活性方面表现出色。此外,它支持多语言处理、表情符号识别和领域特定文本处理,能够应对各种复杂的文本输入需求。
1481 6
深入解析Tiktokenizer:大语言模型中核心分词技术的原理与架构
|
机器学习/深度学习 数据采集 人工智能
GAN的主要介绍
【10月更文挑战第6天】
|
机器学习/深度学习 数据可视化 Linux
深度学习模型可视化工具——Netron使用介绍
深度学习模型可视化工具——Netron使用介绍
3494 2