【PyTorch深度强化学习】TD3算法(双延迟-确定策略梯度算法)的讲解及实战(超详细 附源码)

简介: 【PyTorch深度强化学习】TD3算法(双延迟-确定策略梯度算法)的讲解及实战(超详细 附源码)

需要源码请点赞关注收藏后评论区留言~~~

一、双延迟-确定策略梯度算法

在DDPG算法基础上,TD3算法的主要目的在于解决AC框架中,由函数逼近引入的偏差和方差问题。一方面,由于方差会引起过高估计,为解决过高估计问题,TD3将截断式双Q学习(clipped Double Q-Learning)应用于AC框架;另一方面,高方差会引起误差累积,为解决误差累积问题,TD3分别采用延迟策略更新和添加噪声平滑目标策略两种技巧。

过高估计问题解决方案

   从策略梯度方法已知,基于PG的强化学习存在过高估计问题,但由于DDPG评论家的目标值不是取最优动作值函数的,所以不存在最大化操作。此时,将Double DQN思想直接用于DDPG的评论家,构造如下目标函数:

y=r+γQ(s′,μ(s′,θ),w′)(bbb.11)(bbb.11)y=r+γQ(s′,μ(s′,θ),w′)

   实际上,这样的处理效果并不好,这是因为在连续动作空间中,策略变化缓慢,行动者更新较为平缓,使得预测QQ值与目标QQ值相差不大,无法避免过高估计问题。


   考虑将Double Q-Learning思想应用于DDPG,采用两个独立的评论家Qw1Qw1、Qw2Qw2和两个独立的行动者μθ1μθ1、μθ2μθ2,以50%的概率利用Q1Q1产生动作,然后更新Q2Q2估计值,而另外50%的概率正好相反。构建更新所需的两个目标值分别为:

{y1=r+γQ(s′,μ(s′,θ1),w′2)y2=r+γQ(s′,μ(s′,θ2),w′1)(bbb.12)(bbb.12){y1=r+γQ(s′,μ(s′,θ1),w2′)y2=r+γQ(s′,μ(s′,θ2),w1′)

   但由于样本均来自于同一经验池,不能保证样本数据完全独立,所以两个行动者的样本具有一定相关性,在一定的情况下,甚至会加剧高估问题。针对此种情形,秉持“宁可低估,也不要高估”的想法,对Double Q-Learning进行修改,构建基于Clipped Double Q-learning方法的目标值:

y=r+γmini=1,2Q(s′,μ(s′,θ1),w′i)(bbb.13)(bbb.13)y=r+γmini=1,2Q(s′,μ(s′,θ1),wi′)

   如式(bbb.13)所示,目标值只使用了一个行动者网络μθ1μθ1,取两个评论家网络Qw1Qw1和Qw2Qw2的最小值来作为值函数估计值。

   在更新评论家网络Qw1Qw1和Qw2Qw2时,均采用式(bbb.13)目标值y,共用如下损失函数:

L(wi)=𝔼s,a,r,s′∼[y−Q(s,a,wi)]2(bbb.14)(bbb.14)L(wi)=Es,a,r,s′∼D[y−Q(s,a,wi)]2

   该算法相比于原算法的区别仅在于多了一个和原评论家Qw1Qw1同步更新的辅助评论家Qw2Qw2,在更新目标值y时取最小值。不过这一修改仍然会让人疑惑,Qw1Qw1和Qw2Qw2只有初始参数不同,后面的更新都一样,这样形成的两个类似的评论家能否有效消除TD误差带来的偏置估计。

累积误差问题解决方案

   在函数逼近问题中,TD(0)算法的过高估计问题会进一步加剧,每次更新都会产生一定量的TD误差δ(s,a)δ(s,a):

Q(s,a,w)=r+γ𝔼[Q(s′,a′,w)]−δ(s,a)(bbb.15)(bbb.15)Q(s,a,w)=r+γE[Q(s′,a′,w)]−δ(s,a)

   经过多次迭代更新后,误差会被累积:

Q(St,At,w)=Rt+1+γ𝔼[Q(St+1,At+1,w)]−δt+1=Rt+1+γ𝔼[Rt+2+γ𝔼[Q(St+2,At+2,w)]−δt+2]−δt+1⋯⋯=𝔼Si∼ρβ,Ai∼μ[∑T−1γi−t(Ri+1−δi+1)](bbb.16)(bbb.16)Q(St,At,w)=Rt+1+γE[Q(St+1,At+1,w)]−δt+1=Rt+1+γE[Rt+2+γE[Q(St+2,At+2,w)]−δt+2]−δt+1⋯⋯=ESi∼ρβ,Ai∼μ[∑T−1γi−t(Ri+1−δi+1)]

   由此可见,估计的方差与未来奖励、未来TD误差的方差成正比。当折扣因子γγ较大时,每次更新都可以引起方差的快速提升,所以通常TD3设置较小的折扣系数γγ。

延迟的策略更新

   TD3目标网络的更新方式与DDPG相同,都采用软更新,尽管软更新比硬更新更有利于算法的稳定性,但AC算法依然会失败,其原因通常在于行动者和评论家的更新是相互作用的结果:评论家提供的值函数估计值不准确,就会使行动者将策略往错误方向改进;行动者产生了较差的策略,就会进一步加剧评论家误差累积问题,两者不断作用产生恶性循环。

   为解决以上问题,TD3考虑对策略进行延时更新,减少行动者的更新频率,尽可能等待评论家训练收敛后再进行更新操作。延时更新操作可以有效减少累积误差,从而降低方差;同时,也能减少不必要的重复更新操作,一定程度上提升效率。在实际应用时,TD3采取的操作是每隔评论家更新dd次后,再对行动者进行更新。

目标策略平滑操作

   上节中通过延时更新策略来减小误差累积,接下来考虑误差本身。首先,误差的根源是值函数逼近所产生的偏差,在机器学习中,消除估计偏差的常用方法就是对参数更新进行正则化,同样的,这一思想也可以应用在强化学习中。

   一个很自然的想法是,相似的动作应该拥有相似的价值,动作空间中目标动作周围的一小片区域的价值若能足够平滑,就可以有效减少误差的产生。TD3的具体做法是,为目标动作添加截断噪声:

ã ←μ(s′,θ′)+εε∼clip(N(0,σ),−c,c)(bbb.17)(bbb.17)a~←μ(s′,θ′)+εε∼clip⁡(N(0,σ),−c,c)

   该噪声处理也是一种正则化方式。通过这种平滑操作,可以增加算法的泛化能力,缓解过拟合问题,减少价值被过高估计的一些不良状态对策略学习的干扰。

二、TD3算法流程


  算法bbb.2 TD3算法(Lillicrap al. 2016)


  初始化:
     1. 初始化预测价值网络Qw1Qw1和Qw2Qw2,网络参数分别为w1w1和w2w2
     2. 初始化目标价值网络Qw′1Qw1′和Qw′2Qw2′,网络参数分别为w′1w1′和w′2w2′
     3. 初始化预测策略网络μθμθ和目标策略网络μθ′μθ′,网络参数分别为θθ和θ′θ′
     4. 同步参数w′1←w1w1′←w1,w′2←w2w2′←w2,θ′←θθ′←θ
     5. 经验池D的容量为NN
     6. 总迭代次数MM,折扣因子γγ,τ=0.0001τ=0.0001,随机小批量采样样本数量nn


     7. for ee=1 to MM do:
     8.   初始化状态设置为S0S0
     9.   repeat(情节中的每一时间步t=0,1,2,…t=0,1,2,…):
     10.     根据当前的预测策略网络和探索噪声来选择动作根据当前的预测策略网络和探索噪声来选择动作At=μ(St,θ)+εtAt=μ(St,θ)+εt,
          其中εt∼t(0,σ)εt∼Nt(0,σ)
     11.     执行动作AtAt,获得奖赏Rt+1Rt+1和下一状态St+1St+1
     12.     将经验转换(St,At,Rt+1,St+1)(St,At,Rt+1,St+1)存储在经验池D中
     13.     从经验池D中随机采样小批量的nn个经验转移样本(Si,Ai,Ri+1,Si+1)(Si,Ai,Ri+1,Si+1),计算:
          (1)扰动后的动作ã i+1←μ(Si+1,θ′)+εia~i+1←μ(Si+1,θ′)+εi,其中εi∼clip(t(0,σ̃ ),−c,c)εi∼clip⁡(Nt(0,σ~),−c,c)
          (2)更新目标yi=Ri+1+γmini=1,2Q(Si+1,ã i+1,w′i)yi=Ri+1+γmini=1,2Q(Si+1,a~i+1,wi′)
     14.     使用MBGD,根据最小化损失函数来更新价值网络(评论家网络)参数ww:

∇wL(w)≈1N∑iN(yi−Q(Si,Ai,w))∇wQ(Si,Ai,w)∇wL(w)≈1N∑iN(yi−Q(Si,Ai,w))∇wQ(Si,Ai,w)

     15.     if tt mod dd then
     16.       使用MBGA法,根据最大化目标函数来更新策略网络(行动者网络)参数θθ:

∇θĴ β(θ)≈1N∑i∇θμ(Si,θ)∇aQ(Si,a,w)|||||a=μ(Si,θ)∇θJ^β(θ)≈1N∑i∇θμ(Si,θ)∇aQ(Si,a,w)|a=μ(Si,θ)

     17.       软更新目标网络:{w′←τw+(1−τ)w′θ′←τθ+(1−τ)θ′{w′←τw+(1−τ)w′θ′←τθ+(1−τ)θ′
     18.   until t=T−1

三、实验环境

实验环境:OpenAI Gym工具包中的MuIoCo环境,用了其中四个连续控制任务,包括Ant,HalfCheetah,Walker2d,Hopper

每次训练 均运行1000000步,并每取5000步作为一个训练阶段,每个训练阶段结束,对所学策略进行测试评估 与环境交互十个情节并取平均返回值

结果如下图

可以发现在Ant和Walker2d任务中TD3由于采用了Clipped Double Q-Learning机制 较好的缓解了高估问题 减少了由于高估问题导致的不良状态对于策略更新乃至后续训练的不良影响,动作值逼近相对更为准确,因而相对DDPG而言,不容易陷入局部最优,Agent与环境交互所获得的回报,相比较会大幅提升,总而言之,与DDPG相比,TD3算法训练各阶段波动性更小,算法整体更加稳定

四、代码

部分源码如下

import numpy as np
import torch
import gym
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ReplayBuffer(object):
    def __init_
        self.ptr = 0
        self.size = 0
        self.state = np.zeros((max_size, state_dim))
        self.action = np.zeros((max_size, action_dim))
        self.next_state = np.zeros((max_size, state_dim))
        self.reward = np.zeros((max_size, 1))
        self.not_done = np.zeros((max_size, 1))
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.FloatTensor(self.action[ind]).to(self.device),
            torch.FloatTensor(self.next_state[ind]).to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device)
        )
class Actor(nn.Module):
        def __init__(self, state_dim, action_dim, max_action):
            super(Actor, self).__init__()
            self.l1 = nn.Linear(state_dim, 256)
            self.l2 = nn.Linear(256, 256)
            self.l3 = nn.Linear(256, action_dim)
            self.max_action = max_action
        def forward(self, state):
            a = F.relu(self.l1(state))
            a = F.relu(self.l2(a))
            return self.max_action * torch.tanh(self.l3(a))
class Critic(nn.Module):
        def __init__(self, state_dim, action_dim):
            super(Critic, self).__init__()
            # Q1 architecture
            self.l1 = nn.Linear(state_dim + action_dim, 256)
            self.l2 = nn.Linear(256, 256)
            self.l3 = nn.Linear(256, 1)
            # Q2 architecture
            self.l4 = nn.Linear(state_dim + action_dim, 256)
            self.l5 = nn.Linear(256, 256)
            self.l6 = nn.Linear(256, 1)
        def forward(self, state, action):
            sa = torch.cat([state, action], 1)
            q1 = F.relu(self.l1(sa))
            q1 = F.relu(self.l2(q1))
            q1 = self.l3(q1)
            q2 = F.relu(self.l4(sa))
            q2 = F.relu(self.l5(q2))
            q2 = self.l6(q2)
            return q1, q2
        def Q1(self, state, action):
            sa = torch.cat([state, action], 1)
            q1 = F.relu(self.l1(sa))
            q1 = F.relu(self.l2(q1))
            q1 = self.l3(q1)
            return q1
actor1=Actor(17,6,1.0)
for ch in actor1.children():
    print(ch)
print("*********************")
critic1=Critic(17,6)
for ch in critic1.children():
    print(ch)
class TD3(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        max_action,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        noise_clip=0.5,
        policy_freq=2
    ):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.total_it = 0
    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()
    def train(self, replay_buffer, batch_size=100):
        self.total_it += 1
        # Sample replay buffer
        state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (
                torch.randn_like(action) * self.policy_noise
            ).clamp(-self.noise_clip, self.noise_clip)
            next_action = (
                self.actor_target(next_state) + noise
            ).clamp(-self.max_action, self.max_action)
            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q
        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(state, action)
        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:
            # Compute actor losse
            actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            # Update the frozen target models
            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
    def save(self, filename):
        torch.save(self.critic.state_dict(), filename + "_critic")
        torch.save(self.critic_optimizer.state_dict(), filename + "_critic_optimizer")
        torch.save(self.actor.state_dict(), filename + "_actor")
        torch.save(self.actor_optimizer.state_dict(), filename + "_actor_optimizer")
    def load(self, filename):
        self.critic.load_state_dict(torch.load(filename + "_critic"))
        self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer"))
        self.critic_target = copy.deepcopy(self.critic)
        self.actor.load_state_dict(torch.load(filename + "_actor"))
        self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
        self.actor_target = copy.deepcopy(self.actor)
# Runs policy for X episodes and returns average reward
# A fixed seed is used for the eval environment
def eval_policy(policy, env_name, seed, eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)
    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward
    avg_reward /= eval_episodes
    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward
policy = "TD3"
env_name = "Walker2d-v4"  # OpenAI gym environment name
seed = 0  # Sets Gym, PyTorch and Numpy seeds
start_timesteps = 25e3  # Time steps initial random policy is used
eval_freq = 5e3  # How often (time steps) we evaluate
max_timesteps = 1e6  # Max time steps to run environment
expl_noise = 0.1  # Std of Gaussian exploration noise
batch_size = 256  # Batch size for both actor and critic
discount = 0.99  # Discount factor
tau = 0.005  # Target network update rate
policy_noise = 0.2  # Noise added to target policy during critic update
noise_clip = 0.5  # Range to clip target policy noise
policy_freq = 2  # Frequency of delayed policy updates
save_model = "store_true"  # Save model and optimizer parameters
load_model = ""  # Model load file name, "" doesn't load, "default" uses file_name
file_name = f"{policy}_{env_name}_{seed}"
print("---------------------------------------")
print(f"Policy: {policy}, Env: {env_name}, Seed: {seed}")
print("---------------------------------------")
if not os.path.exists("./results"):
    os.makedirs("./results")
if save_model and not os.path.exists("./models"):
    os.makedirs("./models")
env = gym.make(env_name)
# Set seeds
env.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
kwargs = {
    "state_dim": state_dim,
    "action_dim": action_dim,
    "max_action": max_action,
    "discount": discount,
    "tau": tau,
    "policy_noise": policy_noise * max_action,
    "noise_clip": noise_clip * max_action,
    "policy_freq": policy_freq
}
policy = TD3(**kwargs)
if load_model != "":
    policy_file = file_name if load_model == "default" else load_model
    policy.load(f"./models/{policy_file}")
replay_buffer = ReplayBuffer(state_dim, action_dim)
# Evaluate untrained policy
evaluations = [eval_policy(policy, env_name, seed)]
state, done = env.reset(), False
episode_reward = 0
episode_timesteps = 0
episode_num = 0
for t in range(int(max_timesteps)):
    episode_timesteps += 1
    # Select action randomly or according to policy
    if t < start_timesteps:
        action = env.action_space.sample()
    else:
        action = (
                policy.select_action(np.array(state))
                + np.random.normal(0, max_action * expl_noise, size=action_dim)
        ).clip(-max_action, max_action)
l = float(done) if episode_timesteps < env._max_episode_steps else 0
    # Store data in replay buffer
    replay_buffer.add(state, action, next_state, reward, done_bool)
    state = next_state
    episode_reward += reward
    # Train agent after collecting sufficient data
    if t >= start_timesteps:
        policy.train(replay_buffer, batch_size)
    if done:
end(eval_policy(policy, env_name, seed))
        np.save(f"./results/{file_name}", evaluations)
    if save_model:
        policy.save(f"./models/{file_name}")
state_dim

创作不易 觉得有帮助请点赞关注收藏~~~

相关文章
|
21天前
|
大数据 UED 开发者
实战演练:利用Python的Trie树优化搜索算法,性能飙升不是梦!
在数据密集型应用中,高效搜索算法至关重要。Trie树(前缀树/字典树)通过优化字符串处理和搜索效率成为理想选择。本文通过Python实战演示Trie树构建与应用,显著提升搜索性能。Trie树利用公共前缀减少查询时间,支持快速插入、删除和搜索。以下为简单示例代码,展示如何构建及使用Trie树进行搜索与前缀匹配,适用于自动补全、拼写检查等场景,助力提升应用性能与用户体验。
38 2
|
25天前
|
算法 搜索推荐 开发者
别再让复杂度拖你后腿!Python 算法设计与分析实战,教你如何精准评估与优化!
在 Python 编程中,算法的性能至关重要。本文将带您深入了解算法复杂度的概念,包括时间复杂度和空间复杂度。通过具体的例子,如冒泡排序算法 (`O(n^2)` 时间复杂度,`O(1)` 空间复杂度),我们将展示如何评估算法的性能。同时,我们还会介绍如何优化算法,例如使用 Python 的内置函数 `max` 来提高查找最大值的效率,或利用哈希表将查找时间从 `O(n)` 降至 `O(1)`。此外,还将介绍使用 `timeit` 模块等工具来评估算法性能的方法。通过不断实践,您将能更高效地优化 Python 程序。
35 4
|
9天前
|
存储 算法 安全
ArrayList简介及使用全方位手把手教学(带源码),用ArrayList实现洗牌算法,3个人轮流拿牌(带全部源码)
文章全面介绍了Java中ArrayList的使用方法,包括其构造方法、常见操作、遍历方式、扩容机制,并展示了如何使用ArrayList实现洗牌算法的实例。
11 0
|
2月前
|
算法 安全 数据安全/隐私保护
Android经典实战之常见的移动端加密算法和用kotlin进行AES-256加密和解密
本文介绍了移动端开发中常用的数据加密算法,包括对称加密(如 AES 和 DES)、非对称加密(如 RSA)、散列算法(如 SHA-256 和 MD5)及消息认证码(如 HMAC)。重点讲解了如何使用 Kotlin 实现 AES-256 的加密和解密,并提供了详细的代码示例。通过生成密钥、加密和解密数据等步骤,展示了如何在 Kotlin 项目中实现数据的安全加密。
71 1
|
2月前
|
机器学习/深度学习 存储 算法
强化学习实战:基于 PyTorch 的环境搭建与算法实现
【8月更文第29天】强化学习是机器学习的一个重要分支,它让智能体通过与环境交互来学习策略,以最大化长期奖励。本文将介绍如何使用PyTorch实现两种经典的强化学习算法——Deep Q-Network (DQN) 和 Actor-Critic Algorithm with Asynchronous Advantage (A3C)。我们将从环境搭建开始,逐步实现算法的核心部分,并给出完整的代码示例。
118 1
|
2月前
|
算法 安全 数据安全/隐私保护
Android经典实战之常见的移动端加密算法和用kotlin进行AES-256加密和解密
本文介绍了移动端开发中常用的数据加密算法,包括对称加密(如 AES 和 DES)、非对称加密(如 RSA)、散列算法(如 SHA-256 和 MD5)及消息认证码(如 HMAC)。重点展示了如何使用 Kotlin 实现 AES-256 的加密和解密,提供了详细的代码示例。
58 2
|
2天前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
17 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
5天前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
18 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
19天前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
27 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
1天前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
14 2