我在pytorch中运行LSTM,但据我所知,它只取序列长度= 1。当我将序列长度整形为4或其他数字时,就会得到输入和目标长度不匹配的错误。如果我同时对输入和目标进行整形,那么模型会抱怨它不接受多目标标签。 我的训练数据集有66512行和16839列,目标中有3个类别/类。我想使用批处理大小为200和序列长度为4,即在一个序列中使用4行数据。 请建议如何调整我的模型和/或数据,以便能够运行模型的各种序列长度(例如,4)。
batch_size=200
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_target = torch.tensor(train_data[['Label1','Label2','Label3']].values.astype(np.float32))
train_target = np.argmax(train_target, axis=1)
train = torch.tensor(train_data.drop(['Label1','Label2','Label3'], axis = 1).values.astype(np.float32))
train_tensor = TensorDataset(train.unsqueeze(1), train_target)
train_loader = DataLoader(dataset = train_tensor, batch_size = batch_size, shuffle = True)
print(train.shape)
print(train_target.shape)
torch.Size([66512, 16839])
torch.Size([66512])
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTMModel, self).__init__()
# Hidden dimensions
self.hidden_dim = hidden_dim
# Number of hidden layers
self.layer_dim = layer_dim
# Building LSTM
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
# Readout layer
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# Initialize hidden state with zeros
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
# Initialize cell state
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
out, (hn, cn) = self.lstm(x, (h0,c0))
# Index hidden state of last time step
out = self.fc(out[:, -1, :])
return out
input_dim = 16839
hidden_dim = 100
output_dim = 3
layer_dim = 1
batch_size = batch_size
num_epochs = 1
model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.1
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
print(len(list(model.parameters())))
for i in range(len(list(model.parameters()))):
print(list(model.parameters())[i].size())
6
torch.Size([400, 16839])
torch.Size([400, 100])
torch.Size([400])
torch.Size([400])
torch.Size([3, 100])
torch.Size([3])
for epoch in range(num_epochs):
for i, (train, train_target) in enumerate(train_loader):
# Load data as a torch tensor with gradient accumulation abilities
train = train.requires_grad_().to(device)
train_target = train_target.to(device)
# Clear gradients w.r.t. parameters
optimizer.zero_grad()
# Forward pass to get output/logits
outputs = model(train)
# Calculate Loss: softmax --> cross entropy loss
loss = criterion(outputs, train_target)
# Getting gradients w.r.t. parameters
loss.backward()
# Updating parameters
optimizer.step()
print('Epoch: {}. Loss: {}. Accuracy: {}'.format(epoch, np.around(loss.item(), 4), np.around(accuracy,4)))
问题来源StackOverflow 地址:/questions/59381695/lstm-in-pytorch-how-to-add-change-sequence-length-dimension
class torch.nn.LSTM(*args, **kwargs) -- 参数列表: -- input_size: x 的特征维度 -- hidden_size: 隐层的特征维度 -- num_layers: LSTM 层数,默认为1 -- bias: 是否采用 bias, 如果为False,则不采用。默认为True -- batch_first: True, 则输入输出的数据格式为 [batch_size, seq_len, feature_dim],默认为False -- dropout: dropout会在除最后一层外都进行dropout, 默认为0 -- bidirectional: 是否采用双向,默认为False -- 输入数据: -- input: [seq_len, batch_size, input_size], 输入的特征矩阵 -- h_0: [num_layers * num_directions, batch_size, hidden_size], 初始时 h 状态, 默认为0 -- c_0: [num_layers * num_directions, batch_size, hidden_size], 初始时 cell 状态, 默认为0 -- 输出数据: -- output: [seq_len, batch_size, num_directions * hidden_size], 最后一层的所有隐层输出 -- h_n : [num_layers * num_directions, batch, hidden_size], 所有层的最后一个时刻隐层状态 -- c_n : [num_layers * num_directions, batch, hidden_size], 所有层的最后一格时刻的 cell 状态 -- W,b参数: -- weight_ih_l[k]: 与输入x相关的第k层权重 W 参数, W_ii, W_if, W_ig, W_io -- weight_hh_l[k]: 与上一时刻 h 相关的第k层权重参数, W_hi, W_hf, W_hg, W_ho -- bias_ih_l[k]: 与输入x相关的第k层 b 参数, b_ii, b_if, b_ig, b_io -- bias_hh_l[k]: 与上一时刻 h 相关的第k层 b 参数, b_hi, b_hf, b_hg, b_ho
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。