机器学习实战——树回归

时间:2023-01-09 19:51:55

上一节介绍了线性回归模型,但是现实生活中很多问题是非线性的,不可能用全局线性模型来拟合数据。一种可行的方法是将数据集切分成很多易建模的数据,然后利用线性回归技术来进行拟合。这种切分方式下,树结构和回归法就相当有用。

一、 CART算法

CART即Classification And Regression Tree,分类回归树, 它使用二元切分来处理连续型变量。创建树的代码如下:r

CreateTree

找到最佳的带切分特征:

如果该节点不能再分,将该节点存为叶节点;
执行二元切分
在右子树调用CreateTree()方法
在左子树调用CreateTree()方法
这里唯一的问题就在于,如何选取最佳特征将数据集分成两部分。

二、 CART算法用于回归——回归树

回归树是构建以 分段常数为叶节点的树。之前在构建决策树的时候,我们采用熵来度量数据的混乱度,那么如何度量连续值的混乱度呢?我们可以先计算所有数据的均值,然后计算每条数据的值到均值的差值,一般差值取平方。这里的差值我们取的是总值。
回归树构建的要点在于每一次如何选取一个特征以及特征的一个值来对数据集做二元切分。回归树的策略是:遍历所有的特征,对每个特征,遍历这个特征里所有样例的取值,计算以这个取值作为切分时,生成的两个子数据集的平方误差,我们取子数据集平方误差最小的切分方式。伪码如下:

chooseBestSplit
bestS = inf
对于数据集中的每个特征
对于当前特征的每个取值
以此取值为切分值,将数据集切分成两部分
如果两部分子集中的元素个数太少 //剪枝
不切分,返回一个叶节点;
计算两部分子数据集的平方误差之和newS
if newS < bestS
更新当前找到的最好的切分特征和切分值
if S - bestS < tolS //剪枝
不切分,返回一个叶节点;
返回找到的最优切分特征和切分值;

在回归树中,模型返回的其实是目标变量的均值。


三、 树剪枝


1、 预剪枝
预剪枝算法对预先输入的阈值非常敏感,通常通过事先设定每个叶节点元素的数目最小值以及每次切分平方误差减少的最小值来做预剪枝。

2、 后剪枝
使用后剪枝应该先将数据分为 训练集和测试集。从上到下找到叶节点,然后用测试集判定将这些节点 合并是否能降低测试误差
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程;
计算将当前两个叶节点合并后的误差;
计算不合并的误差;
如果合并会降低误差的话,则将叶节点合并。

四、 模型树

用树对数据建模,除了把叶节点设定为常数值意外,还可以将叶节点设定为分段线性函数。 模型树的可解释性是其优于回归树的特点之一。另外模型树也具有 更高的预测精度
将回归树的代码稍加修改,在叶节点生成线性模型而不是常数值。前面回归树中用平方误差的方法来选取最佳切分特征和特征值,在这里这个方法不能再用。对于给定数据集,应该使用线性模型对其进行拟合,然后计算真实的目标值与模型预测值之间的差值平方和,以此为依据来选择最佳切分特征及切分值。
树回归方法在预测复杂数据时会比简单的线性模型更加有效