决策树算法

时间:2024-01-31 11:59:48

1、决策树的工作原理

(1)找到划分数据的特征,作为决策点

(2)利用找到的特征对数据进行划分成n个数据子集。

(3)如果同一个子集中的数据属于同一类型就不再划分,如果不属于同一类型,继续利用特征进行划分。

(4)指导每一个子集的数据属于同一类型停止划分。

2、决策树的优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关的特征数据

                     缺点:可能产生过度匹配的问题

                     适用数据类型:数值型(需要进行离散化)和标称型

ps:产生过度匹配的原因:

过度匹配的现象:一个假设在训练数据上能够获得比其他假设更好的拟合,但是在训练数据外的数据集上却不能很好的拟合数据。此时我们就叫这个假设出现了overfitting的现象。

原因:在决策树模型搭建中,我们使用的算法对于决策树的生长没有合理的限制和修剪的话,决策树的*生长有可能每片叶子里只包含单纯的事件数据或非事件数据,可以想象,这种决策树当然可以完美匹配(拟合)训练数据,但是一旦应用到新的业务真实数据时,效果是一塌糊涂。

ps:数值型和标称型数据

标称型:标称型目标变量的结果只在有限目标集中取值,如真与假(标称型目标变量主要用于分类)
数值型:数值型目标变量则可以从无限的数值集合中取值,如0.100,42.001等 (数值型目标变量主要用于回归分析)

3、决策树创建分支createBranch()的伪代码实现

检测数据集中的每个子项是否属于同一分类

if so  return 类标签

else

寻找划分数据集的最好特征

划分数据集

创建分支节点

for 每个划分的子集

调用createBranch并增加返回结果到分支节点中(递归调用)

return 分支节点

4、决策树的目标就是将散乱的数据进行划分成有序的数据,那么这个划分前后信息的变化就是信息增益,也就是信息熵

那么对于每个类别分类前后都有相应的信息增益,所以就要计算所有类别的信息期望值


(n表示分类的数目)
下面用具体的代码实现平均信息熵的计算过程:
from math import log
import operator
#计算所有已经分类好的子集的信息商
def calcShannonEnt(dataSet):
    #计算给定的数据集的长度,也就是类别的数目
    numEntries = len(dataSet)
    #创建一个空字典
    labelCounts = {}
    #遍历所有分类的子集
    for featVec in dataSet: #the the number of unique elements and their occurance
        #取出每个子集的键值,也就是对应的类标签,-1的索引值表示最后一个
        currentLabel = featVec[-1]
        #如果当前类标签不在标签库中,就将当前子集的标签加入到标签库中
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        #如果已经存在于标签库中,就将标签库对应的加1
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    #遍历标签库中所有的标签
    for key in labelCounts:
        #根据键值对取出每个类别的次数/总的类别数,也就是每个类别的概率p(i)
        prob = float(labelCounts[key])/numEntries
        #计算所有类别的期望值,也就是平均信息熵
        shannonEnt -= prob * log(prob,2) #log base 2
    return shannonEnt

下面创建一个已经分类好的数据集
#创建一个数据集
def createDataSet():
    dataSet = [[1, 1, \'yes\'],
               [1, 1, \'yes\'],
               [1, 0, \'no\'],
               [0, 1, \'no\'],
               [0, 1, \'no\']]
    labels = [\'no surfacing\',\'flippers\']
    #change to discrete values
    return dataSet, labels

再命令行里面进行测试,计算这些数据集的平均信息熵
>>> import tree
>>> myDat,labels=tree.createDataSet()
>>> myDat
[[1, 1, \'yes\'], [1, 1, \'yes\'], [1, 0, \'no\'], [0, 1, \'no\'], [0, 1, \'no\']]
>>> labels
[\'no surfacing\', \'flippers\']
>>> tree.calcShannonEnt()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: calcShannonEnt() takes exactly 1 argument (0 given)
>>> tree.calcShannonEnt(myDat)
0.9709505944546686
>>>
熵越高,则混合的数据也越多,我们可以在数据集中添加更多的分类,观察熵是如何变化的。
这里我们就不尝试了,到这里我们已经学会了如何计算数据集的无序程度

这里我们先对矩阵的基本知识进行一个补充
d=[]
a = [[1, 1, \'yes\'],
     [1, 1, \'yes\'],
     [1, 0, \'no\'],
     [0, 1, \'no\'],
     [0, 1, \'no\']]
for i in a:
    # print i[:2]
    # print "---"
    # print i[:]#所有的元素
    # print i[0:]#所有的元素
    # print i[1:]#第二列之后所有的 取值为1 ,。。。。
    # print i[2:]#第三列之后所有的 取值为2,。。。。
    # print i[:1]#第一列 取值为 0 
    # print i[:2]#前两列 取值为0,1
    print i[:3]#前三列  取值为0,1,2
要知道矩阵里面的[:]是左闭右开区间

下面先看一个简单的例子
a = [[1, 1, \'yes\'],
     [1, 1, \'yes\'],
     [1, 0, \'no\'],
     [0, 1, \'no\'],
     [0, 1, \'no\']]
for i in a:
    if i[1] == 0:
        b = i[:1]
        print b
        print "=="
        b.extend(i[2:])
        print b
        print "==="
        d.append(b)
        print d
[1]
==
[1, \'no\']
===
[[1, \'no\']]
相当于判断每个元素的第二个字符是否等于0,如果等于,则将这个元素的剩下的字符 组成新的列表矩阵
下面给出按照给定的标准划分数据集的代码
看了上面的简单的例子相信下面的函数应该很容易懂了
#按照给定的标准对数据进行划分
#dataSet:待划分的数据集
#axis:划分的标准
#value:返回值
def splitDataSet(dataSet, axis, value):
# python吾言在函数中传递的是列表的引用,在函数内部对列表对象的修改。
# 将会影响该列表对象的整个生存周期。为了消除这个不良
# 影响,我们需要在函数的开始声明一个新列表对象
# 因为该函数代码在同一数据集上被调用多次,
# 为了不修改原始数据集,创建一个新的列表对象0
    retDataSet = []
    #遍历数据集中的每个元素,这里的每个元素也是一个列表
    for featVec in dataSet:
        #如果满足分类标准
        #axis=0 value=1 如果每个元素的第一个字符为1
        #axis=0 value=0 如果每个元素的第一个字符为0
        if featVec[axis] == value:
            #取出前axis列的数据
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            # list.append(object)向列表中添加一个对象object
            # list.extend(sequence)把一个序列seq的内容添加到列表中
            #把featVec元素的axis+1列后面的数据取出来添加到reducedFeatVec
            reducedFeatVec.extend(featVec[axis+1:])
            #将reducedFeatVec作为一个对象添加到retDataSet
            retDataSet.append(reducedFeatVec)
    return retDataSet


下面我们在命令行里面进行测试
>>> import tree
>>> myDat,labels = tree.createDataSet()
>>> myDat
[[1, 1, \'yes\'], [1, 1, \'yes\'], [1, 0, \'no\'], [0, 1, \'no\'], [0, 1, \'no\']]
>>> tree.splitDataSet(myDat,0,1)
[[1, \'yes\'], [1, \'yes\'], [0, \'no\']]
>>> tree.splitDataSet(myDat,0,0)
[[1, \'no\'], [1, \'no\']]
>>> tree.splitDataSet(myDat,1,0)
[[1, \'no\']]
>>> tree.splitDataSet(myDat,1,0)

可以看出我们的分类标准不一样,最终的结果也就不一样,这一步是根据我们的标准选出符合我们确定的标准的数据

现在我们可以循环计算分类后的香农熵以及splitDataSet()函数来寻找最好的分类标准

#计算所有已经分类好的子集的信息商
def calcShannonEnt(dataSet):
    #计算给定的数据集的长度,也就是类别的数目
    numEntries = len(dataSet)
    #创建一个空字典
    labelCounts = {}
    #遍历所有分类的子集
    for featVec in dataSet: #the the number of unique elements and their occurance
        #取出每个子集的键值,也就是对应的类标签,-1的索引值表示最后一个
        currentLabel = featVec[-1]
        #如果当前类标签不在标签库中,就将当前子集的标签加入到标签库中
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        #如果已经存在于标签库中,就将标签库对应的加1
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    #遍历标签库中所有的标签
    for key in labelCounts:
        #根据键值对取出每个类别的次数/总的类别数,也就是每个类别的概率p(i)
        prob = float(labelCounts[key])/numEntries
        #计算所有类别的期望值,也就是平均信息熵
        shannonEnt -= prob * log(prob,2) #log base 2
    return shannonEnt

#寻找最好的分类标准
def chooseBestFeatureToSplit(dataSet):
    #计算原始数据的特征属性的个数=len-标签列
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    #计算原始数据的原始熵
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    #遍历每个特征属性,遍历每一列
    for i in range(numFeatures):        #iterate over all the features
        #遍历数据集中每行除去最后一列标签的每个元素
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        #取出每行的特征属性存入set集合
        uniqueVals = set(featList)       #get a set of unique values
        newEntropy = 0.0
        #遍历每个特征属性
        for value in uniqueVals:
            #dataset:带分类的数据集 i:分类标准  value:返回值
            #取出符合分类标准的数据
            subDataSet = splitDataSet(dataSet, i, value)
            #符合分类标准的数据长度/数据集的总长度=p(i)
            prob = len(subDataSet)/float(len(dataSet))
            #计算分类之后平均信息熵---》香农熵
            newEntropy += prob * calcShannonEnt(subDataSet)
            #分类前后的香农熵之差
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
       #如果差大于最好的时候的差
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            #则继续分类,注意的是香农熵越大,越无序,所以我们要找的是差值最大的时候,也就是分类最有序的时候
            #将前一次的赋值给bestInfoGain,下一次和前一次比较,如果下一次的差小于前者则停止
            bestInfoGain = infoGain         #if better than current best, set to best
            #将当前的分类标准i赋值给bestFeature
            bestFeature = i
    return bestFeature        

我们在命令行里面测试
>>> reload(tree)
<module \'tree\' from \'E:\python2.7\tree.py\'>
>>> myDat,labels=trees.createDataSet()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name \'trees\' is not defined
>>> myDat,labels=tree.createDataSet()
>>> myDat
[[1, 1, \'yes\'], [1, 1, \'yes\'], [1, 0, \'no\'], [0, 1, \'no\'], [0, 1, \'no\']]
>>> tree.chooseBestFeatureToSplit(myDat)
0
>>>

我们通过chooseBestFeatureToSplit()得到第0个特征是最好的分类标准
subDataSet = splitDataSet(dataSet, i, value) 也就是这里的i就是选取的分类标准 value就是第i个特征值的值

找到了最好的分类标准,那么我们还需要找到次数最多的类别名称
下面我们选出次数最多的类别名称
#选出次数最多的类别名称
def majorityCnt(classList):
    #创建一个空字典,用于统计每个类别出现的次数
    classCount={}
    #遍历所有分类的子集中类别的出现的次数
    for vote in classList:
        #如果子集中没有该类别标签,则将该类别添加到字典classCount中
        if vote not in classCount.keys():
            classCount[vote] = 0
        #否则将该字典里面的标签的次数加1
        classCount[vote] += 1
        #iteritems()返回一个迭代器,工作效率高,不需要额外的内存
        #items()返回字典的所有项,以列表的形式返回
        #这里通过迭代返回每个类别出现的次数
        #key=operator.itemgetter(1)获取每个迭代器的第二个域的值,也就是次数,按照 类别出现的次数降序排列
        #reverse是一个bool变量,表示升序还是降序排列,默认为false(升序排列),定义为True时表示降序排列
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    #取出类别最高的类别以及对应的次数
    return sortedClassCount[0][0]
下面开始创建树
#创建决策树
#dataset:数据集     labels:包含了数据集中的所有标签
def createTree(dataSet, labels):
    #创建一个classList的列表,取出每个元素的最后一列:标签类[\'yes\', \'yes\', \'no\', \'no\', \'no\']
    classList = [example[-1] for example in dataSet]
    \'\'\'
    a = [[1, 1, \'yes\'],
     [1, 1, \'yes\'],
     [1, 0, \'no\'],
     [0, 1, \'no\'],
     [0, 1, \'no\']]
    classList = [example[-1] for example in a]
    print classList
    print classList[0]
    print classList.count(classList[0])
    print len(classList)
    # [\'yes\', \'yes\', \'no\', \'no\', \'no\']
    # yes
    # 2
    # 5
    \'\'\'
    #统计标签列表中的第一个元素的个数是否等于标签列表的长度
    #相等就意味着所有的元素属于同一个类别,那么就可以不再划分,这是最简单的情况
    if classList.count(classList[0]) == len(classList):
        #如果相等就返回标签列表的第一个元素
        return classList[0]  # stop splitting when all of the classes are equal
    #或者数据集的第一个元素的长度等于1,表示该元素只有一个特征值,同样停止划分
    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet
        #返回次数最多的类别的名称
        return majorityCnt(classList)
    #在数据集中寻找最好的分类标准:最鲜明的特征属性
    #chooseBestFeatureToSplit()函数返回的是一个整数,表示第几个特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
    #从标签库中将该特征选出来
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    #从标签列表中删除该特征属性
    del (labels[bestFeat])
    #选出数据集中每行的第bestFeat个元素组成一个列表
    #取出第bestFeati列元素
    featValues = [example[bestFeat] for example in dataSet]
    #从列表中创建一个不重复的集合
    uniqueVals = set(featValues)
    #遍历这些不重复的特征值
    for value in uniqueVals:
        #复制所有的标签,这里创建一个新的标签是为了防止函数调用createtree()时改变原始列表的内容
        subLabels = labels[:]  # copy all of labels, so trees don\'t mess up existing labels
        #递归调用创建决策树的函数
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

命令行:
>>> import tree
>>> myDat,labels =tree.createDataSet()
>>> myTree
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name \'myTree\' is not defined
>>> myTree=tree.createTree(myDat,labels)
>>> myTree
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
>>>

下面利用matplolib绘制树形图

import matplotlib.pyplot as plt

#决策点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
#子节点
leafNode = dict(boxstyle="round4", fc="0.8")
#箭头
arrow_args = dict(arrowstyle="<-")

#使用文本注解绘制树节点
#            节点注释  当前子节点   父节点    节点类型
def plotNode(nodeTxt, centerPt,  parentPt, nodeType):

    #nodeTxt:注释的内容,xy:设置箭头尖的坐标 ,被注释的地方(x,y)
    # xytext:xytext设置注释内容显示的起始位置,文字注释的地方,
    #arrowprops用来设置箭头
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords=\'axes fraction\',
                            xytext=centerPt, textcoords=\'axes fraction\',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
\'\'\'
annotate
# 添加注释  
# 第一个参数是注释的内容  
# xy设置箭头尖的坐标  
# xytext设置注释内容显示的起始位置  
# arrowprops 用来设置箭头  
# facecolor 设置箭头的颜色  
# headlength 箭头的头的长度  
# headwidth 箭头的宽度  
# width 箭身的宽度  
\'\'\'

def createPlot():
    #创建一个白色画布
   fig = plt.figure(1, facecolor=\'white\')
    #清除画布
   fig.clf()
   #在画布上创建1行1列的图形
   createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
   plotNode(\'a decision node\', (0.5, 0.1), (0.1, 0.5), decisionNode)
   plotNode(\'a leaf node\', (0.8, 0.1), (0.3, 0.8), leafNode)
   plt.show()

在命令行里面测试
>>> from imp import reload
>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> treeplot.createPlot()
>>>

效果图如下







想要构造一颗完整的注解树,还需要知道树的x坐标和y坐标,对应的x坐标即为子节点的个数,y坐标即为树的深度

#获取叶节点数目
def getNumLeafs(myTree):
    #初始化叶节点的值为0
    numLeafs = 0
    #取出树的第一个关键字{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
    #myTree.keys()[0]=no surfacing
    #myTree[firstStr]={0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}=secondDict
    #secondDict.keys()=[0,1]
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        #secondDict[1]是一个字典 如果子节点为字典类型
        if type(secondDict[key]).__name__==\'dict\':#test to see if the nodes are dictonaires, if not they are leaf nodes
            #递归调用getNumLeafs(myTree)
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

#获取树的层数
def getTreeDepth(myTree):
    #初始化树的最大深度为0
    maxDepth = 0
    # #取出树的第一个关键字{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
    #myTree.keys()[0]=no surfacing
    #myTree[firstStr]={0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}=secondDict
    #secondDict.keys()=[0,1]
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # secondDict[1]是一个字典 如果子节点为字典
        if type(secondDict[key]).__name__==\'dict\':#test to see if the nodes are dictonaires, if not they are leaf nodes
            #将深度加1之后继续递归调用函数getTreeDepth()
            thisDepth = 1 + getTreeDepth(secondDict[key])
            #如果是叶子节点,深度为1
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth
命令行:
>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> treeplot.retrieveTree(0)
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
>>> treeplot.getNumLeafs(myTree)
3(树的子节点数,相当于树的宽度)
>>> treeplot.getTreeDepth(myTree)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "E:\Python36\treeplot.py", line 77, in getTreeDepth
    firstStr = myTree.keys()[0]
TypeError: \'dict_keys\' object does not support indexing(要注意的是python3不能直接解析字典的keys列表,需要手动将keys转为list列表)
>>> reload(myTree)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "E:\Python36\lib\imp.py", line 315, in reload
    return importlib.reload(module)
  File "E:\Python36\lib\importlib\__init__.py", line 139, in reload
    raise TypeError("reload() argument must be a module")
TypeError: reload() argument must be a module
>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> treeplot.getTreeDepth
<function getTreeDepth at 0x0000020A95A799D8>
>>> treeplot.getTreeDepth(myTree)
2(树的深度)
>>>

利用递归画出整个树
#在父子节点中添加文本信息
def plotMidText(cntrPt, parentPt, txtString):
    #父节点和子节点的中点坐标
    #parentPt[0]:父节点的x坐标  cntrPt[0]:左孩子节点的x坐标
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

#画树
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    #获取所有的叶子节点的个数,决定了x轴的宽度
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    #获取树的深度,决定了y轴的高度
    depth = getTreeDepth(myTree)
    # #取出树的第一个关键字{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
    # myTree.keys()[0]=no surfacing
    #取出第一个关键字,作为第一个节点的文本注释
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    #==============参考博客地址:https://www.cnblogs.com/fantasy01/p/4595902.html========================#
    #plotTree.xOff即为最近绘制的一个叶子节点的x坐标
    #plotTree.yOff 最近绘制的一个叶子节点的y的y坐标
    #在确定当前节点位置时每次只需确定当前节点有几个叶子节点,
    # 因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW*1(因为总长度为1),
    #因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW*1,
    # 但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW*1,
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    #在当前节点和父节点之间添加文本信息
    plotMidText(cntrPt, parentPt, nodeTxt)
    #使用文本注解绘制节点
    #            节点注释  当前子节点   父节点    节点类型
   # def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    #myTree[firstStr]={0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}=secondDict
    secondDict = myTree[firstStr]
    #当前节点y坐标的偏移,绘制一层就减少树的1/深度
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    #遍历第二层的节点
    for key in secondDict.keys():
        #如果该节点是一个字典
        if type(secondDict[key]).__name__==\'dict\':#test to see if the nodes are dictonaires, if not they are leaf nodes
            #以该节点继续画子节点,
            #def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it\'s a leaf node print the leaf node
            #如果该节点不是一个字典,那就是一个子节点
            #计算当前节点的x坐标的偏移
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            #def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            #plotMidText(cntrPt, parentPt, txtString):
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it\'s a tree, and the first element will be another dict
#创建绘图区
def createPlot(inTree):
    fig = plt.figure(1, facecolor=\'white\')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    #计算树形图的全局尺寸
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW;
    plotTree.yOff = 1.0;
    #递归调用plotTree函数
    plotTree(inTree, (0.5,1.0), \'\')
    plt.show()



命令行测试:




>>> reload(treeplot)
<module \'treeplot\' from \'E:\\Python36\\treeplot.py\'>
>>> myTree=treeplot.retrieveTree(0)
>>> treeplot.createPlot(myTree)




为此我们已经可以构建一个决策树并且可以将我们构建的决策树通过matplotlib的库函数画出来
那么下面我们将利用他来进行实际的分类

#{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
#使用决策树的分类函数
\'\'\'#>>> tree.classify(myTree,labels,[1,0])
\'no\'
>>> tree.classify(myTree,labels,[1,1])
\'yes\'\'\'
def classify(inputTree,featLabels,testVec):
    #找到第一个特征值\'no surfacing\'
    firstStr = list(inputTree.keys())[0]
    #取出第一个特征值的值value作为第二个字典树
    secondDict = inputTree[firstStr]
    #在所有的标签列表中找到第一个特征值\'no surfacing\'对应特征的名称的下标
    #featLabels[\'no surfacing\', \'flippers\']
    #featLabels.index(\'no surfacing\')=0
    #featIndex=0
    featIndex = featLabels.index(firstStr)
    #在测试集中找到该下标对应的特征属性
    #testVec[0]=1
    key = testVec[featIndex]
    # {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}中找到对应的value
    #在字典中根据这个属性扎到对应的值
    #secondDict[1]={\'flippers\': {0: \'no\', 1: \'yes\'}}
    valueOfFeat = secondDict[key]
    #如果该值为一个字典
    if isinstance(valueOfFeat, dict):
        #继续执行分类函数(递归分类)
        classLabel = classify(valueOfFeat, featLabels, testVec)
        #否则将该值复制给分类标签标签,返回当前的分类标签
    else: classLabel = valueOfFeat
    return classLabel





命令行测试:
>>> reload(tree)
<module \'tree\' from \'E:\\Python36\\tree.py\'>
>>> dataSet,labels = tree.createDataSet()
>>> labels (分类的标签列表)
[\'no surfacing\', \'flippers\']
>>> myTree = treeplot.retrieveTree(0)
>>> myTree(构建的决策树)
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}}}
>>> tree.classify(myTree,labels,[1,0])
\'no\'
>>> tree.classify(myTree,labels,[1,1])
\'yes\'
>>>

构建好一颗可以进行分类的决策树之后需要进行序列化,存储在硬盘上,可以根据需要随便调用任一个对象‘

#存储决策树到硬盘
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, \'w\')
    pickle.dump(inputTree, fw)
    fw.close()

#从硬盘加载决策树
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)
命令行运行:
>>> reload(tree)
<module \'tree\' from \'E:\\Python36\\tree.py\'>
>>> fr = open(\'E:\Python36\lenses.txt\')
>>> lenses = [inst.strip().split(\'\t\') for inst in fr.readlines()](读取每行,并且根据tab符号进行分隔)
>>> lensesLabels =[\'age\',\'prescript\',\'astigmatic\',\'tearRate\']
>>> lensesTree = tree.createTree(lenses,lensesLabels)
>>> lensesTree
{\'tearRate\': {\'reduced\': \'no lenses\', \'normal\': {\'astigmatic\': {\'yes\': {\'prescript\': {\'myope\': \'hard\', \'hyper\': {\'age\': {\'pre\': \'no lenses\', \'young\': \'hard\', \'presbyopic\': \'no lenses\'}}}}, \'no\': {\'age\': {\'pre\': \'soft\', \'young\': \'soft\', \'presbyopic\': {\'prescript\': {\'myope\': \'no lenses\', \'hyper\': \'soft\'}}}}}}}}
>>> treeplot.createPlot(lensesTree)

最后画出决策树