news 2026/3/26 3:08:20

线性拟合模型

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
线性拟合模型

线性拟合模型

一、数据准备部分

importnumpyasnpimportkerasimportmatplotlib.pyplotasplt train_X=np.asarray([30.0,40.0,60.0,80.0,100.0,120.0,140.0])train_Y=np.asarray([320.0,360.0,400.0,455.0,490.0,546.0,580.0])train_X/=100.0train_Y/=100.0
  • train_Xtrain_Y是人工构造的训练数据(x 和 y)。

  • 除以 100 是为了归一化(Normalization),将数据范围从 [30-140] 和 [320-580] 缩放到 [0.3-1.4] 和 [3.2-5.8]),有助于神经网络更快收敛。

  • 这是典型的监督学习回归问题:输入 x → 预测 y。

二、可视化函数

defplot_points(x,y,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')plt.scatter(x,y)plt.show()defplot_line(W,b,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')x=np.linspace(0.0,2.0,num=100)y=W*x+b plt.plot(x,y)plt.show()
  • plot_points:画散点图,展示原始数据。

  • plot_line:根据斜率W和截距b画出拟合直线。

三、模型构建

model=keras.models.Sequential()model.add(keras.layers.Dense(units=1,input_dim=1))
  • 只有一层:Dense全连接层
  • units=1:只有一个神经元(输出一个值)
  • input_dim=1:输入数据是一维的(一个特征)
  • 相当于数学公式:y = Wx + b,其中:
    • W:权重(weight),相当于斜率
    • b:偏置(bias),相当于截距

四、编译模型

model.compile(optimizer='sgd',loss='mean_squared_error')
  • optimizer='sgd':使用随机梯度下降优化器
    • SGD是最基础、最经典的优化算法
    • 相比adam,SGD更简单,适合这种简单线性问题
  • loss='mean_squared_error':使用均方误差作为损失函数
    • 计算公式:MSE = Σ(y_pred - y_true)² / n
    • 这是回归问题最常用的损失函数

五、训练模型

history=model.fit(x=train_X,y=train_Y,batch_size=1,epochs=10)
  • batch_size=1批大小为1(在线学习/随机梯度下降)
    • 每看一个样本就更新一次权重
    • 梯度更新频繁,波动较大
    • 内存占用小,适合小数据集
  • epochs=10:训练10轮
    • 把7个样本反复训练10遍
    • 总共训练 7 × 10 = 70 次更新

注意history会记录训练过程中的loss变化,可以用于后续分析

六. 结果可视化

plot_line(model.get_weights()[0][0][0],model.get_weights()[1][0],title_name='Current_Model')
  • model.get_weights()[0]:获取权重W(斜率)
    • [0][0][0]是因为权重的形状是(1,1),需要索引到具体数值
  • model.get_weights()[1]:获取偏置b(截距)
    • [0]是因为偏置的形状是(1,),需要索引到具体数值

这个模型在做什么?

1. 数学本质

这个模型其实就是用神经网络的方式来实现最小二乘法线性回归

  • 要找一条直线y = Wx + b
  • 让这条直线最接近所有数据点
  • "接近"的标准是:均方误差最小

2. 训练过程(SGD)

初始化:W=随机值,b=随机值for10:for每个样本(x_i,y_i):1.计算预测值:y_pred=W*x_i+b2.计算误差:error=y_pred-y_i3.计算梯度:dW=2*error*x_i# 对W的梯度db=2*error# 对b的梯度4.更新参数:W=W-learning_rate*dW b=b-learning_rate*db

完整代码:

importnumpyasnpimportkerasimportmatplotlib.pyplotasplt train_X=np.asarray([30.0,40.0,60.0,80.0,100.0,120.0,140.0])train_Y=np.asarray([320.0,360.0,400.0,455.0,490.0,546.0,580.0])train_X/=100.0train_Y/=100.0#用于对数据点进行可视化defplot_points(x,y,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')plt.scatter(x,y)plt.show()defplot_line(W,b,title_name):plt.title(title_name)plt.xlabel('x')plt.ylabel('y')x=np.linspace(0.0,2.0,num=100)y=W*x+b plt.plot(x,y)plt.show()plot_points(train_X,train_Y,title_name='Training Points')#建立线性拟合模型,由斜率和偏移两个参数构成,相当于神经元数为1的一层全连接model=keras.models.Sequential()model.add(keras.layers.Dense(units=1,input_dim=1))#成本函数采用均差误差,优化方法使用随机梯度下降model.compile(optimizer='sgd',loss='mean_squared_error')#模型迭代10个轮次,用单样本的方式进行优化history=model.fit(x=train_X,y=train_Y,batch_size=1,epochs=10)plot_line(model.get_weights()[0][0][0],model.get_weights()[1][0],title_name='Current_Model')

附解释可视化函数部分
1.散点图
def plot_points(x, y, title_name):

  • 定义一个名为plot_points的函数。

    x:横坐标数据(如你的 train_X)
    y:纵坐标数据(如你的 train_Y)
    title_name:图表的标题(字符串)

​ plt.title(title_name) # 设置图表标题
​ plt.xlabel(‘x’) # 设置x轴标签
​ plt.ylabel(‘y’) # 设置y轴标签
​ plt.scatter(x, y) # 绘制散点图
​ plt.show() # 显示图表

2.直线图
def plot_line(W, b, title_name):
plt.title(title_name) # 设置图表标题
plt.xlabel(‘x’) # 设置x轴标签
plt.ylabel(‘y’) # 设置y轴标签

​ x = np.linspace(0.0, 2.0, num=100) # 生成100个等间距的x值
​ np: numpy模块的别名
​ .linspace(): 生成等差数列(linear space)
​ 参数:
​ 0.0: 起始值(start)
​ 2.0: 结束值(stop)
​ num=100: 生成100个点

​ y = W * x + b # 计算对应的y值

​ plt.plot(x, y) # 绘制折线图(这里是直线)
​ .plot(): 绘制折线图
​ 参数:(x, y)坐标点

​ plt.show() # 显示图表

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/21 22:56:34

光伏设计新选择:鹧鸪云

在光伏电站开发领域,传统设计模式的痛点早已凸显:人工测量耗时费力,二维图纸难以还原场地实况,数据偏差动辄引发发电量预测失准、施工返工等连锁问题,严重制约项目推进效率与收益。如今,无人机与数字化技术…

作者头像 李华
网站建设 2026/3/24 10:36:40

大模型微调7种方法:零基础入门全指南

大模型微调是让通用预训练模型适配特定任务的核心技术,分为全量微调与参数高效微调(PEFT)两大类。对零基础学习者而言,PEFT方法因低资源需求、易上手的优势成为首选。以下详细解析7种主流微调方法,并梳理极简入门流程&…

作者头像 李华
网站建设 2026/3/23 9:48:58

如何实现pdf一页内容分割成多页打印?详细教程分享

做好的设计稿是A2尺寸,可打印机只支持A4怎么办?直接缩印的话字体会小到看不清楚。其实可以试试将PDF一页内容分割成多页,打印好后再拼接到一起,清晰度不受影响,还不用特意跑打印店。有同样需求的朋友赶紧码住学起来~一…

作者头像 李华
网站建设 2026/3/13 5:54:14

【学习笔记】《道德经》第56章

《道德经》第56章 学习整理 本整理基于James Legge经典英文译本,结合标准中文(参考王弼本),从英文学习角度系统呈现内容。结构分为三个部分: 逐句中英对照翻译现代日常口语版英文关键短语口语对应表及使用建议 一、逐句…

作者头像 李华
网站建设 2026/3/13 5:59:41

2025智能体(Agent)框架全景:构建自主智能的基石

在人工智能的演进历程中,2025年标志着一个关键转折点——智能体(Agent)框架不再仅仅是实验室里的概念验证,而是成为推动产业智能化转型的核心引擎。从数字助手到自主决策系统,从虚拟化身到物理机器人,智能体…

作者头像 李华
网站建设 2026/3/24 9:11:14

远程客服管理方案:客服系统如何实现分散团队的协同与监管

随着远程办公模式的普及,客服团队的地域分散化成为常态,随之而来的信息孤岛、响应延迟、服务标准不统一、监管失效等问题,严重影响服务质量与客户体验。一套具备高效协同与精准监管能力的客服系统,成为破解分散团队管理困境的核心…

作者头像 李华