import matplotlib.pyplot as plt
import numpy as np

# 设置图片清晰度
plt.rcParams['figure.dpi'] = 300

# 绘制简单神经网络
# 定义每层神经元的位置
input_layer = np.array([(0, i) for i in range(3)])
hidden_layer = np.array([(1, i) for i in range(4)])
output_layer = np.array([(2, i) for i in range(1)])

# 绘制神经元
plt.scatter(input_layer[:, 0], input_layer[:, 1], s=200, label='输入层')
plt.scatter(hidden_layer[:, 0], hidden_layer[:, 1], s=200, label='隐藏层')
plt.scatter(output_layer[:, 0], output_layer[:, 1], s=200, label='输出层')

# 绘制连接
for i in input_layer:
    for h in hidden_layer:
        plt.plot([i[0], h[0]], [i[1], h[1]], 'k-')
for h in hidden_layer:
    for o in output_layer:
        plt.plot([h[0], o[0]], [h[1], o[1]], 'k-')

# 添加权重标注（简单示意）
for i in input_layer:
    for h in hidden_layer:
        mid_x = (i[0] + h[0]) / 2
        mid_y = (i[1] + h[1]) / 2
        plt.text(mid_x, mid_y, 'w', ha='center', va='center')

# 添加标签和标题
plt.legend()
plt.title('简单神经网络示例')
plt.axis('off')
plt.show()