news 2026/3/7 14:22:00

掌握 CatBoost 中的不确定性

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
掌握 CatBoost 中的不确定性

原文:towardsdatascience.com/mastering-uncertainty-with-catboost-cdb330bc00cf

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/c9136e76eb8165fc08e5d7e84852b1b7.png

图片由 Ian Taylor 在 Unsplash 提供。

预测区间在回归分析中起着至关重要的作用,尤其是在目标不仅仅是点预测,而是评估预测的不确定性或变异性的情况下。与提供每个输入的单个估计值的点预测不同,预测区间提供了一个范围,其中真实值预计将以一定的概率落在其中。这一点尤其有价值,因为它考虑了任何预测建模中固有的不确定性。通过量化这种不确定性,预测区间提供了对可能结果更全面的理解。例如,在金融预测中,了解未来回报可能波动的范围对于风险管理投资策略至关重要。

此外,追求在回归模型中创建最窄的,或最“有效”的预测区间,可以提高模型输出的精度和可靠性。更窄的区间表明预测的确定性更高,假设区间是准确且一致地捕捉真实值。

通常,预测区间表示为:

[𝜇-𝘻𝜎, 𝜇+𝘻𝜎]

其中 𝜇 是均值(即均值预测),𝘡 是 𝘡 值的量数,𝜎 是标准差。因此,为了达到这个目的,我们可以通过设置 𝘻 = 1.64 来找到 90% 的预测区间,或者如果我们希望有一个更窄的区间,例如 95%,我们可以设置 𝘻 = 1.96。

CatBoost 中的不确定性

假设我们有一个包含特征 𝒙 的数据集。使用 RMSE 损失函数的传统回归框架仅限于预测 𝒙 的平均值。然而,如果目标是确定 𝒚 的方差,这反映了数据的不确定性或识别哪些预测可能不够精确,那么就必须转向能够预测均值和方差的概率回归模型。这种数据不确定性或 𝒚 的方差类似于所谓的随机不确定性**。**

为了解决这个问题,CatBoost 使用一个名为 RMSEWithUncertainty 的新型损失函数。这个函数通过优化负对数似然和采用自然梯度来使 CatBoost 能够近似正态分布的均值和方差。

此外,CatBoost 提供了一种*虚拟集成技术*来衡量知识不确定性认识不确定性。当模型输入来自在训练数据中不足够代表或与它显著不同的区域时,就会发生知识不确定性。

总结一下,CatBoost 使用总方差定律将总不确定性(或方差)分解为知识不确定性和预期数据不确定性的总和。换句话说,

总不确定性 = 知识不确定性 + 预期数据不确定性

示例

让我们考虑一个例子。我们将使用来自 scikit-learn 的加利福尼亚住房数据集来预测中值房价(MedHouseVal),并且进一步对我们的预测设置预测区间。为了加载数据集:

california_dataset=fetch_california_housing()california=pd.DataFrame(california_dataset.data,columns=california_dataset.feature_names)california_target=pd.DataFrame(california_dataset.target,columns=california_dataset.target_names)california['MedHouseVal']=california_target['MedHouseVal']print(california.shape)california.head()

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/0e87994dbdc2512e1883aa8143934f80.png

由作者生成的图像。

数据有 9 个数值属性,我们使用MedHouseVal作为我们的目标变量。我们现在可以准备这些数据,我们将数据分为训练、验证和测试,并为 CatBoost 准备池对象:

# Splitting data into training, validation, and test setsX=california.drop("MedHouseVal",axis=1)y=california["MedHouseVal"]X_train_full,X_test,y_train_full,y_test=train_test_split(X,y,test_size=0.2,random_state=42)X_train,X_val,y_train,y_val=train_test_split(X_train_full,y_train_full,test_size=0.25,random_state=42)# 0.25 x 0.8 = 0.2# Prepare the Pool objects for CatBoosttrain_pool=Pool(X_train,y_train)val_pool=Pool(X_val,y_val)test_pool=Pool(X_test)

现在数据已经准备好了,我们可以继续定义参数并训练和验证模型:

# Model parametersparams={'learning_rate':0.01,'random_state':42,'colsample_bylevel':0.5,'subsample':0.5,'max_bin':50,'max_depth':8,'loss_function':'RMSEWithUncertainty','task_type':'CPU','iterations':2000,'boosting_type':'Plain','bootstrap_type':'Bernoulli','verbose':500,}model=CatBoostRegressor(**params)model.fit(train_pool,eval_set=val_pool)

正如我们所见,对于损失函数,我们使用的是RMSEWithUncertainty,因此输出将包括每个预测的均值方差。注意,我们还定义了提升类型为Plain,这使训练期间启用经典梯度提升方案。对于bootstrap_type,我们选择了Bernoulli,它与随机梯度提升相关。

训练模型给出:

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/aeaece0f897d1f31c94063e9fea81381.png

由作者生成的图像。

在这里,学习和测试日志记录了来自RMSEWithUncertainty的损失。bestTest表示在测试数据集上达到的最佳分数,以及达到此分数的迭代次数由bestIteration给出。最后,Shrink model表示模型将被“收缩”到前 1273 次迭代。这意味着模型将丢弃在 1272 次迭代之后添加的任何额外复杂性,有效地使用模型在测试数据集上表现最佳时的状态。这一步骤是为了防止过拟合并确保模型保持良好的泛化能力。

我们现在可以在测试集上应用我们的模型并进行一些评估:

# Predict on test setpreds=model.predict(test_pool)mean_preds,var_preds=preds[:,0],preds[:,1]# RMSE Evaluationrmse=np.sqrt(mean_squared_error(y_test,mean_preds))print("RMSE on Test Set:",rmse)# Calculate Prediction Intervals and Coverageconfidence_multiplier=1.96# Assuming a certain confidence level, adjust as necessarylower_bound=(mean_preds-confidence_multiplier*np.sqrt(var_preds))upper_bound=(mean_preds+confidence_multiplier*np.sqrt(var_preds))# Coverage calculationcovered=np.sum((y_test>=lower_bound)&amp;(y_test<=upper_bound))coverage=covered/len(y_test)print("Coverage:",coverage)

在这部分中,我们可以看到,在应用model.predict()方法后,我们得到了均值和方差(与只有预测均值的预测相反)。使用均值,我们可以计算 RMSE,这会与点预测类似。

然而,如前所述,这个练习的目的是为我们的预测输出定义预测区间。为此,我们定义了lower_bound = 𝜇-𝘻𝜎 and upper_bound = 𝜇+𝘻𝜎。我们还使用了置信系数confidence_multiplier = 1.96`,这对应于正态分布的 95%预测区间。

为了评估我们的预测区间,我们使用了覆盖率。覆盖率指的是预测区间成功包含真实结果值的比例。这是一个用来评估区间如何捕捉实际观察结果的度量。我们的输出显示:

  • 测试集上的 RMSE:0.5276577335921109

  • 覆盖率:0.9193313953488372

这意味着大约 92%的预测值都在我们定义的下限和上限预测区间内,这相当不错!

我们可以将我们的预测总结在一个数据框中:

# Create a DataFrameresults_df=pd.DataFrame({'Actual':y_test,'Predicted':mean_preds,'Lower_Bound':lower_bound,'Upper_Bound':upper_bound})# Reset index for a cleaner look, especially if y_test is a Series with its own indexresults_df.reset_index(drop=True,inplace=True)# Display the first few rows to verifyresults_df.head()

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/182a07fed18f018ddea28a952b5b6bcb.png

由作者生成的图像。

对于测试集中的每一行,我们可以看到实际值、均值预测,这代表模型预测的期望值或均值。这是模型对数据集中每个实例的目标变量真实值的最佳猜测。

我们还可以看到预测的下限和上限区间。这些值取决于均值预测以及方差预测,这些预测量化了预测的不确定性或变异性。它衡量预测值围绕其均值(mean_preds)的分散程度。在预测区间的上下文中,方差对于确定区间的宽度至关重要。

我们也可以绘制结果:

# Density plot for actual valuessns.kdeplot(results_df['Actual'],label='Actual',fill=True,color='blue',alpha=0.7)# Density plot for predicted valuessns.kdeplot(results_df['Predicted'],label='Predicted',fill=True,color='red',alpha=0.5)# Density plot for lower boundssns.kdeplot(results_df['Lower_Bound'],label='Lower Bound',color='green',linestyle='--',alpha=0.7)# Density plot for upper boundssns.kdeplot(results_df['Upper_Bound'],label='Upper Bound',color='purple',linestyle='--',alpha=0.7)plt.title('Density Plot of Actual, Predicted, and Prediction Intervals')plt.xlabel('Median House Value')plt.ylabel('Density')plt.legend()plt.show()

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f53db3c2b4da10ad4bfe181d85cd898e.png

由作者生成的图像。

密度图可以帮助可视化实际值和预测值的分布,从而给出预测如何捕捉实际数据中变异性的感觉。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/2/28 20:17:39

RimSort:彻底解决《RimWorld》模组管理难题的效率革命工具

RimSort&#xff1a;彻底解决《RimWorld》模组管理难题的效率革命工具 【免费下载链接】RimSort 项目地址: https://gitcode.com/gh_mirrors/ri/RimSort 你是否曾因《RimWorld》模组加载顺序错误导致游戏崩溃&#xff1f;是否在数百个模组中艰难寻找冲突源&#xff1f;…

作者头像 李华
网站建设 2026/3/4 17:53:38

新手必看!Glyph视觉推理部署避坑指南

新手必看&#xff01;Glyph视觉推理部署避坑指南 Glyph不是又一个“上传图片→点几下→出结果”的轻量级工具&#xff0c;而是一套把长文本当图像来“看”的视觉推理新范式。它不靠堆显存扩上下文&#xff0c;而是把几千字的合同、论文或日志渲染成高分辨率图像&#xff0c;再…

作者头像 李华
网站建设 2026/3/5 12:42:27

玩转动物森友会:NHSE存档编辑工具全攻略

玩转动物森友会&#xff1a;NHSE存档编辑工具全攻略 【免费下载链接】NHSE Animal Crossing: New Horizons save editor 项目地址: https://gitcode.com/gh_mirrors/nh/NHSE 功能解析&#xff1a;为什么NHSE能让你的岛屿梦想成真&#xff1f; 你是否曾想过自定义动物森…

作者头像 李华
网站建设 2026/3/5 6:57:58

告别繁琐配置!用科哥构建的Paraformer镜像一键部署语音识别

告别繁琐配置&#xff01;用科哥构建的Paraformer镜像一键部署语音识别 你是否经历过这样的场景&#xff1a; 想快速验证一个语音识别模型&#xff0c;却卡在环境搭建上——CUDA版本不匹配、PyTorch编译报错、FunASR依赖冲突、模型权重下载失败……折腾半天&#xff0c;连第一…

作者头像 李华
网站建设 2026/3/4 20:04:34

MTK设备BROM模式故障排除技术指南

MTK设备BROM模式故障排除技术指南 【免费下载链接】mtkclient MTK reverse engineering and flash tool 项目地址: https://gitcode.com/gh_mirrors/mt/mtkclient 1. 问题诊断&#xff1a;BROM模式异常的识别与分析 1.1 典型故障现象 当MTK设备出现BROM模式访问问题时…

作者头像 李华
网站建设 2026/3/5 4:21:43

Z-Image Turbo画质增强算法逆向分析:高频细节增强与色彩校正逻辑

Z-Image Turbo画质增强算法逆向分析&#xff1a;高频细节增强与色彩校正逻辑 1. 本地极速画板&#xff1a;不只是界面&#xff0c;更是画质增强的起点 Z-Image Turbo 本地极速画板不是传统意义上“能出图就行”的Web工具。它从第一行代码开始&#xff0c;就把画质作为核心目标…

作者头像 李华