news 2026/4/24 11:31:17

别再死记硬背公式了!用Python手写一个感知机,从鸢尾花分类理解AI的‘第一课’

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记硬背公式了!用Python手写一个感知机,从鸢尾花分类理解AI的‘第一课’

用Python手写感知机:从鸢尾花分类看AI如何"学会"决策

当你第一次听说"机器学习"时,脑海中浮现的是不是一堆复杂的数学公式?那些Σ、∇符号和矩阵运算确实容易让人望而生畏。但今天我们要打破这个魔咒——用不到100行Python代码,亲手实现一个能自动学习分类规则的感知机模型。你会发现,AI的"第一课"其实就像教小朋友区分苹果和橘子一样直观。

1. 感知机:半个世纪前的AI基石

1960年代,Frank Rosenblatt发明的感知机(Perceptron)开启了模式识别的新纪元。这个看似简单的模型蕴含着机器学习最核心的思想:通过错误来学习。想象一下教孩子认猫:

  • 孩子指着狗说"猫"(错误分类)
  • 你纠正说"这是狗"(调整参数)
  • 孩子下次更可能正确识别(模型收敛)

感知机的工作方式与此惊人相似。它由三个关键部分组成:

class Perceptron: def __init__(self): self.weights = None # 决策边界的"倾斜程度" self.bias = 0 # 决策边界的"左右偏移" self.lr = 0.1 # 学习率(犯错后的调整幅度)

**权重(weights)**就像我们对不同特征的重视程度。在鸢尾花分类中,花瓣长度可能比萼片宽度更重要;**偏置(bias)**则相当于判断时的宽松程度——好比老师批改试卷时,60分及格和70分及格的区别。

2. 数据准备:鸢尾花的简化世界

我们使用经典的鸢尾花数据集,但做两个简化处理:

  1. 只保留setosa和versicolor两个品种(二分类问题)
  2. 仅使用萼片长度和宽度两个特征(方便可视化)
from sklearn.datasets import load_iris import numpy as np iris = load_iris() X = iris.data[:100, :2] # 前100个样本,取前两个特征 y = np.where(iris.target[:100] == 0, 1, -1) # 转换为1/-1标签

来看看数据的分布情况:

特征组合Setosa (标签1)Versicolor (标签-1)
萼片长度4.3-5.8 cm4.9-7.0 cm
萼片宽度2.3-4.4 cm2.0-3.4 cm

提示:实际项目中应该对特征进行标准化处理,但为了教学直观性,我们保留原始尺度

3. 核心算法:错误驱动的学习过程

感知机的训练过程就像蒙眼走迷宫:每次碰到墙(分类错误),就调整前进方向。具体实现如下:

def fit(self, X, y, epochs=100): n_samples, n_features = X.shape self.weights = np.zeros(n_features) for _ in range(epochs): for idx, x_i in enumerate(X): condition = y[idx] * (np.dot(x_i, self.weights) + self.bias) if condition <= 0: # 分类错误 update = self.lr * y[idx] self.weights += update * x_i self.bias += update

这段代码中藏着两个精妙之处:

  1. 错误判断条件y * (w·x + b) ≤ 0

    • 正确分类时:w·x + by同号,乘积为正
    • 错误分类时:两者异号,乘积为负
  2. 参数更新规则

    • w = w + η * y * x(η是学习率)
    • b = b + η * y

用几何解释:错误样本点在决策边界的错误一侧,更新规则将其"拉向"正确方向。例如:

  • 正样本被误判为负:w += η * x使得w·x增大
  • 负样本被误判为正:w -= η * x使得w·x减小

4. 可视化:看决策边界如何进化

让我们用matplotlib观察训练过程中决策边界的变化:

def plot_decision_boundary(model, X, y, epoch): x1_min, x1_max = X[:,0].min()-0.5, X[:,0].max()+0.5 x2_min, x2_max = X[:,1].min()-0.5, X[:,1].max()+0.5 xx1, xx2 = np.meshgrid(np.linspace(x1_min,x1_max,100), np.linspace(x2_min,x2_max,100)) Z = model.predict(np.c_[xx1.ravel(), xx2.ravel()]) Z = Z.reshape(xx1.shape) plt.contourf(xx1, xx2, Z, alpha=0.3) plt.scatter(X[:,0], X[:,1], c=y, edgecolors='k') plt.title(f'Epoch {epoch}') plt.xlabel('Sepal length') plt.ylabel('Sepal width')

训练过程中的关键阶段:

  1. 初始状态(随机权重):

    • 决策边界混乱,准确率约50%
  2. 中期调整(部分样本正确分类):

    • 边界开始分离两类样本
    • 仍有一些顽固的错误点
  3. 最终收敛

    • 所有训练样本正确分类
    • 边界处于两类之间的"中庸"位置

注意:如果数据不是线性可分的,感知机会在两者间反复震荡无法收敛

5. 超越基础:现代视角下的感知机

虽然原始感知机很简单,但它启发了现代深度学习的许多概念:

  1. 激活函数

    • 感知机的sign函数是阶跃函数
    • 现代神经网络使用sigmoid、ReLU等平滑函数
  2. 损失函数

    • 感知机最小化误分类点到超平面的距离
    • 现代方法常用交叉熵、MSE等
  3. 优化算法

    • 感知机使用原始梯度下降
    • 现代优化器如Adam、RMSprop更高效

用PyTorch实现感知机会发现惊人相似:

import torch class TorchPerceptron(torch.nn.Module): def __init__(self, input_dim): super().__init__() self.linear = torch.nn.Linear(input_dim, 1) def forward(self, x): return torch.sign(self.linear(x)).squeeze()

关键区别在于:

  • 自动计算梯度(autograd)
  • 可以使用GPU加速
  • 轻松扩展为多层网络

6. 实战建议:从玩具到真实项目

当你在真实数据上应用感知机时,记住这些经验:

  1. 特征工程比算法更重要

    • 对非线性数据尝试多项式特征
    from sklearn.preprocessing import PolynomialFeatures poly = PolynomialFeatures(degree=2) X_poly = poly.fit_transform(X)
  2. 超参数调优

    • 学习率太大导致震荡,太小收敛慢
    • 用网格搜索找最佳组合
    from sklearn.model_selection import GridSearchCV param_grid = {'lr': [0.001, 0.01, 0.1, 1]}
  3. 评估指标选择

    • 准确率对平衡数据集有效
    • 不平衡数据用F1-score或AUC-ROC
  4. 扩展到多分类

    • 一对多(One-vs-Rest)策略
    • 多类感知机变种

在Kaggle的Titanic数据集上,即使简单如感知机,经过恰当的特征工程也能达到75%+的准确率——这已经比随机猜测的50%好很多了。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/24 11:27:31

算法训练营第十一天|删除有序数组中的重复项

学习链接&#xff1a;https://zhuanlan.zhihu.com/p/29544395643 学习思路&#xff1a; 使用两个指针&#xff1a;慢指针 slow 和快指针 fast慢指针 slow 指向当前可以放置元素的位置快指针 fast 用于遍历数组对于每个元素&#xff0c;我们需要判断是否应该保留它&#xff1a…

作者头像 李华
网站建设 2026/4/24 11:26:28

2025届毕业生推荐的十大AI写作网站实际效果

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 在文本生成进程里&#xff0c;为有效弱化机器化表达&#xff0c;提升内容自然度&#xff0c…

作者头像 李华
网站建设 2026/4/24 11:25:38

高效数据迁移:基于脚本的MySQL→Hive OBS层搭建方案

1、 业务数据库的导入先在 mysql 创建数据库 jrxd之后将准备好的大数据sql导入mysql传统通过navicat导入速度过慢时&#xff0c;可以使用mysql命令行进行导入&#xff0c;即先将sql上传至虚拟机之后再使用mysql客户端中的source命令source /opt/modules/jrxd-new.sql;导入成功后…

作者头像 李华
网站建设 2026/4/24 11:25:37

如何快速优化游戏输入:3种SOCD清洁模式提升操作精度指南

如何快速优化游戏输入&#xff1a;3种SOCD清洁模式提升操作精度指南 【免费下载链接】socd Key remapper for epic gamers 项目地址: https://gitcode.com/gh_mirrors/so/socd 你是否在玩格斗游戏时&#xff0c;同时按下左右方向键导致角色卡顿&#xff1f;或者在平台跳…

作者头像 李华
网站建设 2026/4/24 11:23:28

如何快速解包Godot游戏资源:终极PCK文件提取工具指南

如何快速解包Godot游戏资源&#xff1a;终极PCK文件提取工具指南 【免费下载链接】godot-unpacker godot .pck unpacker 项目地址: https://gitcode.com/gh_mirrors/go/godot-unpacker 如果你正在寻找一个高效、免费的Godot游戏资源解包工具&#xff0c;那么godot-unpac…

作者头像 李华