import torch.nn as nn class LSTMForecaster(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super(LSTMForecaster, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): out, _ = self.lstm(x) out = out[:, -1, :] out = self.fc(out) return out