Data Analysis

의사결정나무모델 실습 - 타이타닉 생존 여부 예측

김심슨 2025. 6. 25. 08:00

https://www.kaggle.com/competitions/titanic/data

 

Titanic

Kaggle profile for Titanic

www.kaggle.com

1. 목적, 역할 

타이타닉 탑승객 데이터를 활용 

의사결정나무 분류 모델을 만들어 탑승객의 생존 여부 예측 

모델 학습 -> 예측 -> 성능평가 -> 시각화의 전 과정 이해해보기 

 

2. 필수 라이브러리 설명 

* pandas : 데이터(csv 파일)를 읽고 관리하기 위해 사용 (DataFrame 활용)

* numpy 데이터 전처리 과정에서 결측치를 처리, 수치계산 간편, 수학적 연산 제공 

* matplotlib : 데이터를 시각적으로 확인하여 모델의 성능과 예측결과를 직관적 이해 

* sklearn : 머신러닝 모델 구현, 데이터 분할, 예측, 성능평가 등 머신러닝의 모든 과정을 쉽게 할 수 있도록 지원하는 핵심 라이브러리 

 

3. 주요 흐름 

1. 데이터 로딩 

2. 전처리 (결측치 처리, 숫자형 변환 등)

3. 데이터 셋 분할 

4. 의사결정나무 모델 생성 및 학습 

5. 예측 및 성능 평가

6. 모델 결과 시각화 (트리 그래프 및 혼동 행렬)

 

4. 코드 한 줄씩 뜯어먹기 

#데이터 로드, 데이터 확인 
df = pd.read_csv('assets/train.csv')

# 3. 데이터 확인
df.info()
print(df.head())

-- 출력 --
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
   PassengerId  Survived  Pclass  ...     Fare Cabin  Embarked
0            1         0       3  ...   7.2500   NaN         S
1            2         1       1  ...  71.2833   C85         C
2            3         1       3  ...   7.9250   NaN         S
3            4         1       1  ...  53.1000  C123         S
4            5         0       3  ...   8.0500   NaN         S

[5 rows x 12 columns]

 

#필요한 열 선택 + 결측치 처리 (데이터 전처리)
df = df[['Survived', 'Pclass', 'Sex', 'Age', 'Fare']] # 필요한 변수만 선택
df['Age'] = df['Age'].fillna(df['Age'].median()) # Age 없는 경우 중간값으로
df['Sex'] = df['Sex'].map({'male': 0, 'female': 1}) # 성별을 숫자로

 

# 입력 (x)과 출력(y) 나누기 
X = df.drop('Survived', axis=1)
y = df['Survived'] # 출력값 : 생존여부 (0 또는 1)

 

# 훈련 / 테스트 데이터 분리 
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size = 0.2, random_state = 42
)
# 모델 생성 및 학습 

clf = DecisionTreeClassifier(max_depth = 3, random_state = 42)
clf.fit(x_train, y_train)

- 과적합 방지를 위해서 최대 깊이 3개로 제한 

- 훈련 데이터(train)를 통해 모델 학습 진행 

 

# 예측 및 정확도 평가 

y_pred = clf.predict(X_test) # 예측값
accuracy = accuracy_score(y_test, y_pred) # 정확도
print(f'정확도 : {accuracy:.2f}') #0.80

 

# 의사결정나무 모델 시각화 (구조 확인)

plt.rc('font', family='Malgun Gothic')
plt.figure(figsize=(8, 5))
plot_tree(
    clf,
    feature_names = X.columns, # 변수명
    class_names = ['Died', 'Survived'], # 클래스명 2개
    filled = True # 노드 색상으로 클래스 구분
)
plt.title('의사결정나무 시각화')
plt.show()

< 의사결정나무 시각화 해석 >

1. 맨 위 (루트 노드)

  • **성별(Sex)**을 기준으로 분리
  • 총 712명 중, 사망(Died) 444명, 생존(Survived) 268명
  • Gini 계수 0.469 → 다소 높은 불순도

2. 왼쪽 (여성으로 추정: Sex ≤ 0.5 → True)

  • **나이(Age ≤ 6.5세)**로 추가 구분
  • 어린 나이(6.5세 이하)는 생존율이 높음 (대부분 생존)
  • 나이가 많을수록 사망이 증가
  • 하위 노드에서 다시 **객실등급(Pclass)**으로 구분
    • 객실등급이 좋을수록(Pclass 값이 낮을수록) 생존 확률이 증가

3. 오른쪽 (남성으로 추정: Sex ≤ 0.5 → False)

  • 남성은 기본적으로 사망률이 더 높음.
  • **객실등급(Pclass ≤ 2.5)**로 구분됨
    • 상위 클래스(1,2등급)의 남성은 생존 가능성이 높아짐
    • 다시 나이로 구분했을 때 아주 어린 아이는 대부분 생존
  • 중하위 클래스(3등급)의 경우, 생존 가능성이 떨어지며, 운임료(Fare)가 23.35 이하로 낮으면 생존 가능성 더욱 낮아짐.

요약 해석:

  • 성별이 가장 결정적: 여성이 생존율 높음.
  • 여성 중에서도 나이가 어릴수록 생존율이 높고, 객실 등급이 높을수록 생존율 증가.
  • 남성은 기본적으로 생존율 낮지만, 어린아이와 상위 클래스일 때 생존율이 올라감.

+) 지니계수?

노드의 불순도를 측정하는 값 

지니계수가 있는 시각화는 노드의 신뢰성(클래스의 순도)을 판단하는 데 유용 

값이 낮을수록 노드 내 데이터의 순도가 높다는 의미. 예측 정확성이 높아지는 경향이 있음 (하나의 클래스에 집중됨) 


#모델 성능 지표 출력 (정밀도, 재현율, f1-score)
print(classification_report(y_test, y_pred, target_names=['Died', 'Survived']))

--- 출력 ---
              precision    recall  f1-score   support (각 클래스의 실제 데이터 수입)

        Died       0.80      0.88      0.84       105
    Survived       0.80      0.69      0.74        74

    accuracy                           0.80       179
   macro avg       0.80      0.78      0.79       179
weighted avg       0.80      0.80      0.80       179

=> classification_report 함수 호출로 실제 값 (y_test)과 예측값 (y_pred)을 받아 성능 평가표를 출력

=> target_names 옵션을 사용하여 클래스 레이블(0과 1) 대신 이해하기 쉽게 'Died', 'Survived'  로 표현 

 

< 해석 > 

* 정밀도 (Precision)

- 사망으로 예측한 사람들 중 80%가 실제로 사망

- 생존으로 예측한 사람 중 80%가 실제로 생존

* 재현율 (Recall)

- 실제로 사망한 사람 중 88%를 정확히 찾아냄

- 실제 생존한 사람 중 69%만을 정확히 찾아냈습니다. 즉, 생존한 사람 중 31%는 사망으로 잘못 예측

* F1-score

사망 : 0.84 (안정적) 

생존 : 0.74 (개선의 여지 있음 

 

 

< 혼동 행렬 >