扫二维码与项目经理沟通
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流
决策树之ID3算法及其Python实现
创新互联公司专业为企业提供清远网站建设、清远做网站、清远网站设计、清远网站制作等企业网站建设、网页设计与制作、清远企业网站模板建站服务,10年清远做网站经验,不只是建网站,更提供有价值的思路和整体网络服务。
1. 决策树背景知识
??决策树是数据挖掘中最重要且最常用的方法之一,主要应用于数据挖掘中的分类和预测。决策树是知识的一种呈现方式,决策树中从顶点到每个结点的路径都是一条分类规则。决策树算法最先基于信息论发展起来,经过几十年发展,目前常用的算法有:ID3、C4.5、CART算法等。
2. 决策树一般构建过程
??构建决策树是一个自顶向下的过程。树的生长过程是一个不断把数据进行切分细分的过程,每一次切分都会产生一个数据子集对应的节点。从包含所有数据的根节点开始,根据选取分裂属性的属性值把训练集划分成不同的数据子集,生成由每个训练数据子集对应新的非叶子节点。对生成的非叶子节点再重复以上过程,直到满足特定的终止条件,停止对数据子集划分,生成数据子集对应的叶子节点,即所需类别。测试集在决策树构建完成后检验其性能。如果性能不达标,我们需要对决策树算法进行改善,直到达到预期的性能指标。
??注:分裂属性的选取是决策树生产过程中的关键,它决定了生成的决策树的性能、结构。分裂属性选择的评判标准是决策树算法之间的根本区别。
3. ID3算法分裂属性的选择——信息增益
??属性的选择是决策树算法中的核心。是对决策树的结构、性能起到决定性的作用。ID3算法基于信息增益的分裂属性选择。基于信息增益的属性选择是指以信息熵的下降速度作为选择属性的方法。它以的信息论为基础,选择具有最高信息增益的属性作为当前节点的分裂属性。选择该属性作为分裂属性后,使得分裂后的样本的信息量最大,不确定性最小,即熵最小。
??信息增益的定义为变化前后熵的差值,而熵的定义为信息的期望值,因此在了解熵和信息增益之前,我们需要了解信息的定义。
??信息:分类标签xi 在样本集 S 中出现的频率记为 p(xi),则 xi 的信息定义为:?log2p(xi) 。
??分裂之前样本集的熵:E(S)=?∑Ni=1p(xi)log2p(xi),其中 N 为分类标签的个数。
??通过属性A分裂之后样本集的熵:EA(S)=?∑mj=1|Sj||S|E(Sj),其中 m 代表原始样本集通过属性A的属性值划分为 m 个子样本集,|Sj| 表示第j个子样本集中样本数量,|S| 表示分裂之前数据集中样本总数量。
??通过属性A分裂之后样本集的信息增益:InfoGain(S,A)=E(S)?EA(S)
??注:分裂属性的选择标准为:分裂前后信息增益越大越好,即分裂后的熵越小越好。
4. ID3算法
??ID3算法是一种基于信息增益属性选择的决策树学习方法。核心思想是:通过计算属性的信息增益来选择决策树各级节点上的分裂属性,使得在每一个非叶子节点进行测试时,获得关于被测试样本最大的类别信息。基本方法是:计算所有的属性,选择信息增益最大的属性分裂产生决策树节点,基于该属性的不同属性值建立各分支,再对各分支的子集递归调用该方法建立子节点的分支,直到所有子集仅包括同一类别或没有可分裂的属性为止。由此得到一棵决策树,可用来对新样本数据进行分类。
ID3算法流程:
(1) 创建一个初始节点。如果该节点中的样本都在同一类别,则算法终止,把该节点标记为叶节点,并用该类别标记。
(2) 否则,依据算法选取信息增益最大的属性,该属性作为该节点的分裂属性。
(3) 对该分裂属性中的每一个值,延伸相应的一个分支,并依据属性值划分样本。
(4) 使用同样的过程,自顶向下的递归,直到满足下面三个条件中的一个时就停止递归。
??A、待分裂节点的所有样本同属于一类。
??B、训练样本集中所有样本均完成分类。
??C、所有属性均被作为分裂属性执行一次。若此时,叶子结点中仍有属于不同类别的样本时,选取叶子结点中包含样本最多的类别,作为该叶子结点的分类。
ID3算法优缺点分析
优点:构建决策树的速度比较快,算法实现简单,生成的规则容易理解。
缺点:在属性选择时,倾向于选择那些拥有多个属性值的属性作为分裂属性,而这些属性不一定是最佳分裂属性;不能处理属性值连续的属性;无修剪过程,无法对决策树进行优化,生成的决策树可能存在过度拟合的情况。
sklearn中决策树分为DecisionTreeClassifier和DecisionTreeRegressor,所以用的算法是CART算法,也就是分类与回归树算法(classification and regression tree,CART),划分标准默认使用的也是Gini,ID3和C4.5用的是信息熵,为何要设置成ID3或者C4.5呢
import numpy as np11
import pandas as pd11
names=("Balance,Duration,History,Purpose,Credit amount,Savings,Employment,instPercent,sexMarried,Guarantors,Residence duration,Assets,Age,concCredit,Apartment,Credits,Occupation,Dependents,hasPhone,Foreign,lable").split(',')11
data=pd.read_csv("Desktop/sunshengyun/data/german/german.data",sep='\s+',names=names)11
data.head()11
Balance
Duration
History
Purpose
Credit amount
Savings
Employment
instPercent
sexMarried
Guarantors
…
Assets
Age
concCredit
Apartment
Credits
Occupation
Dependents
hasPhone
Foreign
lable
A11 6 A34 A43 1169 A65 A75 4 A93 A101 … A121 67 A143 A152 2 A173 1 A192 A201 1
1
A12 48 A32 A43 5951 A61 A73 2 A92 A101 … A121 22 A143 A152 1 A173 1 A191 A201 2
2
A14 12 A34 A46 2096 A61 A74 2 A93 A101 … A121 49 A143 A152 1 A172 2 A191 A201 1
3
A11 42 A32 A42 7882 A61 A74 2 A93 A103 … A122 45 A143 A153 1 A173 2 A191 A201 1
4
A11 24 A33 A40 4870 A61 A73 3 A93 A101 … A124 53 A143 A153 2 A173 2 A191 A201 2
5 rows × 21 columns
data.Balance.unique()11
array([‘A11’, ‘A12’, ‘A14’, ‘A13’], dtype=object)data.count()11
Balance 1000 Duration 1000 History 1000 Purpose 1000 Credit amount 1000 Savings 1000 Employment 1000 instPercent 1000 sexMarried 1000 Guarantors 1000 Residence duration 1000 Assets 1000 Age 1000 concCredit 1000 Apartment 1000 Credits 1000 Occupation 1000 Dependents 1000 hasPhone 1000 Foreign 1000 lable 1000 dtype: int64#部分变量描述性统计分析
data.describe()1212
Duration
Credit amount
instPercent
Residence duration
Age
Credits
Dependents
lable
count
1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000 1000.000000
mean
20.903000 3271.258000 2.973000 2.845000 35.546000 1.407000 1.155000 1.300000
std
12.058814 2822.736876 1.118715 1.103718 11.375469 0.577654 0.362086 0.458487
min
4.000000 250.000000 1.000000 1.000000 19.000000 1.000000 1.000000 1.000000
25%
12.000000 1365.500000 2.000000 2.000000 27.000000 1.000000 1.000000 1.000000
50%
18.000000 2319.500000 3.000000 3.000000 33.000000 1.000000 1.000000 1.000000
75%
24.000000 3972.250000 4.000000 4.000000 42.000000 2.000000 1.000000 2.000000
max
72.000000 18424.000000 4.000000 4.000000 75.000000 4.000000 2.000000 2.000000
data.Duration.unique()11
array([ 6, 48, 12, 42, 24, 36, 30, 15, 9, 10, 7, 60, 18, 45, 11, 27, 8, 54, 20, 14, 33, 21, 16, 4, 47, 13, 22, 39, 28, 5, 26, 72, 40], dtype=int64)data.History.unique()11
array([‘A34’, ‘A32’, ‘A33’, ‘A30’, ‘A31’], dtype=object)data.groupby('Balance').size().order(ascending=False)11
c:\python27\lib\site-packages\ipykernel\__main__.py:1: FutureWarning: order is deprecated, use sort_values(…) if __name__ == ‘__main__’: Balance A14 394 A11 274 A12 269 A13 63 dtype: int64data.groupby('Purpose').size().order(ascending=False)11
c:\python27\lib\site-packages\ipykernel\__main__.py:1: FutureWarning: order is deprecated, use sort_values(…) if __name__ == ‘__main__’: Purpose A43 280 A40 234 A42 181 A41 103 A49 97 A46 50 A45 22 A44 12 A410 12 A48 9 dtype: int64data.groupby('Apartment').size().order(ascending=False)11
c:\python27\lib\site-packages\ipykernel\__main__.py:1: FutureWarning: order is deprecated, use sort_values(…) if __name__ == ‘__main__’: Apartment A152 713 A151 179 A153 108 dtype: int64import matplotlib.pyplot as plt
%matplotlib inline
data.plot(x='lable', y='Age', kind='scatter',
alpha=0.02, s=50);12341234
![png](output_13_0.png)data.hist('Age', bins=15);11
![png](output_14_0.png)target=data.lable11
features_data=data.drop('lable',axis=1)11
numeric_features = [c for c in features_data if features_data[c].dtype.kind in ('i', 'f')] # 提取数值类型为整数或浮点数的变量11
numeric_features11
[‘Duration’, ‘Credit amount’, ‘instPercent’, ‘Residence duration’, ‘Age’, ‘Credits’, ‘Dependents’]numeric_data = features_data[numeric_features]11
numeric_data.head()11
Duration
Credit amount
instPercent
Residence duration
Age
Credits
Dependents
6 1169 4 4 67 2 1
1
48 5951 2 2 22 1 1
2
12 2096 2 3 49 1 2
3
42 7882 2 4 45 1 2
4
24 4870 3 4 53 2 2
categorical_data = features_data.drop(numeric_features, axis=1)11
categorical_data.head()11
Balance
History
Purpose
Savings
Employment
sexMarried
Guarantors
Assets
concCredit
Apartment
Occupation
hasPhone
Foreign
A11 A34 A43 A65 A75 A93 A101 A121 A143 A152 A173 A192 A201
1
A12 A32 A43 A61 A73 A92 A101 A121 A143 A152 A173 A191 A201
2
A14 A34 A46 A61 A74 A93 A101 A121 A143 A152 A172 A191 A201
3
A11 A32 A42 A61 A74 A93 A103 A122 A143 A153 A173 A191 A201
4
A11 A33 A40 A61 A73 A93 A101 A124 A143 A153 A173 A191 A201
categorical_data_encoded = categorical_data.apply(lambda x: pd.factorize(x)[0]) # pd.factorize即可将分类变量转换为数值表示
# apply运算将转换函数应用到每一个变量维度
categorical_data_encoded.head(5)123123
Balance
History
Purpose
Savings
Employment
sexMarried
Guarantors
Assets
concCredit
Apartment
Occupation
hasPhone
Foreign
0 0 0 0 0 0 0 0 0 0 0 0 0
1
1 1 0 1 1 1 0 0 0 0 0 1 0
2
2 0 1 1 2 0 0 0 0 0 1 1 0
3
0 1 2 1 2 0 1 1 0 1 0 1 0
4
0 2 3 1 1 0 0 2 0 1 0 1 0
features = pd.concat([numeric_data, categorical_data_encoded], axis=1)#进行数据的合并
features.head()
# 此处也可以选用one-hot编码来表示分类变量,相应的程序如下:
# features = pd.get_dummies(features_data)
# features.head()1234512345
Duration
Credit amount
instPercent
Residence duration
Age
Credits
Dependents
Balance
History
Purpose
Savings
Employment
sexMarried
Guarantors
Assets
concCredit
Apartment
Occupation
hasPhone
Foreign
6 1169 4 4 67 2 1 0 0 0 0 0 0 0 0 0 0 0 0 0
1
48 5951 2 2 22 1 1 1 1 0 1 1 1 0 0 0 0 0 1 0
2
12 2096 2 3 49 1 2 2 0 1 1 2 0 0 0 0 0 1 1 0
3
42 7882 2 4 45 1 2 0 1 2 1 2 0 1 1 0 1 0 1 0
4
24 4870 3 4 53 2 2 0 2 3 1 1 0 0 2 0 1 0 1 0
X = features.values.astype(np.float32) # 转换数据类型
y = (target.values == 1).astype(np.int32) # 1:good,2:bad1212
from sklearn.cross_validation import train_test_split # sklearn库中train_test_split函数可实现该划分
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=0) # 参数test_size设置训练集占比
1234512345
from sklearn.tree import DecisionTreeClassifier
from sklearn.cross_validation import cross_val_score
clf = DecisionTreeClassifier(max_depth=8) # 参数max_depth设置树最大深度
# 交叉验证,评价分类器性能,此处选择的评分标准是ROC曲线下的AUC值,对应AUC更大的分类器效果更好
scores = cross_val_score(clf, X_train, y_train, cv=3, scoring='roc_auc')
print("ROC AUC Decision Tree: {:.4f} +/-{:.4f}".format(
np.mean(scores), np.std(scores)))123456789123456789
ROC AUC Decision Tree: 0.6866 +/-0.0105
#利用learning curve,以样本数为横坐标,训练和交叉验证集上的评分为纵坐标,对不同深度的决策树进行对比(判断是否存在过拟合或欠拟合)
from sklearn.learning_curve import learning_curve
def plot_learning_curve(estimator, X, y, ylim=(0, 1.1), cv=3,
n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5),
scoring=None):
plt.title("Learning curves for %s" % type(estimator).__name__)
plt.ylim(*ylim); plt.grid()
plt.xlabel("Training examples")
plt.ylabel("Score")
train_sizes, train_scores, validation_scores = learning_curve(
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes,
scoring=scoring)
train_scores_mean = np.mean(train_scores, axis=1)
validation_scores_mean = np.mean(validation_scores, axis=1)
plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
label="Training score")
plt.plot(train_sizes, validation_scores_mean, 'o-', color="g",
label="Cross-validation score")
plt.legend(loc="best")
print("Best validation score: {:.4f}".format(validation_scores_mean[-1]))12345678910111213141516171819202122231234567891011121314151617181920212223
clf = DecisionTreeClassifier(max_depth=None)
plot_learning_curve(clf, X_train, y_train, scoring='roc_auc')
# 可以注意到训练数据和交叉验证数据的得分有很大的差距,意味着可能过度拟合训练数据了123123
Best validation score: 0.6310
clf = DecisionTreeClassifier(max_depth=10)
plot_learning_curve(clf, X_train, y_train, scoring='roc_auc')1212
Best validation score: 0.6565
clf = DecisionTreeClassifier(max_depth=8)
plot_learning_curve(clf, X_train, y_train, scoring='roc_auc')1212
Best validation score: 0.6762
clf = DecisionTreeClassifier(max_depth=5)
plot_learning_curve(clf, X_train, y_train, scoring='roc_auc')1212
Best validation score: 0.7219
clf = DecisionTreeClassifier(max_depth=4)
plot_learning_curve(clf, X_train, y_train, scoring='roc_auc')1212
Best validation score: 0.7226
# 这里有一个示例,你可以看一下。
#
from IPython.display import Image
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
Image(graph.create_png())
决策树是一种非参数有监督的机器学习方法,可以用于解决回归问题和分类问题。通过学习已有的数据,计算得出一系列推断规则来预测目标变量的值,并用类似流程图的形式进行展示。决策树模型可以进行可视化,具有很强的可解释性,算法容易理解,以决策树为基础的各种集成算法在很多领域都有广泛的应用。
熵的概念最早起源于物理学,用于度量一个热力学系统的无序程度。在信息论里面,信息熵代表着一个事件或一个变量等所含有的信息量。 在信息世界,熵越高,则能传输越多的信息,熵越低,则意味着传输的信息越少。
发生概率低的事件比发生概率高的事件具有更大的不确定性,需要更多的信息去描述他们,信息熵更高。
我们可以用计算事件发生的概率来计算事件的信息,又称“香农信息”( Shannon Information )。一个离散事件x的信息可以表示为:
h(x) = -log(p(x))
p() 代表事件x发生的概率, log() 为以二为底的对数函数,即一个事件的信息量就是这个事件发生的概率的负对数。选择以二为底的对数函数代表计算信息的单位是二进制。因为概率p(x)小于1,所以负号就保证了信息熵永远不为负数。当事件的概率为1时,也就是当某事件百分之百发生时,信息为0。
熵( entropy ),又称“香农熵”( Shannon entropy ),表示一个随机变量的分布所需要的平均比特数。一个随机变量的信息熵可以表示为:
H(x) = -sum(each k in K p(k)log(p(k)))
K表示变量x所可能具有的所有状态(所有事件),将发生特定事件的概率和该事件的信息相乘,最后加和,即可得到该变量的信息熵。可以理解为,信息熵就是平均而言发生一个事件我们得到的信息量大小。所以数学上,信息熵其实是事件信息量的期望。
当组成该随机变量的一个事件的概率为1时信息熵最小,为0, 即该事件必然发生。当组成该随机变量的所有事件发生的概率相等时,信息熵最大,即完全不能判断那一个事件更容易发生,不确定性最大。
当一个事件主导时,比如偏态分布( Skewed Probability Distribution ),不确定性减小,信息熵较低(low entropy);当所有事件发生概率相同时,比如均衡分布( Balanced Probability Distribution ),不确定性极大,信息熵较高(high entropy)。
由以上的香农信息公式可知,信息熵主要有三条性质:
- 单调性 。发生概率越高的事件,其所携带的信息熵越低。比如一个真理的不确定性是极低的,那么它所携带的信息熵就极低。
- 非负性 。信息熵不能为负。单纯从逻辑层面理解,如果得知了某个信息后,却增加了不确定性,这也是不合逻辑的。
- 可加性 。即多随机事件同时发生存在的总不确定性的量度是可以表示为各事件不确定性的量度的和。
若两事件A和B同时发生,两个事件相互独立。 p(X=A,Y=B) = p(X = A)*p(Y=B) , 那么信息熵为 H(A,B) = H(A) + H(B) 。但若两事件不相互独立,那么 H(A,B) = H(A) + H(B) - I(A,B) 。其中 I(A,B) 是互信息( mutual information,MI ),即一个随机变量包含另一个随机变量信息量的度量。即已知X的情况下,Y的分布是否会改变。
可以理解为,两个随机变量的互信息度量了两个变量间相互依赖的程度。X 和 Y的互信息可以表示为:
I(X;Y) = H(X) - H(X|Y)
H(X)是X的信息熵,H(X|Y)是已知Y的情况下,X的信息熵。结果的单位是比特。
简单来说,互信息的性质为:
- I(X;Y)=0 互信息永远不可能为负
- H(X) - H(X|Y) = I(X;Y) = I (Y;X) = H(Y) - H(Y|X) 互信息是对称的
-当X,Y独立的时候, I(X;Y) = 0 互信息值越大,两变量相关性越强。
-当X,Y知道一个就能推断另一个的时候, I(X;Y) = H(Y) = H(X)
在数据科学中,互信息常用于特征筛选。在通信系统中互信息也应用广泛。在一个点到点的通信系统中,发送信号为X,通过信道后,接收端接收到的信号为Y,那么信息通过信道传递的信息量就是互信息 I(X,Y) 。根据这个概念,香农推导出信道容量(即临界通信传输速率的值)。
信息增益( Information Gain )是用来按照一定规则划分数据集后,衡量信息熵减少量的指数。
那数据集的信息熵又是怎么计算的呢?比如一个常见的0,1二分类问题,我们可以计算它的熵为:
Entropy = -(p(0) * log(P(0)) + p(1)\ * log(P(1)))
当该数据集为50/50的数据集时,它的信息熵是最大的(1bit)。而10/90的数据集将会大大减少结果的不确定性,减小数据集的信息熵(约为0.469bit)。
这样来说,信息熵可以用来表示数据集的纯度( purity )。信息熵为0就表示该数据集只含有一个类别,纯度最高。而较高的信息熵则代表较为平衡的数据集和较低的纯度。
信息增益是提供了一种可以使用信息熵计算数据集经过一定的规则(比如决策树中的一系列规则)进行数据集分割后信息熵的变化的方法。
IG(S,a) = H(S) - H(S|a)
其中,H(s) 是原数据集S的信息熵(在做任何改变之前),H(S|a)是经过变量a的一定分割规则。所以信息增益描述的是数据集S变换后所节省的比特数。
信息增益可以用做决策树的分枝判断方法。比如最常用CART树( Classification and Regression Tree )中的分枝方法,只要在python中设置参数 criterion 为 “entropy” 即可。
信息增益也可以用作建模前的特征筛选。在这种场景下,信息增益和互信息表达的含义相同,会被用来计算两变量之间的独立性。比如scikit-learn 中的函数 mutual_info_classiif()
信息增益在面对类别较少的离散数据时效果较好,但是面对取值较多的特征时效果会有 偏向性 。因为当特征的取值较多时,根据此特征划分得到的子集纯度有更大的可能性会更高(对比与取值较少的特征),因此划分之后的熵更低,由于划分前的熵是一定的,因此信息增益更大,因此信息增益比较偏向取值较多的特征。举一个极端的例子来说,如果一个特征为身份证号,当把每一个身份证号不同的样本都分到不同的子节点时,熵会变为0,意味着信息增益最大,从而该特征会被算法选择。但这种分法显然没有任何实际意义。
这种时候,信息增益率就起到了很重要的作用。
gR(D,A)=g(D,A)/HA(D)
HA(D) 又叫做特征A的内部信息,HA(D)其实像是一个衡量以特征AA的不同取值将数据集D分类后的不确定性的度量。如果特征A的取值越多,那么不确定性通常会更大,那么HA(D)的值也会越大,而1/HA(D)的值也会越小。这相当于是在信息增益的基础上乘上了一个惩罚系数。即 gR(D,A)=g(D,A)∗惩罚系数 。
在CART算法中,基尼不纯度表示一个随机选中的样本被分错类别的可能性,即这个样本被选中的概率乘以它被分错的概率。当一个节点中所有样本均为一种时(没有被分错的样本),基尼不纯度达到最低值0。
举例来说,如果有绿色和蓝色两类数据点,各占一半(蓝色50%,绿色50%)。那么我们随机分类,有以下四种情况:
-分为蓝色,但实际上是绿色(❌),概率25%
-分为蓝色,实际上也是蓝色(✔️),概率25%
-分为绿色,实际上也是绿色(✔️),概率25%
-分为绿色,但实际上是蓝色(❌),概率25%
那么将任意一个数据点分错的概率为25%+25% = 50%。基尼不纯度为0.5。
在特征选择中,我们可以选择加入后使数据不纯度减少最多的特征。
噪音数据简单来说就是会对模型造成误导的数据。分为类别噪声( class noise 或 label noise )和 变量噪声( attribute noise )。类别噪声指的的是被错误标记的错误数据,比如两个相同的样本具有不同的标签等情况。变量噪声指的是有问题的变量,比如缺失值、异常值和无关值等。
决策树其实是一种图结构,由节点和边构成。
-根节点:只有出边没有入边。包含样本全集,表示一个对样本最初的判断。
-内部节点:一个入边多个出边。表示一个特征或是属性。每个内部节点都是一个判断条件,包含数据集中从根节点到该节点所有满足条件的数据的集合。
-叶节点:一个入边无出边。表示一个类,对应于决策结果。
决策树的生成主要分为三个步骤:
1. 节点的分裂 :当一个节点不够纯(单一分类占比不够大或者说信息熵较大)时,则选择将这一节点进行分裂。
2. 决策边界的确定 :选择正确的决策边界( Decision Boundary ),使分出的节点尽量纯,信息增益(熵减少的值)尽可能大。
3. 重复及停止生长 :重复1,2步骤,直到纯度为0或树达到最大深度。为避免过拟合,决策树算法一般需要制定树分裂的最大深度。到达这一深度后,即使熵不等于0,树也不会继续进行分裂。
下面以超级知名的鸢尾花数据集举例来说明。
这个数据集含有四个特征:花瓣的长度( petal length )、花瓣的宽度( petal width )、花萼的长度( sepal length )和花萼的宽度( sepal width )。预测目标是鸢尾花的种类 iris setosa, iris versicolor 和 iris virginica 。
建立决策树模型的目标是根据特征尽可能正确地将样本划分到三个不同的“阵营”中。
根结点的选择基于全部数据集,使用了贪婪算法:遍历所有的特征,选择可以使信息熵降到最低、基尼不纯度最低的特征。
如上图,根节点的决策边界为' petal width = 0.8cm '。那么这个决策边界是怎么决定的呢?
-遍历所有可能的决策边界(需要注意的是,所有可能的决策边界代表的是该子集中该特征所有的值,不是以固定增幅遍历一个区间内的所有值!那样很没有必要的~)
-计算新建的两个子集的基尼不纯度。
-选择可以使新的子集达到最小基尼不纯度的分割阈值。这个“最小”可以指两个子集的基尼不纯度的和或平均值。
ID3是最早提出的决策树算法。ID3算法的核心是在决策树各个节点上根据 信息增益 来选择进行划分的特征,然后递归地构建决策树。
- 缺点 :
(1)没有剪枝
(2)只能用于处理离散特征
(3)采用信息增益作为选择最优划分特征的标准,然而信息增益会偏向那些取值较多的特征(例如,如果存在唯一标识属性身份证号,则ID3会选择它作为分裂属性,这样虽然使得划分充分纯净,但这种划分对分类几乎毫无用处。)
C4.5 与ID3相似,但对ID3进行了改进:
-引入“悲观剪枝”策略进行后剪枝
-信息增益率作为划分标准
-将连续特征离散化,假设 n 个样本的连续特征 A 有 m 个取值,C4.5 将其排序并取相邻两样本值的平均数共 m-1 个划分点,分别计算以该划分点作为二元分类点时的信息增益,并选择信息增益最大的点作为该连续特征的二元离散分类点;
-可以处理缺失值
对于缺失值的处理可以分为两个子问题:
(1)在特征值缺失的情况下进行划分特征的选择?(即如何计算特征的信息增益率)
C4.5 中对于具有缺失值特征,用没有缺失的样本子集所占比重来折算;
(2)选定该划分特征,对于缺失该特征值的样本如何处理?(即到底把这个样本划分到哪个结点里)
C4.5 的做法是将样本同时划分到所有子节点,不过要调整样本的权重值,其实也就是以不同概率划分到不同节点中。
(1)剪枝策略可以再优化;
(2)C4.5 用的是多叉树,用二叉树效率更高;
(3)C4.5 只能用于分类;
(4)C4.5 使用的熵模型拥有大量耗时的对数运算,连续值还有排序运算;
(5)C4.5 在构造树的过程中,对数值属性值需要按照其大小进行排序,从中选择一个分割点,所以只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时,程序无法运行。
可以用于分类,也可以用于回归问题。CART 算法使用了基尼系数取代了信息熵模型,计算复杂度更低。
CART 包含的基本过程有 分裂,剪枝和树选择 。
分裂 :分裂过程是一个二叉递归划分过程,其输入和预测特征既可以是连续型的也可以是离散型的,CART 没有停止准则,会一直生长下去;
剪枝 :采用“代价复杂度”剪枝,从最大树开始,每次选择训练数据熵对整体性能贡献最小的那个分裂节点作为下一个剪枝对象,直到只剩下根节点。CART 会产生一系列嵌套的剪枝树,需要从中选出一颗最优的决策树;
树选择 :用单独的测试集评估每棵剪枝树的预测性能(也可以用交叉验证)。
(1)C4.5 为多叉树,运算速度慢,CART 为二叉树,运算速度快;
(2)C4.5 只能分类,CART 既可以分类也可以回归;
(3)CART 使用 Gini 系数作为变量的不纯度量,减少了大量的对数运算;
(4)CART 采用代理测试来估计缺失值,而 C4.5 以不同概率划分到不同节点中;
(5)CART 采用“基于代价复杂度剪枝”方法进行剪枝,而 C4.5 采用悲观剪枝方法。
(1)决策树易于理解和解释,可以可视化分析,容易提取出规则
(2)可以同时处理分类型和数值型数据
(3)可以处理缺失值
(4)运行速度比较快(使用Gini的快于使用信息熵,因为信息熵算法有log)
(1)容易发生过拟合(集成算法如随机森林可以很大程度上减少过拟合)
(2)容易忽略数据集中属性的相互关联;
(3)对于那些各类别样本数量不一致的数据,在决策树中,进行属性划分时,不同的判定准则会带来不同的属性选择倾向。
写在后面:这个专辑主要是本小白在机器学习算法学习过程中的一些总结笔记和心得,如有不对之处还请各位大神多多指正!(关于决策树的剪枝还有很多没有搞懂,之后弄明白了会再单独出一篇总结哒)
参考资料链接:
1.
2.
3.
4.
5.
6.
7.
8.
我们在微信上24小时期待你的声音
解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流