基于Python的人工智能应用案例系列(8):鸢尾花分类
这个模型将包含两个隐藏层,使用ReLU作为激活函数。# 定义模型类return x我们将使用交叉熵损失函数(Cross Entropy Loss)和Adam优化器进行模型的训练。# 定义损失函数与优化器在本案例中,我们使用了经典的鸢尾花数据集,构建了一个简单但有效的神经网络模型来进行多分类任务。通过神经网络的构建、训练、测试和推理,我们实现了对不同类别鸢尾花的准确分类。这一过程展示了如何处理多类别
鸢尾花数据集(Iris Dataset)是机器学习入门最经典的分类数据集之一。它包含150个样本,4个特征,用于分类3种不同的鸢尾花类(Iris setosa, Iris virginica, Iris versicolor)。本篇文章将通过一个简单的深度学习模型来实现对鸢尾花的分类。
1. 数据加载与探索性数据分析(EDA)
加载数据
首先,加载鸢尾花数据集,并查看数据的基本信息。
import pandas as pd
# 加载数据集
df = pd.read_csv('../data/iris.csv')
df.head()
数据集包含以下4个特征:
- sepal_length:花萼长度(cm)
- sepal_width:花萼宽度(cm)
- petal_length:花瓣长度(cm)
- petal_width:花瓣宽度(cm) 目标列是
target
,表示鸢尾花的分类。
数据可视化
我们可以通过绘制特征之间的散点图来直观地观察不同鸢尾花类别在不同特征上的分布。
import matplotlib.pyplot as plt
# 绘制鸢尾花数据特征的散点图
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,7))
fig.tight_layout()
# 不同鸢尾花类别的颜色与标签
plots = [(0,1),(2,3),(0,2),(1,3)]
colors = ['b', 'r', 'g']
labels = ['Iris setosa', 'Iris virginica', 'Iris versicolor']
for i, ax in enumerate(axes.flat):
for j in range(3):
x = df.columns[plots[i][0]]
y = df.columns[plots[i][1]]
ax.scatter(df[df['target']==j][x], df[df['target']==j][y], color=colors[j])
ax.set(xlabel=x, ylabel=y)
fig.legend(labels=labels, loc=3, bbox_to_anchor=(1.0,0.85))
plt.show()
从图中可以看出,某些特征组合可以明显区分不同的鸢尾花类别。
2. 数据预处理与训练集/测试集划分
接下来,将数据集分为训练集和测试集。
from sklearn.model_selection import train_test_split
import torch
# 将数据和标签分离
X = df.drop('target', axis=1).values
y = df['target'].values
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=33)
# 转换为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float)
X_test = torch.tensor(X_test, dtype=torch.float)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)
使用DataLoader进行批处理
虽然数据集较小,但使用DataLoader
工具有助于更好地处理大数据集。
from torch.utils.data import DataLoader
trainloader = DataLoader(X_train, batch_size=60, shuffle=True)
testloader = DataLoader(X_test, batch_size=60, shuffle=False)
3. 模型构建
定义神经网络模型
这个模型将包含两个隐藏层,使用ReLU作为激活函数。
import torch.nn as nn
import torch.nn.functional as F
# 定义模型类
class Model(nn.Module):
def __init__(self, in_features=4, h1=8, h2=9, out_features=3):
super().__init__()
self.fc1 = nn.Linear(in_features, h1)
self.fc2 = nn.Linear(h1, h2)
self.out = nn.Linear(h2, out_features)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.out(x)
return x
定义损失函数和优化器
我们将使用交叉熵损失函数(Cross Entropy Loss)和Adam优化器进行模型的训练。
# 定义损失函数与优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
4. 模型训练
模型训练过程包括前向传播、计算损失、反向传播和优化参数。
# 模型训练
epochs = 150
losses = []
for epoch in range(epochs):
y_pred = model.forward(X_train)
loss = criterion(y_pred, y_train)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'epoch: {epoch:2}, loss: {loss.item():10.8f}')
绘制损失函数曲线
我们可以绘制损失函数随迭代次数的变化曲线。
plt.plot(range(epochs), losses)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()
5. 模型测试与评估
现在,我们使用测试集评估模型的性能,计算分类的准确率。
# 模型测试
correct = 0
with torch.no_grad():
y_val = model.forward(X_test)
loss = criterion(y_val, y_test)
print(f'Test Loss: {loss.item()}')
for i, data in enumerate(X_test):
prediction = y_val[i].argmax().item()
true_label = y_test[i].item()
if prediction == true_label:
correct += 1
accuracy = correct / len(y_test) * 100
print(f'Accuracy: {accuracy:.2f}%')
6. 模型保存与加载
为了便于之后的模型推理,我们可以将训练好的模型保存到文件中。
# 保存模型
torch.save(model.state_dict(), 'models/IrisModel.pth')
在需要时,我们可以重新加载模型并进行推理。
# 加载模型
loaded_model = Model()
loaded_model.load_state_dict(torch.load('models/IrisModel.pth'))
loaded_model.eval()
7. 使用模型进行推理
我们可以创建一个新的鸢尾花样本并通过模型进行分类。
# 创建一个新样本
mystery_iris = torch.tensor([5.6, 3.7, 2.2, 0.5])
# 进行推理
with torch.no_grad():
prediction = loaded_model(mystery_iris)
print(f'Predicted class: {prediction.argmax().item()}')
结语
在本案例中,我们使用了经典的鸢尾花数据集,构建了一个简单但有效的神经网络模型来进行多分类任务。通过神经网络的构建、训练、测试和推理,我们实现了对不同类别鸢尾花的准确分类。这一过程展示了如何处理多类别分类问题,如何对模型进行训练和评估,以及如何保存和加载训练好的模型以便后续使用。
尽管鸢尾花数据集相对简单,但这一流程可以很好地拓展至更复杂的分类任务。通过该案例,你不仅能够理解神经网络的基本构建过程,还能掌握一些处理分类问题的核心技巧,如交叉熵损失函数的使用、数据的批处理以及模型的推理和持久化。未来,你可以在更大、更复杂的数据集上应用类似的技术来解决实际问题,从而进一步提高模型的预测性能和应用范围。
如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!
欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。
谢谢大家的支持!
更多推荐
所有评论(0)