Tải bản đầy đủ

Mô hình cây phân loại sử dụng CARET

Mô hình cây phân loại sử dụng CARET
Bs. Lê Ngọc Khả Nhi


1

1. Giới thiệu
caret là viết tắt của Classification And REgression Training
Đây là một công cụ đặc biệt với 2 ứng dụng chính: Mô hình dự báo và Machine learning (Máy học)

caret giống như một cỗ máy lớn tích hợp bên trong nó hàng trăm cơ phận nhỏ, chính là những
package chuyên dụng cho mô hình hồi quy và phân loại.
Caret hỗ trợ tới 217 kiểu mô hình khác nhau, bao gồm
92 mô hình phân loại (bao gồm logistic, naive Bayes, kNN, decision tree…. )
52 mô hình hồi quy (bao gồm robust linear, quantile, neural network, fuzzy, …)
73 mô hình lưỡng dụng (bao gồm glm, gam, random forest, pls, vector machine…)
Caret cung cấp một giao thức chung, tổng quát cho toàn bộ các mô hình này
Caret cho phép thực hiện: Huấn luyện (training), Tinh chỉnh (tuning), Kiểm định (testing) các kiểu
mô hình mà nó hỗ trợ
Ngoài ra caret còn cung cấp một số hàm đặc biệt hỗ trợ việc dựng mô hình dự báo, ví dụ chuẩn bị ,
thăm dò và xử lý số liệu trước khi dựng mô hình.



1

2. Mục tiêu chung của series caret
Các bài hướng dẫn trong series CARET sẽ giúp các bạn thực hiện được:
1. Các thao tác cơ bản trong caret
2. Quy trình huấn luyện (thí dụ: kiểm chứng chéo)
3. Kiểm định phâm chất mô hình, hoặc so sánh
Cho 1 kiểu mô hình bất kì nào đó thuộc 1 trong 2 dạng: hồi quy và phân loại
Những gì serie này sẽ không đề cập:
1. Lý thuyết và nguyên tắc của mỗi loại mô hình
2. Cách diễn giải mô hình
Series này dành cho các bạn đã có kiến thức và kỹ năng nhất định về mô hình dự báo
và/hoặc mô hình phân loại; nhưng chưa sử dụng qua caret.
Ghi chú: Mô hình trong bài chỉ có mục đích minh họa.


1

3. Mục tiêu riêng của bài
Bài này đề cập về 1 dạng Mô hình CÂY (CART) dựa theo phương pháp của Breiman, Friedman,
Olshen and Stone (1984).
Package được caret triệu hồi là rpart của tác giả Terry Therneau (2015).
CART là một kiểu mô hình cây lưỡng dụng (cả hồi quy và phân loại),
nhưng bài này chỉ áp dụng vào mục đích phân loại 1 biến kết quả nhị phân.
Nội dung của bài bao gồm:
1.
2.
3.
4.

Chuẩn bị dữ liệu
Huấn luyện mô hình bằng kiểm chứng chéo
Trích xuất nội dung mô hình
Kiểm định khả năng phân loại của mô hình

Mục tiêu quan trọng nhất của bài là:
tạo cảm hứng và sự tò mò cho tất cả các bạn chưa biết đến package caret.


2

1. Lộ trình tổng quát cho kiểm chứng chéo / caret

1
Chuẩn bị số liệu
Pha trộn và phân chia

2

3

Thiết lập chế độ
huấn luyện và tinh
chỉnh

4
Huấn luyện và tinh chỉnh
mô hình
Hàm train

Hàm trainControl
Hàm expand.grid

5
Kiểm định mô hình
Hàm predict,
confusionMatrix
multiClassSummary


2

2. Sơ đồ Kiểm chứng chéo
Mẫu nguyên thủy (N)

20%

80%
Phân dùng để huấn luyện (Train)

1 Phân chia
2 Huấn luyện

Phần dùng để kiểm định (Test)

TRAIN

TEST

Phân loại thực tế
(Xác suất = 1)

Kiểm chứng chéo lặp lại (k=5,n=10)

Confusion matrix
Kappa coefficient (tương hợp)
Mc-Nemar test
Sensitivity (độ nhạy)
Specificity (độ đặc hiệu)
Accuracy (độ chính xác chung)


Phân chia ngẫu nhiên mẫu Huấn luyện thành 5 khối bằng nhau

4
3

4 khối để dựng mô hình

Kiểm định mô hình
Trên mẫu « Test »

1 khối kiểm chứng

Tinh
chỉnh

Kiểm chứng

Log-Loss (sai biệt dự báo)
ROC

Mô hình

Dự báo xác suất
Lặp lại quy trình này 10 lần, ngẫu nhiên

Phân loại dự báo

Mô hình sau cùng
Kiểm tra ROC, Kappa, độ chính xác

Kết luận

Khả năng phân loại
của mô hình


2

3. Chuẩn bị số liệu

data=read.csv("http://vincentarelbundock.github.io/Rdatasets/csv/MASS/biopsy.csv")
data=data[,c(3:12)]
V1: clump thickness.
V2: uniformity of cell size.
V3 : uniformity of cell shape.
V4: marginal adhesion.
V5: single epithelial cell size.
V6:bare nuclei (16 values are missing).
V7:bland chromatin.
V8:normal nucleoli.
V9:mitoses.
class: "benign" or "malignant".

Dataset biopsy bao gồm :
Class là biến kết quả: phân loại khối u vú: lành tính hay ác tính
V1 tới V9 là thang điểm tế bào học trên mẫu sinh thiết theo 9 tiêu chí, đều là
biến kiểu số, thứ hạng
Mục tiêu của chúng ta là xây dựng một mô hình CART cho phép phân loại khối u
vú dựa vào giá trị của một hay nhiều tiêu chí tế bào học


2

3. Chuẩn bị số liệu

splitdata=function(dataframe, seed=NULL,ratio=NULL) {
if (!is.null(seed)) set.seed(seed)
dataframe2<-dataframe[sample(1:nrow(data),nrow(data),replace=F),]
index <- 1:nrow(dataframe2)
trainid <- sample(index, trunc(length(index)*ratio))
trainsubset <- dataframe2[trainid, ]
testsubset <- dataframe2[-trainid, ]
list(training=trainsubset,testing=testsubset)
}
set.seed(123)
split=splitdata(data, seed=123, 0.80)
train=split$training
test=split$testing
train=na.omit(train)
test=na.omit(test)

Tạo hàm splitdata với công dụng:
1. Trộn ngẫu nhiên dataset nguyên thủy
2. Cắt ngẫu nhiên dataset này thành 2 subset: Train
và Test

Áp dụng hàm này để tạo 2 subset Train và Test


2

3. Chuẩn bị số liệu

library(ggplot2)
ggplot(train,aes(class,..count..))+geom_bar(aes(fill=class),position = "identity")
ggplot(test,aes(class,..count..))+geom_bar(aes(fill=class),position = "identity")
ggplot(data,aes(class,..count..))+geom_bar(aes(fill=class),position = "identity")

Mẫu gốc

Mẫu TRAIN

Mẫu TEST


3

1. Huấn luyện mô hình
# Thiết lập chế độ tinh chỉnh
Grid=expand.grid(maxdepth=c(1,2,3,4,5,6,7))
# Thiết lập chế độ huấn luyện
Control=trainControl(method= "repeatedcv", number = 5, repeats = 10,classProbs = TRUE, summaryFunction = twoClassSummary)
set.seed(333)
cart1=train(class~.,data=train,method = "rpart2",trControl = Control,metric="ROC",tuneLength=20)
set.seed(333)
cart2=train(class~.,data=train,method = "rpart2",trControl = Control,metric="ROC",tuneGrid =Grid)
Ghi chú:
Nếu muốn dùng set.seed thì bạn phải đặt nó ngay trước khi thực hiện hàm train.
Model CART cho phép tinh chỉnh maxdepth tương ứng với method = rpart2
Thông số có thể tinh chỉnh của mô hình CART là maxdepth, việc tinh chỉnh có thể thực hiện tự động (tùy chỉnh tuneLength=20) hoặc thủ công
(bằng hàm expand.grid)
Chế độ huấn luyện phải được thiết lập trước khi tiến hành huấn luyện (hàm trainControl).
Phương pháp huấn luyện ở đây là K-fold cross validation, k=5 (5 block) và lặp lại 10 lần, classProbs=TRUE và summaryFunction=twoClassSummary
dùng cho mô hình phân loại . Tiêu chí kiểm định phổ biến của mô hình phân loại là ROC


3

2. Xuất kết quả huấn luyện

cart2
CART

cart1
559 samples
9 predictor
2 classes: 'benign', 'malignant'

CART
559 samples
9 predictor
2 classes: 'benign', 'malignant'
No pre-processing
Resampling: Cross-Validated (5 fold, repeated 10 times)
Summary of sample sizes: 435, 436, 434, 436, 435, 436, ...
Resampling results across tuning parameters:
maxdepth ROC
1
0.9054096
2
0.9431093
3
0.9471832
7
0.9478149

Sens
Spec
0.9495372 0.8612821
0.9557706 0.9261943
0.9580563 0.9283131
0.9566479 0.9293387

ROC was used to select the optimal model using the largest value.
The final value used for the model was maxdepth = 7.
Nhận xét: Tinh chỉnh thủ công cho ra kết quả phong phú hơn
Tuy nhiên nội dung của mô hình thì như nhau

No pre-processing
Resampling: Cross-Validated (5 fold, repeated 10 times)
Summary of sample sizes: 435, 436, 434, 436, 435, 436, ...
Resampling results across tuning parameters:

maxdepth ROC
1
0.9054096
2
0.9431093
3
0.9471832
4
0.9478149
5
0.9478149
6
0.9478149
7
0.9478149

Sens
Spec
0.9495372 0.8612821
0.9557706 0.9261943
0.9580563 0.9283131
0.9566479 0.9293387
0.9566479 0.9293387
0.9566479 0.9293387
0.9566479 0.9293387

ROC was used to select the optimal model using the largest value.
The final value used for the model was maxdepth = 4.


1

2. Xuất kết quả huấn luyện
ggplot(cart1)

plot(cart1)

Mô hình tối ưu


3

2. Xuất kết quả huấn luyện
summary(cart2)

CP nsplit rel error
1 0.80104712 0 1.0000000
2 0.07853403 1 0.1989529
3 0.01000000 2 0.1204188
Variable importance
V2 V3 V6 V7 V4 V5 V1
22 18 17 15 14 14 1
Node number 1: 544 observations, complexity param=0.8010471
predicted class=benign expected loss=0.3511029 P(node) =1
class counts: 353 191
probabilities: 0.649 0.351
left son=2 (371 obs) right son=3 (173 obs)
Primary splits:
V2 < 3.5 to the left, improve=177.2612, (0 missing)
V3 < 3.5 to the left, improve=165.5581, (0 missing)
V6 < 2.5 to the left, improve=154.0274, (0 missing)
V7 < 3.5 to the left, improve=149.0006, (0 missing)
V5 < 2.5 to the left, improve=148.4639, (0 missing)
Surrogate splits:
V3 < 3.5 to the left, agree=0.936, adj=0.798, (0 split)
V7 < 3.5 to the left, agree=0.892, adj=0.659, (0 split)
V5 < 3.5 to the left, agree=0.884, adj=0.636, (0 split)
V6 < 3.5 to the left, agree=0.881, adj=0.624, (0 split)
V4 < 3.5 to the left, agree=0.877, adj=0.613, (0 split)

Node number 2: 371 observations, complexity param=0.07853403
predicted class=benign expected loss=0.0754717 P(node) =0.6819853
class counts: 343 28
probabilities: 0.925 0.075
left son=4 (354 obs) right son=5 (17 obs)
Primary splits:
V6 < 5.5 to the left, improve=26.70479, (0 missing)
V1 < 6.5 to the left, improve=21.37781, (0 missing)
V8 < 2.5 to the left, improve=19.01168, (0 missing)
V3 < 2.5 to the left, improve=15.40157, (0 missing)
V7 < 3.5 to the left, improve=14.89895, (0 missing)
Surrogate splits:
V1 < 7.5 to the left, agree=0.965, adj=0.235, (0 split)
V4 < 4.5 to the left, agree=0.965, adj=0.235, (0 split)
V3 < 4.5 to the left, agree=0.957, adj=0.059, (0 split)
V7 < 6.5 to the left, agree=0.957, adj=0.059, (0 split)
Node number 3: 173 observations
predicted class=malignant expected loss=0.05780347 P(node) =0.3180147
class counts: 10 163
probabilities: 0.058 0.942

Node number 4: 354 observations
predicted class=benign expected loss=0.03389831 P(node) =0.6507353
class counts: 342 12
probabilities: 0.966 0.034
Node number 5: 17 observations
predicted class=malignant expected loss=0.05882353 P(node) =0.03125
class counts: 1 16
probabilities: 0.059 0.941


3

2. Xuất kết quả huấn luyện

cart2$finalModel
n= 544
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 544 191 benign (0.64889706 0.35110294)
2) V2< 3.5 371 28 benign (0.92452830 0.07547170)
4) V6< 5.5 354 12 benign (0.96610169 0.03389831) *
5) V6>=5.5 17 1 malignant (0.05882353 0.94117647) *
3) V2>=3.5 173 10 malignant (0.05780347 0.94219653) *

library(partykit)
library(party)
cartp=as.party(cart2$finalModel)
cartp

Model formula:
.outcome ~ V1 + V2 + V3 + V4 + V5 + V6 + V7 + V8 + V9
Fitted party:
[1] root
| [2] V2 < 3.5
| | [3] V6 < 5.5: benign (n = 354, err = 3.4%)
| | [4] V6 >= 5.5: malignant (n = 17, err = 5.9%)
| [5] V2 >= 3.5: malignant (n = 173, err = 5.8%)
Number of inner nodes: 2
Number of terminal nodes: 3

Ghi chú:
Sau khi kích hoạt package party và party kit, bạn có thể vẽ mô hình cây,
như trong slide tiếp theo


3

2. Xuất kết quả huấn luyện

plot(cartp)


3

2. Xuất kết quả huấn luyện

plot(cartp,type="simple")


4

1. Kiểm định mô hình

confusionMatrix(cart2)
Cross-Validated (5 fold, repeated 10 times) Confusion Matrix
(entries are percentual average cell counts across resamples)
Reference
Prediction benign malignant
benign
62.1
2.5
malignant
2.8
32.6
Accuracy (average) : 0.9471
Ghi chú:
hàm confussionMatrix khi áp dụng cho chính đối tượng train của caret, nó sẽ khảo sát mức độ chính xác trung bình của
phân loại trong suốt quá trình kiểm định chéo.
Đây mới chỉ là bước đầu, tiếp theo chúng ta sẽ kiểm định mô hình tối ưu trên subset TEST độc lập


4

1. Kiểm định mô hình
test=na.omit(test)
test$pred=predict(cart2,newdata=test)
confusionMatrix(test$pred,test$class)

Confusion Matrix and Statistics
Reference
Prediction benign malignant
benign
89
3
malignant
2
45
Accuracy
95% CI
No Information Rate
P-Value [Acc > NIR]

:
:
:
:

0.964
(0.9181, 0.9882)
0.6547
<2e-16

Kappa : 0.9201
Mcnemar's Test P-Value : 1

Quy trình trên đây nhằm lượng giá khả năng phân loai của mô hình cây, trên subset
TEST.
Lưu ý rất quan trọng: caret mặc định level 1 của class là Dương tính. Cách xếp loại này
có thể mâu thuẫn với vấn đề thực tế. Thí dụ ở đây thực ra magnignant mới là positive.
Nên bạn phải diễn dịch đảo ngược giữa Sensitivity và Specificity…

Sensitivity
Specificity
Pos Pred Value
Neg Pred Value
Prevalence
Detection Rate
Detection Prevalence
Balanced Accuracy

:
:
:
:
:
:
:
:

0.9780
0.9375
0.9674
0.9574
0.6547
0.6403
0.6619
0.9578

'Positive' Class : benign


4

1. Kiểm định mô hình

testresult=predict(cart2, test, type = "prob")
testresult$obs=test$class
testresult$pred=predict(cart2,test)
multiClassSummary(testresult,lev=levels(testresult$obs))

logLoss
0.1553397
Specificity
0.9375000

ROC
0.9615385
Pos_Pred_Value
0.9673913

Accuracy
0.9640288
Neg_Pred_Value
0.9574468

Kappa
Sensitivity
0.9200506
0.9780220
Detection_Rate Balanced_Accuracy
0.6402878
0.9577610

Quy trình trên đây cho ra kết quả đầy đủ nhất
Phẩm chất mô hình thể hiện qua:

logLoss = sai biệt giữa xác suất dự báo và xác suất thực tế (=1). LogLoss càng thấp càng tốt
ROC= diện tích dưới đường cong ROC (tương quan giữa sensitivity và Specificity)
Những chỉ số còn lại thì dễ dàng suy ra từ cunfussion matrix. Bạn cũng có thể tính thủ công những giá trị này và 95%CI của
chúng dựa vào bảng 2x2


4

1. Kiểm định mô hình

library(pROC)
cartpred=predict(cart2,newdata = test , type="prob")
cartROC=roc(test$class,cartpred[,"malignant"], levels = rev(test$class))

plot(smooth(cartROCpos), col="blue", print.auc=TRUE,
auc.polygon=TRUE, grid=c(0.1,
0.2),auc.polygon.col="greenyellow",print.thres=TRUE, identity=F)
smooth.roc(roc = cartROCpos)
Data: cartpred[, "malignant"] in 48 controls (test$class 2) > 91 cases
(test$class 1).
Smoothing: binormal
Area under the curve: 0.993

Bước cuối cùng là vẽ đường cong ROC
Lưu ý: Đây là hình thức mở rộng của phương pháp ROC, lúc này nó áp
dụng cho cả 1 mô hình phân loại, chứ không phải chỉ cho 1 biến số định
lượng theo phân tích ROC truyền thống. Đường cong ROC thực chất chỉ là
một trò chơi về xác suất.


Hẹn gặp lại bạn
vào bài tiếp theo



Tài liệu bạn tìm kiếm đã sẵn sàng tải về

Tải bản đầy đủ ngay

×

×