본문 바로가기
Study/혼자 공부하는 머신러닝

혼자 공부하는 머신러닝 + 딥러닝 - 결정 트리

by Wanooky 2022. 3. 14.

먼저 저번에 배운 로지스틱 회귀로 와인을 분류해보자.

 

import pandas as pd
wine = pd.read_csv('https://bit.ly/wine_csv_data

wine.info()



wine.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6497 entries, 0 to 6496
Data columns (total 4 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   alcohol  6497 non-null   float64
 1   sugar    6497 non-null   float64
 2   pH       6497 non-null   float64
 3   class    6497 non-null   float64
dtypes: float64(4)
memory usage: 203.2 KB

alcohol과 sugar, pH를 데이터 셋, class를 타겟 셋으로 설정하자.

data = wine[['alcohol','sugar','pH']].to_numpy()
target = wine['class'].to_numpy()

from sklearn.model_selection import train_test_split
train_input, test_input, train_target, test_target = train_test_split(data, target, 
													test_size = 0.2, random_state = 42)
                                                    

from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
ss.fit(train_input)
train_scaled = ss.transform(train_input)
test_scaled = ss.transform(test_input)

from sklearn.linear_model import LinearRegression
lr = LogisticRegression()
lr.fit(train_scaled, train_target)
print(lr.score(train_scaled, train_target)
print(lr.score(test_scaled, test_target)

0.7808350971714451
0.7776923076923077

로지스틱 회귀를 이용하여 수행하면 결과가 썩 좋지만은 않다.

그렇다면 결정 트리를 사용해서 분류를 해보자.

 

사이킷런에서 DecisionTreeClassifier 클래스를 사용하여 결정 트리 모델을 훈련해보자.

from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier(random_state = 42)
dt.fit(train_scaled, train_target)
print(dt.score(train_scaled, train_target))
print(dt.score(test_scaled, test_target))


0.996921300750433
0.8592307692307692

성능이 그리 나쁘지만은 않으나, 과대적합된 모델이라 볼 수 있다.

이 모델을 그림으로도 그릴 수 있는데, 이는 plot_tree()라는 매서드를 사용하여 그릴 수 있다.

 

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize = (10, 7))
plot_tree(dt)
plt.show()

결정 트리는 위에서부터 아래 방향으로 커가는데, 맨 위의 노드를 루트 노드, 맨 아래 노드를 리프 노드라고 한다.

 

위의 모델은 보다시피 길이가 너무 길다. 그러면 길이를 확장해서 출력할 수 있는데, 이는 max_depth 매개변수로 지정할 수 있다. 만약 1을 저장한다면 루트 노드 다음 노드까지만 확장한다.

filled 매개변수에는 클래스에 맞게 노드의 색을 칠할 수 있으며, True 와 False 값으로 저장한다.

feature_names 매개변수에는 특성의 이름을 전달 할 수 있다.

 

plt.figure(figsize=(10,7))
plot_tree(dt, max_depth=1, filled=True, feature_names=['alcohol', 'sugar', 'pH'])
plt.show()

읽는 순서는 다음과 같다.

테스트 조건

불순도

총 샘플 수

클래스별 샘플 수

 

그리고 왼쪽으로 가면 Yes, 오른쪽으로 가면 No 이다.

그리고 filled 매개변수로 인해 어떤 클래스의 비율이 높아지면 색이 진해진다.

 

결정 트리에서 예측하는 방법은 리프 노드에서 가장 많은 클래스가 예측 클래스가 된다.

 

 

불순도는 지니계수를 이용해서 측정하는데, 이는 criterion 매개변수에 저장되어있으며 기본 값이다.

지니 계수는 1에서 각각 클래스의 확률의 제곱을 해준 값을 빼준다.

만약 두 클래스의 비율이 같아서 (1/2) 0.5가 되는 경우에 지니 계수는 최악이 된다.

만약 노드에 하나의 클래스만 있다면 지니계수는 0이 되는데 이 노드를 순수 노드라고도 부른다.

 

결정 트리 모델은 부모 노드와 자식 노드의 불순도 차이가 가능한 크도록 트리를 성장시킨다.

 

자식 노드의 지니 계수를 샘플 개수에 비례하여 모두 더한다. 그리고 부모 노드의 지니 계수에서 빼주면 된다.

부모와 자식 노드 사이의 불순도 차이를 정보이득이라고 부른다. 

 

또 다른 불순도로는 엔트로피 계수가 있다. criterion='entropy'로 지정하면 엔트로피 계수를 사용할 수 있다.

엔트로피 계수는 밑이 2인 로그를 사용하여 곱한다.

 

결국 결정 트리 알고리즘은 불순도 기준을 사용해 정보 이득이 최대가 되도록 노드를 분할하고, 노드를 순수하게 나눌수록 정보 이득이 커진다.

 

결정 트리 알고리즘은 가지치기를 해주어야 한다. 가지치기를 하는 방법은 트리의 최대 깊이를 지정하는 것이다.

dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(train_scaled, train_target)
print(dt.score(train_scaled, train_target))
print(dt.score(test_scaled, test_target))

0.8454877814123533
0.8415384615384616

plt.figure(figsize=(10, 7))
plot_tree(dt, filled=True, feature_names=['alcohol','sugar','pH'])
plt.show()

트레이닝 셋의 성능은 낮아졌지만 테스트 셋의 성능은 거의 그대로이다.

 

결정 트리에서는 표준화 처리를 굳이 할 필요가 없다.

표준화 하기 전 자료로 결정 트리 알고리즘에 대입해도 결과는 동일하다.

dt = DecisionTreeClassifier(max_depth=3, random_state=42)
dt.fit(train_input, train_target)
print(dt.score(train_input, train_target))
print(dt.score(test_input, test_target))

0.8454877814123533
0.8415384615384616

plt.figure(figsize=(10,7))
plot_tree(dt, filled=True, feature_names = ['alcohol','sugar','pH'])
plt.show()

그리고 어떤 특성이 분류에 가장 유용한지 나타내는 특성 중요도도 나타낼 수 있는데, 이는 feature_importances_

속성에 저장되어 있다.

 

print(dt.feature_importances_)

[0.12345626 0.86862934 0.0079144 ]