import torch import torch.nn as nn import torch.optim as optim import numpy as np import matplotlib.pyplot as plt # 设置随机种子保证可重复性 torch.manual_seed(42) # ========================== # 1. 生成正弦函数数据 # ========================== def generate_sine_data(seq_length=50, total_points=1000): x = np.linspace(0, 100, total_points) y = np.sin(x) # 构造输入序列和目标值 X, Y = [], [] for i in range(len(y) - seq_length): X.append(y[i:i+seq_length]) # 输入序列 Y.append(y[i+seq_length]) # 下一个值(目标) return np.array(X), np.array(Y) seq_length = 50 X, y = generate_sine_data(seq_length=seq_length) # 转换为 PyTorch 张量 X = torch.FloatTensor(X).unsqueeze(-1) # 形状: (样本数, seq_length, 1) y = torch.FloatTensor(y).unsqueeze(-1) # 形状: (样本数, 1) # ========================== # 2. 定义 RNN 模型 # ========================== class SimpleRNN(nn.Module): def __init__(self, input_size=1, hidden_size=32, output_size=1): super().__init__() self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x): # x 形状: (batch_size, seq_length, input_size) out, _ = self.rnn(x) # out 形状: (batch_size, seq_length, hidden_size) out = self.fc(out[:, -1, :]) # 取最后一个时间步的输出 return out model = SimpleRNN() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.01) # ========================== # 3. 训练模型 # ========================== epochs = 200 losses = [] for epoch in range(epochs): optimizer.zero_grad() outputs = model(X) loss = criterion(outputs, y) loss.backward() optimizer.step() losses.append(loss.item()) if (epoch+1) % 20 == 0: print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}') # 绘制损失曲线 plt.figure(figsize=(12, 4)) plt.plot(losses) plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training Loss') plt.show() # ========================== # 4. 预测并可视化结果 # ========================== model.eval() with torch.no_grad(): predictions = model(X) # 可视化前 200 个点的预测结果 plt.figure(figsize=(12, 6)) plt.plot(y[:200].numpy(), label='True') plt.plot(predictions[:200].numpy(), label='Predicted') plt.legend() plt.title('RNN Sine Wave Prediction') plt.show()