【Python机器学习系列】绘制多分类任务的ROC曲线-宏平均ROC曲线
qiyuwang 2024-10-10 11:28 21 浏览 0 评论
这是我的第345篇原创文章。
一、引言
ROC曲线是用于评估二分类模型性能的工具,它展示了模型在不同阈值下的真阳性率与假阳性率之间的关系,但是标准的ROC并不能运用于多分类任务种,于是扩展出了宏平均ROC曲线。
宏平均ROC曲线是多分类问题中对ROC曲线的扩展,在多分类任务中,我们需要计算每一类别相对于其他所有类别的ROC曲线,然后对所有这些ROC曲线进行平均,从而得到宏平均ROC曲线,其主要步骤如下:
- 逐类计算ROC曲线:对于每个类别,将其视为正类,其他所有类别视为负类,计算出相应的ROC曲线,也就是可以看作对每个类别进行独热编码
- 计算AUC值:计算每个类别对应的AUC值
- 平均化:对所有类别的AUC值进行平均,从而得到宏平均AUC值,同时,将各类别的ROC曲线取平均,得到宏平均ROC曲线
宏平均ROC曲线的优点在于它平等地考虑了每个类别的性能,适用于类别数量不平衡的情况,不过,由于它对所有类别进行了简单平均,如果某些类别比其他类别更加重要,宏平均ROC可能无法完全反映分类器的实际性能。
二、实现过程
2.1 准备数据
data = pd.read_csv(r'data.csv')
df = pd.DataFrame(data)
print(df.head())
该多分类数据存在3个类别:
2.2 提取目标变量
target = 'Type'
features = df.columns.drop(target)
print(data["Type"].value_counts()) # 顺便查看一下样本是否平衡
2.3 划分数据集
# df = shuffle(df)
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)
2.4 归一化
mm1 = MinMaxScaler() # 特征进行归一化
X_train_m = mm1.fit_transform(X_train)
2.5 模型的构建
model = RandomForestClassifier()
2.6 模型的训练
model.fit(X_train_m, y_train)
2.7 模型的推理
X_test_m = mm1.transform(X_test)
y_pred = model.predict(X_test_m)
y_scores = model.predict_proba(X_test_m)
print(y_pred)
acc = accuracy_score(y_test, y_pred) # 准确率acc
print(f"acc: \n{acc}")
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
print(f"cm: \n{cm}")
cr = classification_report(y_test, y_pred) # 分类报告
print(f"cr: \n{cr}")
2.8 模型的评价
acc = accuracy_score(y_test, y_pred) # 准确率acc
print(f"acc: \n{acc}")
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
print(f"cm: \n{cm}")
cr = classification_report(y_test, y_pred) # 分类报告
print(f"cr: \n{cr}")
结果:
把混淆矩阵进行可视化:
2.8 绘制ROC曲线
计算宏平均ROC:
# 将y标签转换成one-hot形式
ytest_one_rf = label_binarize(y_test, classes=[1, 2, 3])
# 宏平均法计算AUC
rf_AUC = {}
rf_FPR = {}
rf_TPR = {}
for i in range(ytest_one_rf.shape[1]):
rf_FPR[i], rf_TPR[i], thresholds = roc_curve(ytest_one_rf[:, i], y_scores[:, i])
rf_AUC[i] = auc(rf_FPR[i], rf_TPR[i])
print(rf_AUC)
# 合并所有的FPR并排序去重
pass
# 计算宏平均TPR
rf_TPR_all = np.zeros_like(rf_FPR_final)
for i in range(ytest_one_rf.shape[1]):
rf_TPR_all += np.interp(rf_FPR_final, rf_FPR[i], rf_TPR[i])
rf_TPR_final = rf_TPR_all / ytest_one_rf.shape[1]
# 计算最终的宏平均AUC
rf_AUC_final = auc(rf_FPR_final, rf_TPR_final)
AUC_final_rf = rf_AUC_final # 最终AUC
print(f"Macro Average AUC with Random Forest: {AUC_final_rf}")
利用随机森林模型对测试集进行预测,并计算每个类别的预测概率。然后,将实际标签 ytest 转换为 one-hot 编码形式,以便进行多分类的 ROC 曲线分析,接着,通过逐类别计算 ROC 曲线和 AUC 值,并保存到字典中,最后,通过合并所有类别的 FPR 值并计算宏平均 TPR,从而得到最终的宏平均 AUC 值,用于评估随机森林模型在多分类任务中的整体性能。
绘制随机森林分类器在多分类任务中的 ROC 曲线,并计算并展示了每个类别的 AUC 值以及宏平均 ROC 曲线的 AUC:
plt.figure(figsize=(10, 5), dpi=300)
# 使用不同的颜色和线型
plt.plot(rf_FPR[0], rf_TPR[0], color='#1f77b4', linestyle='-', label='Class 1 ROC AUC={:.4f}'.format(rf_AUC[0]), lw=2)
plt.plot(rf_FPR[1], rf_TPR[1], color='#ff7f0e', linestyle='-', label='Class 2 ROC AUC={:.4f}'.format(rf_AUC[1]), lw=2)
plt.plot(rf_FPR[2], rf_TPR[2], color='#2ca02c', linestyle='-', label='Class 3 ROC AUC={:.4f}'.format(rf_AUC[2]), lw=2)
# 宏平均ROC曲线
plt.plot(rf_FPR_final, rf_TPR_final, color='#000000', linestyle='-', label='Macro Average ROC AUC={:.4f}'.format(rf_AUC_final), lw=3)
# 45度参考线
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=2, label='45 Degree Reference Line')
plt.xlabel('False Positive Rate (FPR)', fontsize=15)
plt.ylabel('True Positive Rate (TPR)', fontsize=15)
plt.title('Random Forest Classification ROC Curves and AUC', fontsize=18)
plt.grid(linestyle='--', alpha=0.7)
plt.legend(loc='lower right', framealpha=0.9, fontsize=12)
plt.savefig('RF_optimized.pdf', format='pdf', bbox_inches='tight')
plt.show()
结果:
作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。
原文链接:
相关推荐
- # 安装打开 ubuntu-22.04.3-LTS 报错 解决方案
-
#安装打开ubuntu-22.04.3-LTS报错解决方案WslRegisterDistributionfailedwitherror:0x800701bcError:0x80070...
- 利用阿里云镜像在ubuntu上安装Docker
-
简介:...
- 如何将Ubuntu Kylin(优麒麟)19.10系统升级到20.04版本
-
UbuntuKylin系统使用一段时间后,有新的版本发布,如何将现有的UbuntuKylin系统升级到最新版本?可以通过下面的方法进行升级。1.先查看相关的UbuntuKylin系统版本情况。使...
- Ubuntu 16.10内部代号确认为Yakkety Yak
-
在正式宣布Ubuntu16.04LTS(XenialXerus)的当天,Canonical创始人MarkShuttleworth还非常开心的在个人微博上宣布Ubuntu下个版本16.10的内...
- 如何在win11的wsl上装ubuntu(怎么在windows上安装ubuntu)
-
在Windows11的WSL(WindowsSubsystemforLinux)上安装Ubuntu非常简单。以下是详细的步骤:---...
- Win11学院:如何在Windows 11上使用WSL安装Ubuntu
-
IT之家2月18日消息,科技媒体pureinfotech昨日(2月17日)发布博文,介绍了3中简便的方法,让你轻松在Windows11系统中,使用WindowsSubs...
- 如何查看Linux的IP地址(如何查看Linux的ip地址)
-
本头条号每天坚持更新原创干货技术文章,欢迎关注本头条号"Linux学习教程",公众号名称“Linux入门学习教程"。...
- 怎么看电脑系统?(怎么看电脑系统配置)
-
要查看电脑的操作系统信息,可以按照以下步骤操作,根据不同的操作系统选择对应的方法:一、Windows系统通过系统属性查看右键点击桌面上的“此电脑”(或“我的电脑”)图标,选择“属性”。在打开的...
- 如何查询 Linux 内核版本?这些命令一定要会!
-
Linux内核是操作系统的核心,负责管理硬件资源、调度进程、处理系统调用等关键任务。不同的内核版本可能支持不同的硬件特性、提供新的功能,或者修复了已知的安全漏洞。以下是查询内核版本的几个常见场景:...
- 深度剖析:Linux下查看系统版本与CPU架构
-
在Linux系统管理、维护以及软件部署的过程中,精准掌握系统版本和CPU架构是极为关键的基础操作。这些信息不仅有助于我们深入了解系统特性、判断软件兼容性,还能为后续的软件安装、性能优化提供重要依据。接...
- 504 错误代码解析与应对策略(504错误咋解决)
-
在互联网的使用过程中,用户偶尔会遭遇各种错误提示,其中504错误代码是较为常见的一种。504错误并非意味着网站被屏蔽,它实际上是指服务器在规定时间内未能从上游服务器获取响应,专业术语称为“Ga...
- 猎聘APP和官网崩了?回应:正对部分职位整改,临时域名可登录
-
10月12日,有网友反映猎聘网无法打开,猎聘APP无法登录。截至10月14日,仍有网友不断向猎聘官方微博下反映该情况,而猎聘官方微博未发布相关情况说明,只是在微博内对反映该情况的用户进行回复,“抱歉,...
- 域名解析的原理是什么?域名解析的流程是怎样的?
-
域名解析是网站正常运行的关键因素,因此网站管理者了解域名解析的原理和流程对于做好域名管理、解决常见解析问题,保障网站的正常运转十分必要。那么域名解析的原理是什么?域名解析的流程是怎样的?接下来,中科三...
- Linux无法解析域名的解决办法(linux 不能解析域名)
-
如果由于误操作,删除了系统原有的dhcp相关设置就无法正常解析域名。 此时,需要手动修改配置文件: /etc/resolv.conf 将域名解析服务器手动添加到配置文件中 该文件是DNS域名解...
- 域名劫持是什么?(域名劫持是什么)
-
域名劫持是互联网攻击的一种方式,通过攻击域名解析服务器(DNS),或伪造域名解析服务器(DNS)的方法,把目标网站域名解析到错误的地址从而实现用户无法访问目标网站的目的。说的直白些,域名劫持,就是把互...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- # 安装打开 ubuntu-22.04.3-LTS 报错 解决方案
- 利用阿里云镜像在ubuntu上安装Docker
- 如何将Ubuntu Kylin(优麒麟)19.10系统升级到20.04版本
- Ubuntu 16.10内部代号确认为Yakkety Yak
- 如何在win11的wsl上装ubuntu(怎么在windows上安装ubuntu)
- Win11学院:如何在Windows 11上使用WSL安装Ubuntu
- 如何查看Linux的IP地址(如何查看Linux的ip地址)
- 怎么看电脑系统?(怎么看电脑系统配置)
- 如何查询 Linux 内核版本?这些命令一定要会!
- 深度剖析:Linux下查看系统版本与CPU架构
- 标签列表
-
- navicat无法连接mysql服务器 (65)
- 下横线怎么打 (71)
- flash插件怎么安装 (60)
- lol体验服怎么进 (66)
- ae插件怎么安装 (62)
- yum卸载 (75)
- .key文件 (63)
- cad一打开就致命错误是怎么回事 (61)
- rpm文件怎么安装 (66)
- linux取消挂载 (81)
- ie代理配置错误 (61)
- ajax error (67)
- centos7 重启网络 (67)
- centos6下载 (58)
- mysql 外网访问权限 (69)
- centos查看内核版本 (61)
- ps错误16 (66)
- nodejs读取json文件 (64)
- centos7 1810 (59)
- 加载com加载项时运行错误 (67)
- php打乱数组顺序 (68)
- cad安装失败怎么解决 (58)
- 因文件头错误而不能打开怎么解决 (68)
- js判断字符串为空 (62)
- centos查看端口 (64)