Ah! 的博客

当然是选择原谅她

0%

[机器学习实战] 决策树

机器学习实战 系列参考于互联网资料与 人民邮电出版社 《机器学习实战》,编写目的在于学习交流,如有侵权,请联系删除

决策树概述

决策树

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据

缺点:可能会产生过度匹配的问题

适用数据类型:数值型和标称型

有这么一个游戏叫做二十个问题的游戏,游戏的规则很简单:参与游戏的一方在脑海里想某个事物,其他参与者向他提问题,只允许提20个问题,问题的答案也只能用对或错回答。问问题的人通过推断分解,逐步缩小代猜测事物的范围。决策树 的工作原理与 20 个问题类似,用户输入一系列数据,然后给出游戏的答案

决策树经常用作处理分类的问题,近来的调查研究表明决策树也是最经常使用的数据挖掘算法,它之所以这么流行,一个很重要的原因就是不需要了解机器学习的知识,就能搞明白决策树是如何工作的

如下的流程图就是一个决策树

在此图中 菱形代表 判断模块 (decision block),椭圆代表 终止模块 (terminating block),表示已经得出结论,可以终止运行。从判断模块引出的箭头叫做 分支 (branch),它可以到达另一个判断模块或者终止模块。

如上构造了一个假想的邮件分类系统,它首先检测发送邮件域名地址。如果地址是 myEmployer.com,则将其分类 “无聊时需要阅读的邮件” 中。如果邮件不是来自这个域名,则检查邮件内容里是否包含单词 曲棍球,如果包含则将邮件归类到 “需要及时处理的朋友邮件”,如果不包含则将邮件归类到 “无需阅读的垃圾邮件”

决策树的一个重要任务就是为了理解数据中蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列规则,这些机器根据数据集创建规则的过程就是机器学习的过程

之前的 k-近邻算法 可以完成很多分类任务,但最大的缺点就是无法给出数据的内在含义,决策树的优势就在于数据形式非常容易理解

关于决策树的画法

长方形代表 判断模块 (decision block),椭圆代表 终止模块 (terminating block),表示已经得出结论,可以终止运行。从判断模块引出的箭头叫做 分支 (branch)

—— 引用自 《机器学习实战》 一书

一个决策树包含三种类型的节点:

决策节点:通常用矩形框来表示

机会节点:通常用圆圈来表示

终结点:通常用三角形来表示

维基百科决策树示例

—— 引用自 https://zh.wikipedia.org/wiki/决策树#简介 维基百科

决策树的构造

首先需要了解数学上如何使用 信息论 划分数据集,然后编写代码将理论应用到具体的数据集上,最后编写代码构建决策树

在构建决策树时,需要解决的第一个问题就是,当前数据集那个特征在划分数据分类时起决定作用。为了找到决定性的特征,划分出最好的结果,必须评估每一个特征。完成测试之后,原始数据集就会被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则当前无需阅读的垃圾邮件已经正确的划分数据分类,无需进一步对数据集进行分割。如果数据子集内的数据不属于同一类型,则无需重复划分数据子集的过程。划分数据子集的算法和划分原始数据集的方法相同,直到所有具有相同类型的数据均在一个数据子集内

创建分支的伪代码函数 createBranch() 如下:

1
2
3
4
5
6
7
8
9
检测数据集中的每一个子项是否属于同一分类:
If so return 类标签
Else
寻找划分数据集的最好特征
划分数据集
创建分支点
for 每个划分的子集
调用函数 `createBrach` 并增加返回结果到分支节点中
return 分支节点

上面伪代码 createBrach 是一个递归函数,在倒数第二行直接调用了它自己。

决策树的一般流程

  • 收集数据
    • 可以使用任何方法
  • 准备数据
    • 树构造算法只适用于标称型数据,因此数值型数据必须离散化
  • 分析数据
    • 可以使用任何方法,构造树完成之后,需要检查图形是否符合预期
  • 训练算法
    • 构造书的数据结构
  • 测试算法
    • 使用经验树计算错误率
  • 使用算法
    • 此步骤可适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义

一些决策树算法采用二分法划分数据,如果依照某个属性划分数据将会产生4个可能的值。这里将会把数据划分成四块,并创建四个不同的分支,使用 ID3 算法划分数据集,该算法处理如何划分数据集,何时停止划分数据集 (wikiPedia ID3算法)。每次划分数据集时,只选取一个特征属性,如果训练集中存在20个特征,第一次选择哪个特征作为划分的参考属性呢

下表中包含 5 个海洋生物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。可以将这些动物分为两类:鱼类和非鱼类。现在需要决定依据第一个特征还是第二个特征划分数据。在回答这个问题之前,需要采用量化的方法判断如何划分数据

海洋生物数据

不浮出水面是否可以生存 是否有脚蹼 属于鱼类

信息增益

划分数据集的大原则是:将无序的数据变得更加有序。可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。组织杂乱无章的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。可以在划分数据之前或之后,使用信息论量化度量信息的内容。

在划分数据集之前之后信息发生的变化称为 信息增益,知道如何计算信息增益,就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高大的特征就是最好的选择

在可以评测哪种数据划分方式是最好大的数据划分之前,必须学习如何计算信息增益。集合信息的度量方式称为香浓熵或者简称为熵(shāng),这个名字来源于信息论之父克劳德·香农

如果不明白什么是 信息增益 (information gain)熵 (entropy), 请不要着急——它们从诞生的那一天起,就注定会让世人十分费解

熵定义为信息的期望值,在明晰这个概念之前,必须知道信息的定义。如果待分类的事务可能划分在多个分类之中,则符号 的信息定义为:

$l(\chi_i) = -\log_2 p(\chi_i)$

其中 是选择该分类的概率

为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下面公式得到:

$H = -\sum_{i=1}^np(\chi_i)\log_2 p(\chi_i)$

其中 是分类的数目

下面将使用 Python 计算信息熵,创建名为 trees.py 的文件,下面代码是计算给定的数据集熵

1
2
3
4
5
6
7
8
9
10
11
12
13
def calc_shannon_ent(data_set):
num_entries = len(data_set)
label_counts = {}
for featVec in data_set: # 为所有可能分类创建字典
current_label = featVec[-1]
if current_label not in label_counts.keys():
label_counts[current_label] = 0
label_counts[current_label] += 1
shannon_ent = 0.0
for key in label_counts:
prob = float(label_counts[key]) / num_entries
shannon_ent -= prob * log(prob, 2) # 以 2 为底求对数
return shannon_ent

首先计算数据集中实列的总数。也可以在需要时再计算这个值,但是由于代码中多次用到这个值,为了提高代码效率,显式地生命一个变量保存实列总数。然后创建一个字典,它的键值是最后一列的数值。如果当前键值不存在,则扩展字典并将当前键加入字典。每个键值都记录了当前类别出现的次数。最后,使用所有类标签的发生频率计算类别出现的概率,将用这个概率计算香浓熵,统计所有类标签发生的次数。下面使用熵划分数据集

trees.py 中,可以利用 create_data_set 函数所有的简单鱼鉴定数据集

1
2
3
4
5
6
7
8
9
def create_data_set():
data_set = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
# change to discrete values
return data_set, labels

在 Python 命令提示符下输入下列命令:

1
2
3
4
5
myDat, labels = trees.create_data_set()
myDat
Out[4]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
trees.calc_shannon_ent(myDat)
Out[5]: 0.9709505944546686

熵越高,则混合的数据也越多,可以在数据集中添加更多的分类,观察熵是如何变化的。这里增加第三个名为 maybe 的分类,测试熵的变化:

1
2
3
4
5
myDat[0][-1] = 'maybe'
myDat
Out[7]: [[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
trees.calc_shannon_ent(myDat)
Out[8]: 1.3709505944546687

得到熵之后,就可以按照获取最大信息增益的方法划分数据集,后面将介绍如何划分数据集以及如何度量信息增益

另一个度量集合无序程序的方法是 基尼不纯度 (Gini impurity),简单地说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。

划分数据

上面介绍了如何度量数据集的无序程度,分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以便判断当前是否正确地划分了数据集。随后将对每个特征划分数据集的结果计算一次信息熵,然后判断按照那个特征划分数据集是最好的划分方式。想象一个分布在二维空间的数据散点图,需要在数据之间划条线,将它们分成两部分,应该按照 轴还是 轴划线呢,下面将进行阐述

要划分数据集,打开文本编辑器,在 trees.py 中加入

1
2
3
4
5
6
7
8
def split_data_set(data_set, axis, value):
ret_data_set = []
for featVec in data_set:
if featVec[axis] == value:
reduced_feat_vec = featVec[:axis] # 抽取
reduced_feat_vec.extend(featVec[axis + 1:])
ret_data_set.append(reduced_feat_vec)
return ret_data_set

上述代码使用了三个输入参数:待划分的数据集,划分数据集的特征,需要返回的特征的值。需要注意的是,Python 不需要考虑内存分配的问题。 Python 语言在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期。为了消除这个不良影响,需要在函数的开始声明一个新列表对象。因为该函数代码在同一数据集上被调用多次,为了不修改原始数据集,创建一个新的列表对象。数据集这个列表中的各个元素也是列表,要遍历数据集中的每个元素,一旦发现符合要求的值,则将其添加到新创建的列表中。在 if 语句中,程序将符合特征的数据抽取出来。这里可以这样理解:当按照某个特征划分数据集时,就需要将所有符合要求的元素抽取出来。代码中使用了 Python 语言列表类型自带的extendappend方法,这两个方法的处理结果是完全不同的

假定存在两个列表, a和b:

1
2
3
4
5
a = [1,2,3]
b = [4,5,6]
a.append(b)
a
Out[5]: [1, 2, 3, [4, 5, 6]]

如果执行 a.append(b),则列表得到了第四个元素,而且第四个元素也是一个列表,然而如果使用 extend 方法:

1
2
3
a.extend(b)
a
Out[8]: [1, 2, 3, 4, 5, 6]

则得到一个包含 a 和 b 所有元素的列表

可以在前面的简单样本数据上测试函数 split_data_set(),在 Python 命令提示符中输入:

1
2
3
4
5
6
7
myDat, labels = trees.create_data_set()
myDat
Out[4]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
trees.split_data_set(myDat, 0, 1)
Out[5]: [[1, 'yes'], [1, 'yes'], [0, 'no']]
trees.split_data_set(myDat, 0, 0)
Out[6]: [[1, 'no'], [1, 'no']]

接下来将遍历整个数据集,循环计算香农熵和 split_data_set() 函数,找到最好的特征划分方式。熵计算将会告诉如何划分数据集是最好的数据组织方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def choose_best_feature_to_split(data_set):
num_features = len(data_set[0]) - 1 # 选最后一列为标签
base_entropy = calc_shannon_ent(data_set)
best_info_gain = 0.0
best_feature = -1
for i in range(num_features): # 遍历所有列
feat_list = [example[i] for example in data_set] # 创建唯一的分类标签列表
unique_val = set(feat_list) # 获取唯一的值
new_entropy = 0.0
for value in unique_val:
sub_data_set = split_data_set(data_set, i, value)
prob = len(sub_data_set) / float(len(data_set))
new_entropy += prob * calc_shannon_ent(sub_data_set)
info_gain = base_entropy - new_entropy # 计算每种划分方式的信息熵
if info_gain > best_info_gain: # 比较信息熵
best_info_gain = info_gain # 如果比当前的优则设置
best_feature = i
return best_feature # 返回最佳的信息熵

choose_best_feature_to_split() 实现选取特征、划分数据集,计算得出最好的划分数据集的特征。在函数中调用数据需要满足一定的要求:第一个要求是,数据必须是一种由列表元素组成的列表,而且所有的列表元素都要具有相同的数据长度;第二个要求是,数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。数据集一旦满足上述要求,就可以在函数的第一行判定当前数据集包含多少特征属性。无需限定 list 中的数据类型,它们既可以是数字也可以是字符串,并不影响实际计算

在开始划分数据之前,上述代码的第三行计算了整个数据集的原始香农熵,保存最初的无序度量值,用于划分完之后的数据集计算的熵值进行比较。第 1 个 for 循环遍历数据集中的所有特征。使用列表推导 (List Comprehension) 来创建新的列表,将数据集中所有第 i 个特征值或者所有可能存在的值写入这个新 list 中。然后使用 Python 语言原生的集合 (set) 数据类型。集合数据类型与列表类型相似,不同之处仅在于集合类型中的每个值互不相同。从列表中创建集合是 Python 语言得到列表中唯一元素值的最快方法

遍历当前特征中的所有唯一属性值,对每个唯一属性值划分一次数据集,然后计算数据集的新熵值,并对所有唯一的特征值得到的熵求和。信息增益是熵的减少或者是数据无序度的减少,对于将熵用于度量数据无序度的减少更容易理解。最后,比较所有特征中的信息增益,返回最好特征划分的索引值

现在进行测试:

1
2
3
4
5
myDat, labels = trees.create_data_set()
trees.choose_best_feature_to_split(myDat)
Out[4]: 0
myDat
Out[5]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

代码运行结果显示,第 0 个特征是最好的用于划分数据集的特征。结果是否正确呢?这个结果又有什么意义呢?数据集中的数据来源于

海洋生物数据

不浮出水面是否可以生存 是否有脚蹼 属于鱼类

或者变量 myDat 中的数据。如果按照第一个特征属性划分数据,也就是说第一个特征是 1 的放在一个组,第一个特征是 0 的放在另一个组,数据一致性如何?按照上述的方法划分数据集,第一个特征为 1 的海洋生物分组将有两个属于鱼类,一个属于非鱼类;另一个分组则全部属于非鱼类。如果按照第二个特征分组,结果又是怎么样呢?第一个海洋生物分组将有两个属于鱼类,两个属于非鱼类;另一个分组则只有一个非鱼类。第一种划分很好地处理了相关数据。如果不相信目测结果,可以使用 calc_shannon_ent() 测试不同特征分组的输出结果

下面将介绍如何将函数功能放在一起,构建决策树

递归构建决策树

目前已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,可以再次划分数据。因此可采用递归的原则处理数据集

递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果实列具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类

划分数据集时的数据路径

第一个结束条件使得算法可以终止,甚至可以设置算法可以划分的最大分组数目。后面将会介绍其他决策树算法,如C4.5和CART,这些算法在运行时并不总是在每次划分分组时都会消耗特征。由于特征数目并不是在每次划分数据分组时都减少,因此这些算法在实际使用时可能引起一定的问题。目前并不需要考虑这个问题,只需要在算法开始运行前计算列的数目,查看算法是否使用了所有属性即可。如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时需要决定如何定义该叶子节点,在这种情况下,通常采用多数表决的方法决定该叶子节点的分类

1
2
3
4
5
6
7
8
def majority_cnt(class_list):
class_count = {}
for vote in class_list:
if vote not in class_count.keys():
class_count[vote] = 0
class_count[vote] += 1
sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
return sorted_class_count[0][0]

上面的代码与 kNN 的 classify0 部分的投票表决代码非常类似,该函数使用分类名称的列表,然后创建键值为 classList 中唯一值的数据字典,字典对象储存了 classList 中每个类标签出现的频率,最后利用 operator 操作键值排序字典,并返回出现次数最多的分类名称

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def create_tree(data_set, label):
class_list = [example[-1] for example in data_set]
if class_list.count(class_list[0]) == len(class_list):
return class_list[0] # 类别完全相同则停止划分
if len(data_set[0]) == 1: # 遍历完所有特征时返回出现次数最多的类别
return majority_cnt(class_list)
best_feat = choose_best_feature_to_split(data_set)
best_feat_label = label[best_feat]
my_tree = {best_feat_label: {}}
del (label[best_feat])
feat_values = [example[best_feat] for example in data_set]
unique_val = set(feat_values)
for value in unique_val:
sub_labels = label[:] # 复制所有标签,这样树就不会弄乱现有的标签
my_tree[best_feat_label][value] = create_tree(split_data_set(data_set, best_feat, value), sub_labels)
return my_tree

上述代码使用两个输入参数:数据集和标签列表。标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但是为了给出数据明确的含义,将它作为一个输入参数提供。此外,前面提到的对数据集的要求这里依然需要满足。上述代码首先创建了名为 classList 的列表变量,其中包含了数据集的所有类标签。递归函数的第一个停止条件是所有的类标签完全相同,则直接返回该类标签。递归函数的第二个停止条件是使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组。由于第二个条件无法简单地返回唯一的类标签,这里使用前面介绍的 majority_cnt() 函数挑选出现次数最多的类别作为返回值

下一步程序开始创建树,这里使用 Python 语言的字典类型储存树的信息,当然也可以声明特殊的数据类型储存树,但是这里完全没有必要。字典变量 my_tree 储存了树的所有信息,这对于其后绘制树形图非常重要。当前数据集选取的最好特征存储在变量 best_feat 中,得到列表包含的所有属性值。

最后代码遍历当前选择特征包含的所有属性种子,在每个数据集划分上递归调用函数 create_tree(),得到的返回值将被插入到字典变量 my_tree 中,因此函数终止执行时,字典中将会嵌套很多代表叶子节点信息的字典数据。在解释这个嵌套数据之前,先看一下循环的第一行 sub_labels = label[:] ,这行复制了类标签,并将其储存在新列表变量 sub_labels 中。之所以这样做,是因为 Python 语言中函数参数是列表类型时,参数是按照引用方式传递的。为了保证每次调用函数 create_tree() 时不改变原始列表的内容,使用新变量 sub_lables 代替原始列表

现在进行测试,在 Python 命令行中输入:

1
2
3
4
myDat, labels = trees.create_data_set()
myTree = trees.create_tree(myDat, labels)
myTree
Out[6]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

变量 myTree 包含了很多代表树结构信息的嵌套字典,从左边开始,第一个关键字 no surfacing 是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典。第二个关键字是 no surfacing 特征划分的数据集,这些关键字的值是 no surfacing 节点的子节点。这些值可能是类标签,也可能是另一个数据字典。如果值是类标签,则该子节点是叶子节点;如果值是另一个数据字典,则子节点是一个判断节点,这种格式结构不断重复就构成了整棵树。例子中的树包含了 3 个叶子节点以及 2 个判断节点

使用 Matplotlib 注解绘制树形图

使用 Matplotlib 创建树形图,决策树的主要优点就是直观易于理解,如果不能将其直观地显示出来,就无法发挥其优势。虽然前面使用的图形库已经非常强大,但是 Python 没有提供绘制树的工具,因此必须自己绘制树形图

Matplotlib 注解

Matplotlib 提供了一个非常有用的注解工具 annotations,可以在数据图形上添加文本注解。注解通常用于解释数据的内容。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支持带箭头的划线工具,使得可以在其他恰当的地方指出数据位置,并在此添加描述信息,解释数据内容

创建 treePlotter.py,添加函数 plot_node

在应用数学中,一系列由边连接在一起的对象或者节点称为图。节点间的任意联系都可以通过边来表示。在计算机科学中,图是一种数据结构,用于表示数学上的概念。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定制文本框和箭头格式
leafNode = dict(boxstyle="round4", fc="0.8") # 定制文本框和箭头格式
arrow_args = dict(arrowstyle="<-") # 定制文本框和箭头格式


def plot_node(node_txt, center_pt, parent_pt, node_type):
create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction',
xytext=center_pt, textcoords='axes fraction',
va="center", ha="center", bbox=node_type, arrowprops=arrow_args) # 绘制带箭头的注解


def create_plot():
fig = plt.figure(1, facecolor='white')
fig.clf()
create_plot.ax1 = plt.subplot(111, frameon=False) # ticks for demo puropses
plot_node('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plot_node('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()

这是第一个 create_plot() 函数与接下来的 create_plot() 有些不同,随着内容的深入,将逐步添加缺失的代码。代码定义了描述树节点格式的常量。然后定义 plot_node() 函数执行了实际的绘图功能,该函数需要一个绘图区,该区域由全局变量 create_plot.ax1 定义。Python 语言中所有的变量默认都是全局有效的,只要搞清楚知道当前代码的主要功能,并不会引入太大的麻烦,最后定义 create_plot() 函数,它是这段代码的核心。create_plot() 函数首先创建了一个新图形并清空绘图区,然后在绘图区上绘制两个代表不同类型的树节点,后面将用着两个节点绘制树形图

下面是代码的输出结果

构造注解树

绘制一颗完整的树需要一些技巧。虽然有 x、y 坐标,但是如何放置所有的树节点却是个问题。首先必须知道有多少个叶节点,以便可以正确确定 x 轴的长度;其次还需要知道树有多少层,以便正确确定 y 轴的高度。这里定义两个新函数 get_num_leafs()get_tree_depth(),来获取叶节点的数目和树的层数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def get_num_leafs(my_tree):
num_leafs = 0
first_str = list(my_tree.keys())[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[
key]).__name__ == 'dict': # 测试节点的数据类型是否为字典,不是就是叶节点
num_leafs += get_num_leafs(second_dict[key])
else:
num_leafs += 1
return num_leafs


def get_tree_depth(my_tree):
max_depth = 0
first_str = list(my_tree.keys())[0]
second_dict = my_tree[first_str]
for key in second_dict.keys():
if type(second_dict[
key]).__name__ == 'dict': # 测试节点的数据类型是否为字典,不是就是叶节点
this_depth = 1 + get_tree_depth(second_dict[key])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth

上述的两个函数具有相同的结构,后面也将使用到这两个函数。这里使用的数据结构说明了如何在 Python 字典类型中储存树信息。第一个关键字是第一次划分数据集的类别标签,附带的数值表示子节点的取值。从第一个关键字出发,可以遍历整棵树的所有子节点。使用 Python 提供的 type() 函数可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用 get_num_leafs() 函数。get_num_leafs() 函数遍历整棵树,累计叶子节点的个数,并返回该数值。第二个函数 get_tree_depth() 计算遍历过程中遇到判断节点的个数。该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回。并将计算树深度的变量加一。为了节省时间,函数 retrieve_tree() 输出预先储存的树信息,避免了每次测试代码时都要从数据中创建树的麻烦

1
2
3
4
5
def retrieve_tree(i):
list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return list_of_trees[i]

进行测试,在 Python 命令提示符中输入:

1
2
3
4
5
6
7
8
9
treePlotter.retrieve_tree(1)
Out[4]:
{'no surfacing': {0: 'no',
1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
myTree = treePlotter.retrieve_tree(0)
treePlotter.get_num_leafs(myTree)
Out[6]: 3
treePlotter.get_tree_depth(myTree)
Out[7]: 2

函数 retrieve_tree() 主要用于测试,返回预定义的树结构。上述命令中调用 get_num_leafs() 函数返回值 3,等于树 0 的叶子节点数;调用 get_tree_depth() 函数也能够正确返回树的层数

注意,前文已经在文件中定义了函数 create_plot(),此处需要更新这部分代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def plot_mid_text(ctr_pt, parent_pt, txt_string):
x_mid = (parent_pt[0] - ctr_pt[0]) / 2.0 + ctr_pt[0] # 在父子节点中填充文本信息
y_mid = (parent_pt[1] - ctr_pt[1]) / 2.0 + ctr_pt[1] # 在父子节点中填充文本信息
create_plot.ax1.text(x_mid, y_mid, txt_string, va="center", ha="center", rotation=30) # 在父子节点中填充文本信息


def plot_tree(my_tree, parent_pt, node_txt):
num_leafs = get_num_leafs(my_tree) # 计算树 x 轴的长度
depth = get_tree_depth(my_tree)
first_str = list(my_tree.keys())[0] # 节点的标签
ctr_pt = (plot_tree.xOff + (1.0 + float(num_leafs)) / 2.0 / plot_tree.totalW, plot_tree.yOff)
plot_mid_text(ctr_pt, parent_pt, node_txt) # 标记子节点属性值
plot_node(first_str, ctr_pt, parent_pt, decisionNode)
second_dict = my_tree[first_str]
plot_tree.yOff = plot_tree.yOff - 1.0 / plot_tree.totalD # 减少 y 的偏移
for key in second_dict.keys():
if type(second_dict[
key]).__name__ == 'dict': # 测试节点的数据类型是否为字典,不是就是叶节点
plot_tree(second_dict[key], ctr_pt, str(key)) # 递归
else: # 是叶节点则打印叶节点
plot_tree.xOff = plot_tree.xOff + 1.0 / plot_tree.totalW
plot_node(second_dict[key], (plot_tree.xOff, plot_tree.yOff), ctr_pt, leafNode)
plot_mid_text((plot_tree.xOff, plot_tree.yOff), ctr_pt, str(key))
plot_tree.yOff = plot_tree.yOff + 1.0 / plot_tree.totalD


def create_plot(in_tree):
fig = plt.figure(1, facecolor='white')
fig.clf()
apropos = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **apropos) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plot_tree.totalW = float(get_num_leafs(in_tree))
plot_tree.totalD = float(get_num_leafs(in_tree))
plot_tree.xOff = -0.5 / plot_tree.totalW
plot_tree.yOff = 1.0;
plot_tree(in_tree, (0.5, 1.0), '')
plt.show()

函数 create_plot() 是使用的主函数,它调用了 plot_tree(),函数 plot_tree 又依次调用了前面介绍的函数和 plot_mid_text() 。绘制树形图的很多工作都是在函数 plot_tree() 中完成的,函数 plot_tree() 首先计算树的宽和高。全局变量 plot_tree.totalW 储存树的宽度,全局变量 plot_tree.totalD 储存树的深度,使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。函数 plot_tree() 也是一个递归函数。树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间,而不仅仅是它子节点的中间。同时使用两个全局变量 plot_tree.xOffplot_tree.yOff 追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。另一个需要说明的问题是,绘制图形的x轴有效范围是0.0到1.0,y轴有效范围也是0.0 ~ 1.0。通过计算树包含的所有叶子节点数,划分图形的宽度,从而计算得到当前节点的中心位置,也就是,按照叶子节点的数目将 x 轴划分成若干部分。按照图形比例绘制树形图的最大好处是无需关心实际输出图形的大小,一旦图形大小发生了变化,函数会自动按照图形大小重新绘制,如果以像素为单位绘制图形,则缩放图形就不是一件简单的工作

接着,绘出子节点具有的特征值,或者沿此分支向下的数据实例必须具有的特征值。使用函数 plot_mid_text() 计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息

然后,按比列减少全局变量 plot_tree.yOff,并标注此处将要绘制子节点,这些节点既可以是叶子节点也可以是判断节点,此处需要只保存绘制图形的轨迹。因为是自顶向下绘制图形,因此需要依次递减 y 坐标值,而不是递增 y 坐标值。然后程序采用函数 get_num_leafs()get_tree_depth() 以相同的方式递归遍历整棵树,如果节点是叶子节点则在图形上画出叶子节点,如果不是叶子节点则递归调用 plot_tree()函数。在绘制了所有子节点之后,增加全局变量Y的偏移

create_plot() 是最后一个函数,它创建绘图区,计算树形图的全局尺寸,并调用递归函数 plot_tree()

现在进行测试,在 Python 命令提示符下输入:

1
2
myTree = treePlotter.retrieve_tree(0)
treePlotter.create_plot(myTree)

得到下面的图,但是没有坐标轴标签

按照下面命令变更字典,重新绘制树形图:

1
2
3
4
myTree['no surfacing'][3] = 'maybe'
myTree
Out[7]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
treePlotter.create_plot(myTree)

得到下面的图,有点像一个无头的简笔画。也可以在树字典中随意添加一些数据,并重新绘制树型图观察输出结果的变化

超过两个分支的树形图

测试和存储分类器

下面将使用决策树构建分类器,并介绍实际应用中如何储存分类器

测试算法:使用决策树执行分类

依靠训练数据构造了决策树之后,可以将它用于实际数据的分类。在执行数据分类时,需要使用决策树以及用于构造决策树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类型

1
2
3
4
5
6
7
8
9
10
11
def classify(input_tree, feat_labels, test_vec):
first_str = list(input_tree.keys())[0]
second_dict = input_tree[first_str]
feat_index = feat_labels.index(first_str)
key = test_vec[feat_index]
value_of_feat = second_dict[key]
if isinstance(value_of_feat, dict):
class_label = classify(value_of_feat, feat_labels, test_vec)
else:
class_label = value_of_feat
return class_label

将上述代码添加到 trees.py 中,这也是一个递归函数,在储存带有特征的数据会面临一个问题:程序无法确定特征在数据集中的位置,例如前面例子的第一个用于划分数据集的特征是 no surfacing 属性,但是实际数据集中该属性储存在哪个位置?是第一个属性还是第二个属性?特征标签列表将帮助程序处理这个问题。使用 index 方法查找当前列表中第一个匹配 first_str 变量的元素。然后代码递归遍历整棵树,比较 test_vec 变量中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签

在 Python 命令行中输入下列命令:

1
2
3
4
5
6
7
8
9
10
myDat, labels = trees.create_data_set()
labels
Out[6]: ['no surfacing', 'flippers']
myTree = treePlotter.retrieve_tree(0)
myTree
Out[8]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
trees.classify(myTree, labels, [1, 0])
Out[9]: 'no'
trees.classify(myTree, labels, [1, 1])
Out[10]: 'yes'

与下图比较

第一个节点名为: no surfacing,它有两个子节点:一个是名字为 0 的叶子节点,类标签为 no;另一个是名为 flippers 的判断节点,此处进入递归调用,flippers 节点有两个子节点。以前绘制的树形图和此处代表树的数据结构完全相同

现在已经创建了使用决策树的分类器,但是每次使用分类器时,必须重新构造决策树,下一节将介绍如何在硬盘上储存决策树分类器

使用算法:决策树的储存

构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。然而用创建好的决策树解决分类问题,则可以很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树,为了解决这个问题,需要使用 Python 模块 pickle 序列化对象。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外

1
2
3
4
5
6
7
8
9
10
11
def store_tree(input_tree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(input_tree, fw)
fw.close()


def grab_tree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)

通过上面的代码,可以将分类器储存在硬盘上,而不用每次对数据分类时重新学习一遍,这也是决策树的优点之一,之前介绍的 k-近邻算法就无法持久化分类器。可以预先提炼并储存数据集中包含的知识信息,在需要对事物进行分类时再使用这些知识

示例:使用决策树预测隐形眼镜类型

下面将通过一个例子讲解决策树如何预测患者需要佩戴的隐形眼镜类型。使用小数据集,就可以利用决策树学到很多知识:眼科医生时如何判断患者需要佩戴的镜片类型的;一旦理解了决策树的工作原理,甚至也可以帮助人们判断需要佩戴的镜片类型

使用决策树预测隐形眼镜类型:

  • 收集数据
    • 提供的文本文件
  • 准备数据
    • 解析 tab 键分割的数据行
  • 分析数据
    • 快速检查数据,确保正确地解析数据内容,使用 create_plot() 函数绘制最终的树形图
  • 训练算法
    • 使用 create_tree() 函数
  • 测试算法
    • 编写测试函数验证决策树可以正确分类给定的数据实例
  • 使用算法
    • 存储树的数据结构,以便下次使用时无需重新构造树

隐形眼镜数据集是非常著名的数据集,它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。数据来源于 UCI 数据库,为了更容易显示数据,数据做了简单的更改,数据储存在源代码下载路径的文本文件中

在 Python 命令提示符输入:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.create_tree(lenses, lensesLabels)
lensesTree
Out[9]:
{'tearRate': {'reduced': 'no lenses',
'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses',
'young': 'hard',
'presbyopic': 'no lenses'}},
'myope': 'hard'}},
'no': {'age': {'pre': 'soft',
'young': 'soft',
'presbyopic': {'prescript': {'hyper': 'soft',
'myope': 'no lenses'}}}}}}}}
treePlotter.create_plot(lensesTree)

采用文本方式很难分辨出决策树的模样,最后一行命令调用 create_plot() 函数绘制了下图

由 ID3 算法产生的决策树

沿着决策树的不同分支,可以得到不同患者需要佩戴的隐形眼镜类型。从图中发现,医生最多需要问 4 个问题就能确定患者需要佩戴哪种类型的隐形眼镜

上图的决策树非常好的匹配了实现数据,然而这些匹配选项可能太多了。这样的情况称之为 过度匹配(overfitting) 。为了减少过度匹配,可以裁剪决策树,去掉一些不必要的叶子节点。如果叶子节点只能增加少许信息,则可以删除该节点,将它并入到其他叶子节点中,后面将会进行阐述

在之后的博客中,会介绍另一个决策树构造算法 CART,本章使用的算法称为 ID3,它是一个好的算法但并不完美。ID3算法无法直接处理数值型数据,尽管可以通过量化的方法将数值型数据转化为标称型数值,但是如果存在太多的特征划分,ID3算法仍然会面临其他问题

小结

决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,首先需要测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。ID3 算法可以用于划分标称型数据集。构建决策树时,通常采用递归的方法将数据集转化为决策树。一般并不构造新的数据结构,而是使用 Python 语言内嵌的数据结构字典储存节点信息

Matplotlib 的注解功能,可以将储存的树结构转化为容易理解的图形。Python 语言的 pickle 模块可用于储存决策树的结构。隐形眼镜的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题

现为止,讨论的是结果确定的分类算法,数据实例最终会被明确划分到某个分类中。之后,会阐述分类算法将不能完全确定数据实例应该划分到某个分类,或者只能给出数据实例属于给定分类的概率

源码示例

GitHub 下载