机器学习笔记(9)---决策树

时间:2023-02-13 07:41:09

决策树

决策这一节相对KNN算法来说难了点,因为本节需要先理解熵和信息增益的概念,理解后再看就比较容易了。不过我也是先看的代码,在看代码的过程中没明白它为什么要这么做,然后再去查相关的书籍,再把熵和信息增益的概念理解了,再去看代码,就明白了。

基本概念

基本概念不懂没关系,先去看源码。然后再回顾和总结。

香农熵(也叫信息熵)简称熵,其计算公式如下:

H = i = 1 n p ( x i ) l o g 2 p ( x i )

其中的 p ( x i ) 是该类别的概率。

熵越小纯度越高,熵越大越杂乱无章。比如左手一把盐,右手一瓶水,此时熵很小,但如果把盐倒到水里,那么此时熵就很大了。

决策树的根结点的熵是最大的,我们的目标就是进行分类,让节点的熵变成0,就表示节点都是同一类的了,需要关注的是在这个过程中使熵变小的最快的分类是最优分类,我们要做的就是找到这样的分类。

信息熵和信息增益

信息熵计算

例子1

以MLiA书上的数据,计算信息熵和信息增益的过程

机器学习笔记(9)---决策树

如上图的数据集,最终的类别只有两类:是鱼类和不是鱼类,分别占2/5和3/5。

按公式 H = i = 1 n p ( x i ) l o g 2 p ( x i ) (MLiA P35)可知,这里的n=2,即共2个类别, x 1 表示是鱼类; x 2 表示不是鱼类。 p ( x i ) 表示某一类的概率,所以:

p ( x 1 ) = 2 / 5 = 0.4 , p ( x 2 ) = 3 / 5 = 0.6

(1) H = i n p ( x i ) log 2 p ( x i ) (2) = ( 0.4 l o g 2 0.4 + 0.6 l o g 2 0.6 ) (3) ( 0.528 0.442 ) (4) 0.97

以上计算可通过python验证:

机器学习笔记(9)---决策树

例子2

在例子1的基础上,把数据增加一类,比如把第1条样本的结果分类为“可能”,那么一共就有三类:“是鱼类”、“不是鱼类”,“可能是鱼类”,这三者占比分别为1/5,1/5,3/5。

此时的信息熵为:

(1) H = i n p ( x i ) log 2 p ( x i ) (2) = ( 0.2 l o g 2 0.2 + 0.2 l o g 2 0.2 + 0.6 l o g 2 0.6 ) (3) ( 0.464 0.464 0.442 ) (4) 1.37

对比例子1,可发现类别越多,数据就越不纯,信息熵就越大;反之,类别越少,数据就越纯,信息熵就越小,当只有一个类别时,信息熵就是0。

信息增益

从信息熵的概念我们可以知道,熵越小表示纯度越高,直到熵为0时就表示某一类已经完全分好类了。而在每一次分类时,我们需要找到一个类别,以这个类别分类后的熵最小,就是我们想要的,当前熵最小也就是上一级熵减当前类别的熵最大,把这个差就叫信息增益,所以我们的目标就是找信息增益最大的即可。

还是以这个数据为例子:

机器学习笔记(9)---决策树

上一节已经计算出来当前熵H=0.97,那第一个分类到底是拿“不浮出水面是否可以生存”(后续简称第0列属性)这个属性去分类,还是拿“是否有脚蹼”(后续称第1列属性)去分类呢?这就需要遍历这两个属性并计算每个属性信息增益,找到信息增益最大的属性作为最优的分类属性。下面我们分别计算这两个属性的信息增益。

计算第0列属性信息增益

它有两个可能的取值:{是,否},使用该属性对样本进行划分,可得到2个子集,分别记为D1 (不浮出水面是否可以生存=是),D2(不浮出水面是否可以生存=否)。D1和D2分别占3/5和2/5。

子集D1包含编号为{1,2,3}的3个样例,其中正例(是鱼类)占2/3,反例占1/3;子集D2 包含编号为{4,5}的2个样例,都是反例。所以按照第0列属性划分之后获得的两个分支结点的信息熵为:

H ( D 1 ) = ( 2 3 l o g 2 2 3 + 1 3 l o g 2 1 3 ) = 0.918
H ( D 2 ) = ( 1 l o g 2 1 ) = 0

信息增益为:

(1) G a i n ( D , ) = H v = 1 2 | D v | D H ( D v ) (2) = 0.97 ( 3 5 0.918 + 2 5 0 ) (3) 0.419

D :当前样本集合.

D v :以某一属性进行划分,这个属性中的某类别样本就是 D v ,比如以“不浮出水面是否可以生存”来划分,这个属性值为“是”的 D v = 3 ,为“否”的 D v = 2 ,所以 | D v | D 就好理解了,前者就是3/5,后者就是2/5。

H ( D ) :表示样本的信息熵。西瓜书上记作 E n t ( D )

如果想了解的更清楚,可参考西瓜书4.2节划分选择。

计算第1列属性信息增益

​ 同样的。该列属性有两个可能的取值:{是,否},使用该属性对样本进行划分,可得到2个子集,分别记为 D 1 (是否有脚蹼=是), D 2 (是否有脚蹼=否)。 D 1 D 2 分别占4/5和1/5。

​ 子集 D 1 包含编号为{1,2,4,5}的4个样例,其中正例(是鱼类)占1/2,反例占1/2;子集 D 2 包含编号为{3}的1个样例,都是反例。所以按照第1列属性划分之后获得的两个分支结点的信息熵为:

H ( D 1 ) = ( 1 2 l o g 2 1 2 + 1 2 l o g 2 1 2 ) = 1.0
H ( D 2 ) = ( 1 l o g 2 1 ) = 0

信息增益为:

(1) G a i n ( D , ) = H v = 1 2 | D v | D H ( D v ) (2) = 0.97 ( 4 5 1.0 + 1 5 0 ) (3) = 0.17

因为 Gain(D,不浮出水面是否可以生存) >Gain(D,是否有脚蹼) ,所以以第0列来分类。

可结合着代码一起看,选择信息增益最好的代码为chooseBestFeatureToSplit()函数,另外一个例子可看西瓜书上4.2.1节的例子。

MLiA中的决策树代码

贴上带自己注释的完整代码,供参考。

#coding:utf-8

#原文地址:http://blog.csdn.net/rosetta
#python3.6
#author:sweird
#date:2018.2.5

from math import log
import operator
import matplotlib.pyplot as plt

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    #change to discrete values
    return dataSet, labels

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet: #the the number of unique elements and their occurance
        currentLabel = featVec[-1] #取最后一列值作为Lable
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1 #统计每一类的个数,比如'yes"类别2个,“no"类别3个
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries #计算该分类的概率,比如yes分类概率为2/5,no分类概率为3/5
        shannonEnt -= prob * log(prob,2) #log(prob,2),以2为底的log。这里计算香农熵。使用书本P35最底下的公式。
        #从这个公式里可以看出,如果只有一个类别,那么prob=1.0,熵就是0,也就是说只有一类那就不需要再分类了!
    return shannonEnt

def calcShannonEnt_test():
    myDat, labels = createDataSet()
    print(myDat)

    shannonEnt = calcShannonEnt(myDat)
    print(shannonEnt)#0.97

    #增加一个类别后看熵的变化
    myDat[0][-1]='maybe'
    print(myDat)
    shannonEnt2=calcShannonEnt(myDat)
    print(shannonEnt2)#1.37

    #只有一个类别的熵
    myDat[0][-1]='no'
    myDat[1][-1]='no'
    print(myDat)
    shannonEnt2=calcShannonEnt(myDat)
    print(shannonEnt2)#0.0

    #所以从上面的实验可知,熵越小纯度越高,熵越大越杂乱无章。

#dataSet = {list}[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
#axis用来取dataSet中每一个元素中第0列一样的
#value表示选择axis这一列的值是多少。
#比如axis=0,value=1,被选择的数据就是[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no']],返回除其自身以外的数据[[1, 'yes'], [1, 'yes'], [0, 'no']]
#如果axis=0,value=0,被选择的数据就是[[0, 1, 'no'], [0, 1, 'no']],返回[[1, 'no'], [1, 'no']]
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

def splitDataSet_test():
    myDat,labels = createDataSet()
    print(myDat)

    retDataSet1 = splitDataSet(myDat,0,1)
    retDataSet2 = splitDataSet(myDat,0,0)
    print(retDataSet1)
    print(retDataSet2)

#这个可看西瓜书上4.2.1节的解释,涉及到熵的计算和信息增益的计算。
#返回最好的用于划分数据集的特征,0表示dataSet数据集的第0列(即是否可浮出水面),1表示第1列(即是否有脚蹼)
#该函数的主要作用:遍历样本的所有属性(带标签的),计算按照该属性分类后的信息增益,选择最大的信息增益所在的属性来分类。
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  #计算特征的个数,这里是2个(是否可浮出水面可以生存和是否有脚蹼)
    baseEntropy = calcShannonEnt(dataSet)#计算熵,值为0.97,这个在calcShannonEnt_test()中已经学过了。
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #iterate over all the features
        featList = [example[i] for example in dataSet]#取出每个样本第i个属性,放到featList中。
        #dataSet = {list}[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]#
        #如果i=0时
        #featList = [1, 1, 1, 0, 0]
        #如果i=1时
        #featList = [1, 1, 0, 1, 1]
        uniqueVals = set(featList) #去重后的放到uniqueVals set中,如uniqueVals = {0, 1}
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature                      #returns an integer

def chooseBestFeatureToSplit_test():
    myDat, lables = createDataSet()
    bestFeature = chooseBestFeatureToSplit(myDat)
    #疑问:信息熵是越大越好还是?
    #信息增益和信息熵的关系?

    print(myDat)
    print(bestFeature)

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

#这块代码看起来有点费劲。
#主要功能是给定实验数据和标签,创建一棵决策数。注意labels会被改写。
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]#stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

def createTree_test():
    myDat,labels=createDataSet()#注意,这里的labels不是结果类别,而是属性类别。
    myTree = createTree(myDat,labels)

    print(myTree)


def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1:{'feet':{0:'no', 1:'yes'}}}}}},
                  {'脐部':{"凹陷":'好瓜', "稍凹":{"根蒂":{"稍蜷":{"色泽":{"青绿":"好瓜", "乌黑":"好瓜", "浅白":"好瓜"}}, "蜷缩":"坏瓜", "硬挺":"好瓜"}}, "平坦":"坏瓜"}},
                  ]#第四条参考西瓜书P83 图4.7,自己写的数据,然后可以正常显示到图中。
    return listOfTrees[i]

#该函数就是决策树预测函数。
#输入已经创建好的决策树和属性类别标签,以及待预测样本的属性。
#输出该标本属于哪一类。
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

def classify_test():
    myDat,ori_labels = createDataSet()#书上使用了种方法,实际上只要labels就可以,myDat没有用。
    print("myDat", myDat)
    # labels = ['no surfacing', 'flippers']#可注释上述两句,放开这句也能达到效果。
    print("labels", ori_labels)

    labels = ori_labels.copy()
    # myTree = retrieveTree(0)
    # print("myTree 1", type(myTree),myTree)
    myTree=createTree(myDat,labels)#该函数会改变labels值,所以上面进行了一次拷贝,因为原始标签ori_labels要在classify()中使用。
    #myTree也可以使用retrieveTree()获取手动创建的决策树,用于测试。
    print("myTree 2", type(myTree),myTree)

    result = classify(myTree,ori_labels,[1,0])
    print(result)

def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'wb')#python3改成一定要用二进制存储,所以打开属性一定要有‘b'。
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename, 'rb')#同样读取时也要加'b’
    return pickle.load(fr)

def store_trees_test():
    filename = "myTreeStorage.txt"
    myDat,ori_labels = createDataSet()
    labels = ori_labels.copy()
    myTree=createTree(myDat,labels)

    print("myTree", myTree)
    storeTree(myTree,filename)

    result = grabTree(filename)
    print("result", result)

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, fontproperties="SimHei")
    #变量createPlot.ax1是调用plotNode函数的地方传进来的,python中的变量默认全局有效。
    #annototate在图形中增加带箭头的注释。可参考“pyplot_test.py”代码“pyplot的文本显示”一节。
    #详细可参考:https://matplotlib.org/api/_as_gen/matplotlib.pyplot.annotate.html?highlight=annotate#matplotlib.pyplot.annotate
    #第1个参数nodeTxt是要注释的内容
    #第2个参数xy=()被注释的地方
    #第4个参数xytext=()是插入文本的地方
    #fontproperties="SimHei",支持中文。

#使用递归遍历的方法,获取叶子个数。
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in list(secondDict.keys()):
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in list(secondDict.keys()):
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties="SimHei")

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in list(secondDict.keys()):
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

#lenses_test()用于对给定的隐形眼睛数据集(数据来源于UCI数据库,UCI数据库是加州大学欧文分校(University of CaliforniaIrvine)提出的用于机器学习的数据库,这个数据库目前共有335个数据集,其数目还在不断增加)创建决策树,并使用Matplot画出决策树,然后使用classify对给定的输入预测结果。
#有四种分类属性,分别是:age(年龄)、prescript(症状)、astigmatic(是否散光)、tearRate(眼泪数量)
#age(年龄)有三种值:young(年经的),pre(翻译成啥?), presbyopic(老花眼)
#prescript(症状):hyper(高度近视)和myope(普通近视)
#astigmatic(是否散光)
#tearRate(眼泪数量):normal(正常)和reduced(减少)
#预测的结果有三种:hard(硬材质)、soft(软材质)和no lenses(不适合佩戴隐形眼镜)
def lenses_test():
    fr = open("lenses.txt")
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    ori_lensesLabels = ['age','prescript','astigmatic','tearRate']#有四种分类属性,分别是:年龄、症状、是否散光、眼泪数量。

    lensesLabels = ori_lensesLabels.copy()#因为createTree()函数会修改lensesLabels,所以这里做一个拷贝。
    lensesTree = createTree(lenses,lensesLabels)

    print("lensesTree:",lensesTree)
    print("lensesLabels:",lensesLabels)

    createPlot(lensesTree)#使用matplotlib画图。
    #进行分类预测的时候,可肉眼看决策树,也可使用下面的classify()函数进行。

    result = classify(lensesTree,ori_lensesLabels,["presbyopic", "myope", "no", "reduced"])#输入一条实例的属性为presbyopic(年龄不知道翻译成啥。),myope(普通近视),no(不散光),reduced(眼泪减少),输出是否需要佩戴隐藏眼睛,以及隐形眼睛的材质。
    #结果有三类:hard(硬材质)、soft(软材质)和no lenses(不适合佩戴隐形眼镜)
    print(result)


if __name__ == '__main__':
    #3.1决策树构造
    #3.1.1信息增益。计算给定数据集的香农熵
    # calcShannonEnt_test()

    #3.1.2划分数据集
    #按照给定特征划分数据集
    # splitDataSet_test()
    #选择最好的数据集划分方式
    # chooseBestFeatureToSplit_test()

    #3.1.3递归构建决策树
    #其实到本节为止,整棵决策树已经画出来了,只是不太直观而已。
    # createTree_test()

    #3.2节使用mattplot把决策树直观的展示出来。可参见“treePlotter.py”

    #3.3测试和存储分类器
    #3.3.1使用决策数进行分类。
    # classify_test()
    #3.3.2决策树的存储。
    # 因为创建一棵决策树会很慢,所以可以先把决策树保存到硬盘上,在用到的时候读取出来。
    # store_trees_test()

    #3.4.使用决策树预测隐形眼镜类型。
    lenses_test()


决策树总结

学完以后才发现,决策树其实也是很简单的,它无非就是两个关键概念信息熵和信息增益,而实际上把信息熵的概念理解了,信息增益就好理解的。这两个概念的理解可看本大单节开始部分的内容。

MLiA第3章决策树内容,就是对给定的数据按类别遍历划分,然后计算出划分后的信息熵,再计算出信息增益,算出最好的分类类别,然后按此类别分类,最终构造一棵决策数。最后输入决策数、属性值(比如差别西瓜时有:挤部、根蒂和色泽)以及待预测的样本的属性值,返回该本的类别(比如是好瓜还是坏瓜)