決定木を描く
Text Update: 09/16, 2019 (JST)

rpart.plotパッケージを用いると決定木をカラフルに可視化することができますが、読みにくい印象があります。そこで、partykitパッケージを用いた可視化を紹介します。

Packages and Datasets

本ページではR version 3.6.1 (2019-07-05)の標準パッケージ以外に以下の追加パッケージを用いています。
 

Package Version Description
knitr 1.24 A General-Purpose Package for Dynamic Report Generation in R
partykit 1.2.5 A Toolkit for Recursive Partytioning
rpart 4.1.15 Recursive Partitioning and Regression Trees
rpart.plot 3.0.8 Plot ‘rpart’ Models: An Enhanced Version of ‘plot.rpart’
tidyverse 1.2.1 Easily Install and Load the ‘Tidyverse’

 
また、本ページでは以下のデータセットを用いています。
 

Dataset Package Version Description
TitanicSurvival carData 3.0.2 Survival of Passengers on the Titanic

 

決定木を作成する

可視化対象となるTitanicSurvivalデータセットは以下のようなデータです。

carData::TitanicSurvival
survived sex age passengerClass
Allen, Miss. Elisabeth Walton yes female 29 1st
Allison, Master. Hudson Trevor yes male 0.92 1st
Allison, Miss. Helen Loraine no female 2 1st
NA NA NA
Zakarian, Mr. Mapriededer no male 26.5 3rd
Zakarian, Mr. Ortin no male 27 3rd
Zimmerman, Mr. Leo no male 29 3rd

 
生存者(survived)の人数と比率は以下のようになっています。

carData::TitanicSurvival$survived %>% 
  table() %>% print() %>% 
  prop.table()
## .
##  no yes 
## 809 500
## .
##       no      yes 
## 0.618029 0.381971

 
survivedをキーに決定木を作成します。

dt <- carData::TitanicSurvival %>% 
  rpart::rpart(survived ~ ., data = .)

dt
## n= 1309 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 1309 500 no (0.6180290 0.3819710)  
##    2) sex=male 843 161 no (0.8090154 0.1909846)  
##      4) age>=9.5 800 136 no (0.8300000 0.1700000) *
##      5) age< 9.5 43  18 yes (0.4186047 0.5813953)  
##       10) passengerClass=3rd 29  11 no (0.6206897 0.3793103) *
##       11) passengerClass=1st,2nd 14   0 yes (0.0000000 1.0000000) *
##    3) sex=female 466 127 yes (0.2725322 0.7274678) *

 

決定木を可視化する

作成した決定木をpartykit関数を用いて可視化してみます。各ノードにノード番号が振られると共に最後のノードにはバーチャートで比率が表示されるので解釈しやすいと思います。。

dt %>% 
  partykit::as.party() %>% 
  plot()

 

参考)rpart.plotパッケージによる可視化

dt %>% 
  rpart.plot::rpart.plot(type = 5, extra = 101)

 
Enjoy!  

本blogに対するアドバイス、ご指摘等は データ分析勉強会 または GitHub まで。

CC BY-NC-SA 4.0 , Sampo Suzuki