首页 » 机器学习实战 » 机器学习实战全文在线阅读

《机器学习实战》9.3 将CART算法用于回归

关灯直达底部

要对数据的复杂关系建模,我们已经决定借用树结构来帮助切分数据,那么如何实现数据的切分呢?怎么才能知道是否已经充分切分呢?这些问题的答案取决于叶节点的建模方式。回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。

为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。第3章使用树进行分类,会在给定节点时计算数据的混乱度。那么如何计算连续型数值的混乱度呢?事实上,在数据集上计算混乱度是非常简单的。首先计算所有数据的均值,然后计算每条数据的值到均值的差值。为了对正负差值同等看待,一般使用绝对值或平方值来代替上述差值。上述做法有点类似于前面介绍过的统计学中常用的方差计算。唯一的不同就是,方差是平方误差的均值(均方差),而这里需要的是平方误差的总值(总方差)。总方差可以通过均方差乘以数据集中样本点的个数来得到。

有了上述误差计算准则和上一节中的树构建算法,下面就可以开始构建数据集上的回归树了。

9.3.1 构建树

构建回归树,需要补充一些新的代码,使程序清单9-1中的函数createTree得以运转。首先要做的就是实现chooseBestSplit函数。给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外,该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。因此,函数chooseBestSplit只需完成两件事:用最佳方式切分数据集和生成相应的叶节点。

从程序清单9-1可以看出,除了数据集以外,函数chooseBestSplit还有leafTypeerrTypeops这三个参数。其中leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,而ops是一个用户定义的参数构成的元组,用以完成树的构建。

下面的代码中,函数chooseBestSplit最复杂,该函数的目标是找到数据集切分的最佳位置。它遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。该函数的伪代码大致如下:

对每个特征:    对每个特征值:        将数据集切分成两份        计算切分的误差        如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差返回最佳切分的特征和阈值  

下面给出上述三个函数的具体实现代码。打开regTrees.py文件并加入程序清单9-2中的代码。

程序清单9-2 回归树的切分函数

def regLeaf(dataSet):    return mean(dataSet[:,-1])def regErr(dataSet):    return var(dataSet[:,-1]) * shape(dataSet)[0]def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):    tolS = ops[0]; tolN = ops[1]    #❶(以下两行) 如果所有值相等则退出    if len(set(dataSet[:,-1].T.tolist[0])) == 1:        return None, leafType(dataSet)    m,n = shape(dataSet)    S = errType(dataSet)    bestS = inf; bestIndex = 0; bestValue = 0    for featIndex in range(n-1):        for splitVal in set(dataSet[:,featIndex]):            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue            newS = errType(mat0) + errType(mat1)            if newS < bestS:                bestIndex = featIndex                bestValue = splitVal                bestS = newS    #❷(以下两行)如果误差减少不大则退出    if (S - bestS) < tolS:        return None, leafType(dataSet)    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):        #❸(以下两行)如果切分出的数据集很小则退出          return None, leafType(dataSet)    return bestIndex,bestValue  

上述程序清单中的第一个函数是regLeaf,它负责生成叶节点。当chooseBestSplit函数确定不再对数据进行切分时,将调用该regLeaf函数来得到叶节点的模型。在回归树中,该模型其实就是目标变量的均值。

第二个函数是误差估计函数regErr。该函数在给定数据上计算目标变量的平方误差。当然也可以先计算出均值,然后计算每个差值再平方。但这里直接调用均方差函数var更加方便。因为这里需要返回的是总方差,所以要用均方差乘以数据集中样本的个数。

第三个函数是chooseBestSplit,它是回归树构建的核心函数。该函数的目的是找到数据的最佳二元切分方式。如果找不到一个“好”的二元切分,该函数返回 None并同时调用createTree方法来产生叶节点,叶节点的值也将返回None。接下来将会看到,在函数chooseBestSplit中有三种情况不会切分,而是直接创建叶节点。如果找到了一个“好”的切分方式,则返回特征编号和切分特征值。

函数chooseBestSplit一开始为ops设定了tolStolN这两个值。它们是用户指定的参数,用于控制函数的停止时机。其中变量tolS是容许的误差下降值,tolN是切分的最少样本数。接下来通过对当前所有目标变量建立一个集合,函数chooseBestSplit会统计不同剩余特征值的数目。如果该数目为1,那么就不需要再切分而直接返回❶。然后函数计算了当前数据集的大小和误差。该误差S将用于与新切分误差进行对比,来检查新切分能否降低误差。下面很快就会看到这一点。

这样,用于找到最佳切分的几个变量就被建立和初始化了。下面就将在所有可能的特征及其可能取值上遍历,找到最佳的切分方式。最佳切分也就是使得切分后能达到最低误差的切分。如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点❷。另外还需要检查两个切分后的子集大小,如果某个子集的大小小于用户定义的参数tolN,那么也不应切分。最后,如果这些提前终止条件都不满足,那么就返回切分特征和特征值❸。

9.3.2 运行代码

下面在一些数据上看看上节代码的实际效果,以图9-1的数据为例,我们的目标是从该数据生成一棵回归树。

将程序清单9-2中的代码添加到regTree.py文件并保存,然后在Python提示符下输入:

>>>reload(regTrees)<module /'regTrees/' from /'regTrees.pyc/'>>>> from numpy import *  

图9-1的数据存储在文件ex00.txt中。

>>> myDat=regTrees.loadDataSet(/'ex00.txt/')>>> myMat = mat(myDat)>>> regTrees.createTree(myMat){/'spInd/': 0, /'spVal/': matrix([[ 0.48813]]),/'right/': -0.044650285714285733,/'left/': 1.018096767241379}  

图9-1 基于CART算法构建回归树的简单数据集

再看一个多次切分的例子,考虑图9-2的数据集。

图9-2 用于测试回归树的分段常数数据集

图9-2的数据保存在一个以tab键分隔的文本文档ex0.txt中数据。为从上述数据中构建一棵回归树,在Python提示符下敲入如下命令:

>>> myDat1=regTrees.loadDataSet(/'ex0.txt/')>>> myMat1=mat(myDat1)>>> regTrees.createTree(myMat1){/'spInd/': 1, /'spVal/': matrix([[ 0.39435]]), /'right/': {/'spInd/': 1, /'spVal/':matrix([[ 0.197834]]), /'right/': -0.023838155555555553, /'left/':1.0289583666666664}, /'left/': {/'spInd/': 1, /'spVal/': matrix([[ 0.582002]]),/'right/': 1.9800350714285717, /'left/': {/'spInd/': 1, /'spVal/': matrix([[0.797583]]), /'right/': 2.9836209534883724, /'left/': 3.9871632000000004}}}  

可以检查一下该树的结构以确保树中包含5个叶节点。读者也可以在更复杂的数据集上构建回归树并观察实验结果。

到现在为止,已经完成回归树的构建,但是需要某种措施来检查构建过程否得当。下面将介绍树剪枝(tree pruning)技术,它通过对决策树剪枝来达到更好的预测效果。