Python实现决策树算法 C4.5和ID3算法

时间:2022-08-30 03:16:16

本文以python语言实现了C4.5和ID3算法,默认为C4.5算法,若要使用ID3算法,将函数 entropy()最后的返回值改变一下即可,即注释掉C4.5那行代码,启用ID3那行代码即可。

将源代码保存为python文件,命名为c45.py,最后一个参数为数据的路径,可*设置,参考以下运行方式:

python c45.py data.txt

特别感谢:

点击打开链接

源代码如下:

#!/usr/bin/python
# -*- coding: UTF-8 -*-
__author__ = 'Administrator'
######## C4.5 ID3 finished!! ######
################# (tm_year=2016, tm_mon=3, tm_mday=15, tm_hour=22, tm_min=56, tm_sec=56, tm_wday=1, tm_yday=75, tm_isdst=0) ################
import re
import math
import sys

mini_size = 1 #### the minimum size of the nodes, the nodes will not be splited in the next though it is not fully just one type
DataLength = 100 ### the length of data items
used = [0 for i in range(DataLength)] ### attribute used or not
ended = [0 for i in range(DataLength)] #### if the nodes will be splited in the next
tp = [-1 for i in range(DataLength)] #### 1 - yes, 0 - no

class node:
def __init__(self):
self.value = ''
self.father = 0
self.com = 0 ### comes from which attribute
self.items = set()
# self.

lg=wd=0
def entropy(dt,values,node,i):
#for i in range(wd):
n = len(values[i])
ls = list(values[i])
# print ls
pos = [0 for j in range(n)]
neg = [0 for j in range(n)]
for j in node.items:
a = ls.index(dt[j][i])
if dt[j][-1] == 'Yes':
pos[a] += 1
else:
neg[a] += 1
all = 0.0
sp = 0.0
for j in range(n):
all += pos[j]+neg[j]

for j in range(n):
if pos[j]+neg[j] ==0:
continue
sp -= float(pos[j]+neg[j])/all * math.log(float(pos[j]+neg[j])/all)
# print all
s = 0.0
for j in range(n):
if pos[j]==0 or neg[j] == 0:
continue
s -= (pos[j]+neg[j])/all*( float(pos[j])/(pos[j]+neg[j])*math.log(float(pos[j])/(pos[j]+neg[j])) +\
float(neg[j])/(pos[j]+neg[j])*math.log(float(neg[j])/(pos[j]+neg[j])) )
#print values[i],pos,neg
#print s,sp,s/sp
return s/sp ### C4.5
#return s ### ID3

def ens(dt,values,node):
ps = ng = 0
for j in node.items:
if dt[j][-1] =='Yes':
ps+=1
else:
ng+=1
#print 'ens',ps,ng
if ps==0:
return 1
if ng==0:
return 0
return float(ps)/(ps+ng)*math.log(float(ps)/(ps+ng)) + float(ng)/(ps+ng)*math.log(float(ng)/(ps+ng))

if __name__ == '__main__':
#for a in sys.argv:
# print a
file = "c45_data.txt"
if len(sys.argv)>1:
file = sys.argv[1]
dt = [0 for i in range(DataLength)]
fp = open(file,"r")
i=0
for line in fp:
line = re.sub(r"\n\r","",line)
ls = line.split()
dt[i] = ls
i+=1
# print i
lg = i
wd = len(dt[0])
# print lg,wd
values = [set() for i in range(wd)]
for i in range(lg):
for j in range(wd):
values[j].add(dt[i][j])
# print values
root =node()
root.father = -1
root.com=-1
for i in range(lg):
root.items.add(i)
# print root.items
tree = [node() for i in range(DataLength)]
tree[0] = root
#print values ### the values of each attributes

order = -1
now = 0
flg = 0
while (order<=now):
order += 1
flg = 0

if len(tree[order].items)<=mini_size:
#print "mini_size",mini_size
ps=ng=0
for j in tree[order].items:
if dt[j][-1]=='Yes':
ps+=1
else:
ng+=1
if ps>=ng:
tp[order] = 1 ##############
else:
tp[order] = 0
#print tp[order]
continue

ls = [1.0 for i in range(wd-1)]
rt = -ens(dt,values,tree[order])
for i in range(wd-1):
if used[i] ==1:
continue
flg = 1
ls[i] = entropy(dt,values,tree[order],i)
#print ls[i]
#print max(ls)
#print ls,rt,min(ls)
if min(ls)>=rt or flg==0:
#print "rt",ls,rt,flg
ps=ng=0
for j in tree[order].items:
if dt[j][-1]=='Yes':
ps+=1
else:
ng+=1
if ps>=ng:
tp[order] = 1 ##############
else:
tp[order] = 0
continue
if min(ls)==0:
#print '0'
ps=ng=0
for j in tree[order].items:
if dt[j][-1]=='Yes':
ps+=1
else:
ng+=1
if ps>=ng:
tp[order] = 1 ##############
else:
tp[order] = 0
i = ls.index(min(ls))
used[i] = 1
#print i
ll = list(values[i])
n = len(ll)
#print rt,ls
#print "hhh",tree[order].items,n,i
for j in tree[order].items:
k = ll.index(dt[j][i])
tree[now+k+1].items.add(j)
tree[now+k+1].value = ll[k]
for j in range(n):
tree[now+j+1].father = order
tree[now+j+1].com = i
now += n
#print 'hello world',now

''' '''
print now+1,"nodes in all"
for i in range(now+1):
print i,'\tfather:',tree[i].father,'\tattribute: ',tree[i].com,"\tvalue:",tree[i].value
print tree[i].items,tp[i],'\n'
#print "\n"
for i in range(wd-1):
print ''#,entropy(dt,values,root,i)