代码注释:机器学习实战第3章 决策树

时间:2022-12-20 11:47:36

写在开头的话:在学习《机器学习实战》的过程中发现书中很多代码并没有注释,这对新入门的同学是一个挑战,特此贴出我对代码做出的注释,仅供参考,欢迎指正。

1、trees.py

#coding:gbk
from math import log
import operator


#作用:建立数据集
#输出:数据集,标签名称
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels


#作用:计算香农熵
#输入:数据集列表,最后一列为类
#输出:数据集香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)#数据集中实例总数
    labelCounts = {}#创建字典,表示类标签出现次数
    for featVec in dataSet:
        currentLabel = featVec[-1]#最后一列,即类标签
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0#香农熵
    for key in labelCounts:#对每个类标签来说
        prob = float(labelCounts[key]) / numEntries#标签出现概率
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


#作用:按照给定特征划分数据集,去除axis对应列特征值等于value的值
#输入:待划分的数据集,划分数据集的特征即列数,需要返回的特征的值
#输入:划分后的数据集
def splitDataSet(dataSet, axis, value):
    retDataSet = []#返回列表
    for featVec in dataSet:#对数据集中每一行
        if featVec[axis] == value:#如果相等
            reducedFeatVec = featVec[:axis]#该行和下一行的作用是得到去除featVec[axis]的列表
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)#将去除featVec[axis]的列表添加到返回列表中
    return retDataSet


#作用:得到最好的数据集划分方式
#输入:数据集列表,最后一列为类
#输出:最好的数据集划分方式对应的特征值
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1#dataSet特征数,-1表示最后一列为类别标签
    baseEntropy = calcShannonEnt(dataSet)#dataSet的香农熵
    bestInfoGain = 0.0;#最大信息熵
    bestFeature = -1;#最佳特征值
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]#列表推导式,找到第i个特征对应的属性值,注意是列表,会有多个相同的属性值
        uniqueVals = set(featList)#转换为集合,消除相同的属性值,集合里只能存在不相同的属性值
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))#注意float,不能两个int值相除,只能得int值
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):#如果新特征值拥有更大的信息熵
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


#作用:返回出现最多的分类名称
#输入:分类名称的列表
#输出:出现最多的分类名称
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1#出现频率加1
    sortedClassCount = sorted(classCount.iteritems,#iteritems()表示将classCount以一个迭代器对象返回
                              key = operator.itemgetter(1), reverse = true)#operator.itemgetter(1)表示第2维数据即值,reverse = True表示从大大小排列
    return sortedClassCount


#作用:创建树
#输入:数据集,标签名称
#输出:树的字典形式
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]#列表推导式,得类别标签列表
    #类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):#classList.count(classList[0])表示将计算第一个类别出现的次数
        return classList[0]
    #遍历完所有特征时返回出现次数最多的类别
    #该程序用到了递归,此为递归退出条件
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)#最好的数据集划分方式对应的特征值
    bestFeatLabel = labels[bestFeat]#最好的数据集划分方式对应的特征值对应的类别名称
    myTree = {bestFeatLabel:{}}#建立字典,键为根节点类别名称
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]#列表推导式,得特征值对应值的列表
    uniqueVals = set(featValues)#建立集合
    for value in uniqueVals:
        subLabels = labels[:]#使用新变量代替原始列表
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)#创建子节点
    return myTree


#作用:使用决策树的分类函数
#输入:树的字典形式,分类标签,待分类矢量
#输出:分类标签
def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)#得firstStr在分类标签中的索引
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


#作用:储存决策树
#输入:树的字典形式,文件名字
#输出:无
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


#作用:提取决策树
#输入:文件名字
#输出:树的字典形式
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

2、treePlotter.py

#coding:gbk
import matplotlib.pyplot as plt

#定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

#作用:绘制带箭头的注解
#输入:注解,注解坐标,起点坐标,注解类型
#输出:无
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', xytext = centerPt, textcoords = 'axes fraction',
va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

#作用:绘制图像
#输入:
#输出:无
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white')
fig.clf()
axprops = dict(xticks = [], yticks = [])
createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))#树的宽度
plotTree.totalD = float(getTreeDepth(inTree))#数的高度
plotTree.x0ff = -0.5/plotTree.totalW;#根节点x值?
plotTree.y0ff = 1.0;#根节点y值,为1.0表示放在最高点
plotTree(inTree, (0.5, 1.0), '')#绘制根节点,0.5表示在x方向的中间,1.0表示在y方向的最上面,''表示为根节点,不用标记子节点属性值
#plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()

#作用:获取叶节点的数目
#输入:树的字典形式
#输出:叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0#叶节点数目
firstStr = myTree.keys()[0]#根节点键
secondDict = myTree[firstStr]#根节点值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs

#作用:获取树的层数
#输入:树的字典形式
#输出:树的层数
def getTreeDepth(myTree):
maxDepth = 0#数的层数
firstStr = myTree.keys()[0]#根节点键
secondDict = myTree[firstStr]#根节点值
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth

#作用:输出树的字典形式
#输入:需要的数
#输出:树的字典形式
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]

#作用:在父子节点间填充文本信息
#输入:子节点位置,父节点位置,文本信息
#输出:无
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)#叶节点数目
depth = getTreeDepth(myTree)#树的层数
firstStr = myTree.keys()[0]#根节点键
cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 /plotTree.totalW, plotTree.y0ff)
plotMidText(cntrPt, parentPt, nodeTxt)#绘制文字
plotNode(firstStr, cntrPt, parentPt, decisionNode)#绘制根节点
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD