【机器学习原理】决策树从原理到实践-2.代码

时间:2024-05-01 13:24:09

下面是代码实现的部分,写了一个基于CART的分类树,使用的样本就是上面提到的贷款数据,数据如下图:
在这里插入图片描述
是一个.txt文档,运行后得到了分类的结果,最终分类的几个集合都只有一个类别,也就是根据这些分类规则,可以完全将数据分开。
在这里插入图片描述
完整代码

# 基于CART的决策分类树复现(离散)
import collections
import queue
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def load_data():
    with open('static/data.txt', mode='r', encoding='utf-8') as f:
        data=f.read().split('\n')
    title=data[0].split(' ')
    x=[]
    y=[]
    for i in range(1,len(data)):
        xy = data[i].split(' ')
        x.append(xy[:-1])
        y.append(xy[-1]) # 最后一个是标签
    x = np.array(x)
    y = np.array(y)
    return title,x,y

class Node():
    def __init__(self,node_id=0,deep=0,id_list=None,nxt_list=None,split=True):
        self.node_id=node_id
        self.deep = deep
        if id_list is not None:
            self.id_list = np.array(id_list,dtype=int)  # 当前节点的索引集合
        else:
            self.id_list=[]
        if nxt_list is not None:
            self.nxt_list = np.array(nxt_list)  # 当前节点的索引集合
        else:
            self.nxt_list=np.array([])
        self.split=split # 是否需要继续分裂

class CART():
    def fit(self,x,y,gini_thresh=0.1):
        samples = x.shape[0]
        features = x.shape[1]
        root = Node(node_id=0,deep=1,id_list=np.arange(samples),nxt_list=[],split=True)
        # 先统计y的相关信息
        y_cag = collections.Counter(y)
        # print('标签统计信息:',y_cag)
        # label_list = list(y_cag.keys()) # y的所有类别
        tree = [root] # 存储最终的树
        q = queue.Queue() # 产生一个队列
        q.put(root)
        split_cnt = 0 # 记录分裂次数
        while not q.empty(): # 取出一个节点
            node = q.get() # 移除并返回数据
            id_list = node.id_list # 得到当前集合的所有id
            label_num = collections.Counter(y[id_list]) # 当前集合样本的所有对应标签的样本数
            num_all = id_list.size # 单管集合的所有数据
            min_gini = [0,None,0x3f3f3f3f,[],None] # 记录当前集合feature索引和特征名称(分裂信息),以及对应的gini指数,还有集合id
            for i in range(features): # 对于每个feature选择
                # 求出所有特征类别和对应的id
                feat_dict = {}
                # 这个地方有问题(不应该统计所有样本,而是当前对应的,应该可以在上面的id循环里面统计掉)
                for idx in id_list:
                    if x[idx,i] not in feat_dict.keys():
                        feat_dict[x[idx,i]]=[] # 当前id(索引)
                    feat_dict[x[idx, i]].append(idx)
                # 下面枚举将当前特征特征的每个取值作为分割点
                for type in feat_dict.keys(): # type作为分割点(统计分割点内的个样本匹配数量)
                    res = {}
                    for idx in feat_dict[type]:
                        if y[idx] not in res:
                            res[y[idx]]=0
                        res[y[idx]]+=1
                    # 根据统计出的数量已经可以计算基尼系数 gini=1-∑p^2
                    num = len(feat_dict[type]) # 得到数量(为是的)
                    gini_D1 = 1
                    for key in res.keys():
                        gini_D1-=(res[key]/num)**2
                    gini_D2 = 1
                    if num_all!=num:
                        for key in label_num.keys(): # 利用集合总数来推算为否的集合gini
                            sub = 0 # 要减去的样本(在集合D1的)
                            if key in res.keys():
                                sub = res[key]
                            gini_D2-=((label_num[key]-sub)/(num_all-num))**2
                    gini = (num/num_all)*gini_D1+((num_all-num)/num_all)*gini_D2
                    if gini<min_gini[2]:
                        min_gini[0]=i
                        min_gini[1] = type # 第i个特征的类别type
                        min_gini[2]=gini
                        min_gini[3] = feat_dict[type] # 记录id
                        min_gini[4]= (gini_D1,gini_D2) # 记录两个集合的gini决定是否继续分裂
            # 找到最小的gini进行分裂
            split_cnt+=1
            # print('总样本集:',id_list)
            print('第 %d 次分裂,根据第 %d 个特征的 %s 类别'%(split_cnt,min_gini[0],min_gini[1]))
            id_D1 = min_gini[3]
            id_D2 = []
            # print(id_list)
            for id in id_list:
                if id not in id_D1:
                    id_D2.append(id)
            # 生成两个节点
            id1 = len(tree)
            id2 = len(tree)+1 # 即将插入的两个节点的id(也就是在tree中的索引)
            tree[node.node_id].nxt_list=[id1,id2]
            node1 = Node(node_id=id1,deep=node.deep+1,id_list=id_D1,nxt_list=[])
            node2 = Node(node_id=id2,deep=node.deep+1,id_list=id_D2,nxt_list=[])
            # 判断是否需要继续分裂(纯度,纯度也就是如果都是一个类别为0就不分裂,还有个用阈值计算,懒得算了)
            if min_gini[4][0]<gini_thresh:
                node1.split=False # 无需分裂
            else:
                q.put(node1)
            if min_gini[4][1]<gini_thresh:
                node2.split=False # 无需分裂
            else:
                q.put(node2)
            tree.append(node1)
            tree.append(node2)
        # print(tree)
        self.tree  = tree

    def printTree(self):
        tree = self.tree
        print('----------- CART -----------')
        print('()中表示深度,根节点为1')
        for subtree in tree:
            if subtree.split:
                print('(%d)'%(subtree.deep),subtree.id_list, end='')
                print(' -> ',end='')
                node1 = tree[subtree.nxt_list[0]]
                print('(%d)'%(node1.deep),node1.id_list,end=' + ')
                node2 = tree[subtree.nxt_list[1]]
                print('(%d)'%(node2.deep),node2.id_list)
        print('----------------------------')

if __name__ == '__main__':
    title,x,y = load_data()
    print('********* 特征 *********')
    for i in range(len(title)):
        print(i+1,title[i])
    print('***********************')
    dct_cart = CART()
    dct_cart.fit(x,y)
    dct_cart.printTree()