因果模型评估实战:从NOTEARS源码拆解FDR/SHD计算逻辑
在因果推断领域,评估模型性能是验证算法有效性的关键环节。NOTEARS论文中提出的count_accuracy函数实现了多种评估指标的计算,其中**FDR(误发现率)和SHD(结构汉明距离)**是最常用的两种。本文将深入源码,逐行解析这些指标背后的数学含义和实现细节。
1. 评估指标基础概念
1.1 混淆矩阵与基本术语
在因果图评估中,我们需要明确几个核心概念:
- 真阳性(TP):预测边存在且方向正确
- 反向边(Reverse):预测边存在但方向相反
- 假阳性(FP):预测边存在但真实图中不存在
- 假阴性(FN):真实边存在但未被预测到
这些概念构成了评估的基础框架。以邻接矩阵表示的有向无环图(DAG)为例:
# 真实图邻接矩阵示例 B_true = np.array([ [0, 1, 0], # X1 -> X2 [0, 0, 1], # X2 -> X3 [0, 0, 0] # X3无出边 ]) # 预测图邻接矩阵示例 B_est = np.array([ [0, 1, 1], # 正确预测X1->X2,错误预测X1->X3 [1, 0, 1], # 反向预测X2->X1,正确预测X2->X3 [0, 0, 0] ])1.2 常见评估指标定义
| 指标 | 公式 | 解释 |
|---|---|---|
| FDR | (反向边+假阳性)/预测正例数 | 错误发现的比例 |
| TPR | 真阳性/真实正例数 | 召回率/敏感度 |
| FPR | (反向边+假阳性)/真实负例数 | 错误预警比例 |
| SHD | 多余边+缺失边+反向边数 | 结构差异总量 |
2. NOTEARS评估函数深度解析
2.1 输入验证与预处理
count_accuracy函数首先进行严格的输入验证:
def count_accuracy(B_true, B_est): # 验证B_est取值合法性 if (B_est == -1).any(): # CPDAG情况 if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all(): raise ValueError('B_est should take value in {0,1,-1}') if ((B_est == -1) & (B_est.T == -1)).any(): raise ValueError('undirected edge should only appear once') else: # DAG情况 if not ((B_est == 0) | (B_est == 1)).all(): raise ValueError('B_est should take value in {0,1}') if not is_dag(B_est): raise ValueError('B_est should be a DAG')注意:-1表示CPDAG中的无向边,需要确保无向边不会在矩阵中重复出现
2.2 关键索引提取
函数通过NumPy操作提取各种边的索引位置:
d = B_true.shape[0] pred_und = np.flatnonzero(B_est == -1) # 无向边位置 pred = np.flatnonzero(B_est == 1) # 预测有向边位置 cond = np.flatnonzero(B_true) # 真实有向边位置 cond_reversed = np.flatnonzero(B_true.T) # 真实反向边位置 cond_skeleton = np.concatenate([cond, cond_reversed]) # 无向骨架这里使用了几个关键NumPy函数:
flatnonzero:返回扁平化数组中非零元素的索引concatenate:合并多个索引数组
3. 核心指标计算逻辑
3.1 真阳性与假阳性识别
# 真阳性:预测有向边且方向正确 true_pos = np.intersect1d(pred, cond, assume_unique=True) # 无向边视为真阳性(宽松评估) true_pos_und = np.intersect1d(pred_und, cond_skeleton, assume_unique=True) true_pos = np.concatenate([true_pos, true_pos_und]) # 假阳性:预测存在但真实不存在 false_pos = np.setdiff1d(pred, cond_skeleton, assume_unique=True) false_pos_und = np.setdiff1d(pred_und, cond_skeleton, assume_unique=True) false_pos = np.concatenate([false_pos, false_pos_und])关键函数解析:
intersect1d:求两个数组的交集setdiff1d:求第一个数组有而第二个数组没有的元素
3.2 反向边检测
extra = np.setdiff1d(pred, cond, assume_unique=True) reverse = np.intersect1d(extra, cond_reversed, assume_unique=True)这段代码精妙地实现了反向边检测:
- 首先找出预测有但真实没有的边(
extra) - 然后检查这些边是否是真实图中反向存在的边
3.3 指标比率计算
pred_size = len(pred) + len(pred_und) # 预测正例总数 cond_neg_size = 0.5 * d * (d - 1) - len(cond) # 真实负例总数 fdr = float(len(reverse) + len(false_pos)) / max(pred_size, 1) tpr = float(len(true_pos)) / max(len(cond), 1) fpr = float(len(reverse) + len(false_pos)) / max(cond_neg_size, 1)提示:分母使用max(...,1)避免除以零错误
4. 结构汉明距离(SHD)实现
4.1 SHD计算原理
SHD衡量两个图结构差异的总和,包括:
- 多余边(预测有但真实没有)
- 缺失边(真实有但预测没有)
- 反向边(方向预测错误)
NOTEARS中的实现:
pred_lower = np.flatnonzero(np.tril(B_est + B_est.T)) cond_lower = np.flatnonzero(np.tril(B_true + B_true.T)) extra_lower = np.setdiff1d(pred_lower, cond_lower, assume_unique=True) missing_lower = np.setdiff1d(cond_lower, pred_lower, assume_unique=True) shd = len(extra_lower) + len(missing_lower) + len(reverse)4.2 关键技巧解析
np.tril:取矩阵的下三角部分,避免重复计算- 邻接矩阵相加:将有向图转换为无向骨架
- 通过集合操作计算多余和缺失边
实际项目中,SHD计算可以这样验证:
from cdt.metrics import SHD import numpy as np # 生成随机邻接矩阵 np.random.seed(42) tar = np.random.randint(2, size=(5,5)) pred = np.random.randint(2, size=(5,5)) # 计算SHD print("CDT库计算结果:", SHD(tar, pred)) print("NOTEARS计算结果:", count_accuracy(tar, pred)['shd'])5. 实际应用中的注意事项
5.1 评估指标的选择策略
不同场景下应侧重不同指标:
- 因果发现:优先关注FDR,控制错误发现
- 因果效应估计:关注TPR,确保重要关系不被遗漏
- 算法比较:使用SHD综合评估结构差异
5.2 常见陷阱与解决方案
样本量影响:
- 小样本时FDR可能被高估
- 解决方案:使用bootstrap计算置信区间
稠密图问题:
- 高密度图的SHD绝对值会增大
- 解决方案:考虑标准化SHD(除以可能边数)
CPDAG评估:
- 无向边的处理需要特殊规则
- NOTEARS采用宽松策略(视为正确)
# 处理CPDAG评估的实用技巧 def adjust_for_cpdag(B_true, B_est): # 将无向边视为双向边 B_est_skeleton = (B_est != 0).astype(int) B_true_skeleton = (B_true != 0).astype(int) # 计算骨架准确率 skeleton_tpr = np.sum(B_est_skeleton & B_true_skeleton) / np.sum(B_true_skeleton) return skeleton_tpr5.3 性能优化建议
对于大规模图(节点数>1000),原始实现可能效率低下,可以考虑:
- 使用稀疏矩阵存储邻接关系
- 并行化集合运算
- 近似计算策略
from scipy.sparse import csr_matrix def sparse_count_accuracy(B_true, B_est): # 转换为稀疏矩阵 B_true_sparse = csr_matrix(B_true) B_est_sparse = csr_matrix(B_est) # 使用稀疏矩阵运算优化性能 # ...后续实现类似但使用稀疏矩阵操作在实际项目中,我发现对大规模基因调控网络(通常有上万个节点)进行评估时,原始实现可能需要数小时完成,而经过稀疏矩阵优化后,评估时间可以缩短到几分钟。特别是在计算SHD时,只比较下三角矩阵的策略可以减少近一半的计算量。