点击上方蓝字,和我一起学技术。 今天是机器学习专题的第23篇文章,我们今天分享的内容是十大数据挖掘算法之一的CART算法。CART算法全称是Classification and regression tree,也就是分类回归树的意思。和之前介绍的ID3和C4.5一样,CART算法同样是决策树模型的一种经典的实现。决策树这个模型一共有三种实现方式,前面我们已经介绍了ID3和C4.5两种,今天刚好补齐这最后一种。
defgini_index(dataset): dataset = np.array(dataset) n = dataset.shape[0] if n == 0: return0 # sigma p(1-p) = 1 - sigma p^2 counter = Counter(dataset[:, -1]) ret = 1.0 for k, v in counter.items(): ret -= (v / n) ** 2 return ret
defsplit_gini(dataset, idx, threshold): left, right = [], [] n = dataset.shape[0] # 根据阈值拆分,拆分之后计算新的Gini指数 for data in dataset: if data[idx] < threshold: left.append(data) else: right.append(data) left, right = np.array(left), np.array(right) # 拆分成两半之后,乘上所占的比例 return left.shape[0] / n * gini_index(left) + right.shape[0] / n * gini_index(right)
defchoose_feature_to_split(dataset): n = len(dataset[0])-1 m = len(dataset) # 记录最佳Gini,特征和阈值 bestGini = 1.0 feature = -1 thred = None for i in range(n): threds = get_thresholds(dataset, i) for t in threds: # 遍历所有的阈值,计算每个阈值的信息增益比 ratio = split_gini(dataset, i, t) if ratio < bestGini: bestGini, feature, thred = ratio, i, t return feature, thred