R語言模擬:Cross Validation

歡迎關注天善智能,我們是專注於商業智能BI,人工智能AI,大數據分析與挖掘領域的垂直社區,學習,問答、求職一站式搞定!

對商業智能BI、大數據分析挖掘、機器學習,python,R等數據領域感興趣的同學加微信:tstoutiao,邀請你進入數據愛好者交流群,數據愛好者們都在這兒。

作者:量化小白一枚,上財研究生在讀,偏向數據分析與量化投資

前兩篇

R語言模擬:Bias Variance Trade-Off

R語言模擬:Bias Variance Decomposition

在理論推導和模擬的基礎上,對於誤差分析中的偏差方差進行了分析。本文在前文的基礎上,分析一種常用的估計預測誤差進而可以參數優化的方法:交叉驗證,並通過R語言進行模擬。

K-FOLD CV

交叉驗證是數據建模中一種常用方法,通過交叉驗證估計預測誤差並有效避免過擬合現象。簡要說明CV(CROSS VALIDATION

)的邏輯,最常用的是K-FOLD CV,以K = 5為例。

R語言模擬:Cross Validation

將整個樣本集分為K份,每次取其中一份作為Validation Set,剩餘四份為Trainset,用Trainset訓練模型,然後計算模型在Validation set上的誤差,循環k次得到k個誤差後求平均,作為預測誤差的估計量。

R語言模擬:Cross Validation

除此之外,比較常用的還有LOOCV,每次只留出一項做Validation,剩餘項做Trainset。

參數優化

對於含有參數的模型,可以分析模型在不同參數值下的CV的誤差,選取誤差最小的參數值。

R語言模擬:Cross Validation

誤區

ESL 7.10.2中提到應用CV的兩種方法,比如對於一個包含多個自變量的分類模型,建模中包括兩方面,一個是篩選出預測能力強的變量,一個是估計最佳的參數。因此有兩種使用CV的方法(以下內容摘自ESL 7.10.2)

1

1.Screen the predictors: find a subset of “good” predictors that show fairly strong (univariate) correlation with the class labels

2.Using just this subset of predictors, build a multivariate classifier.

3.Use cross-validation to estimate the unknown tuning parameters and to estimate the prediction error of the final model.

2

1. Divide the samples into K CV folds(group) at random.

2. For each fold k = 1,2,...,K

(a) Find a subset of "good" predictors that show fairly strong correlation with the class labels, using all of the samples except those in fold k.

(b) Using just this subset of predictors,build a multivariate classifier,using all of the samples except those in fold K.

(c) Use the classifier to predict rhe class labels for the samples in fold k.

簡單來說,第一種方法是先使用全樣本篩出預測能力強的變量,僅使用這部分變量進行建模,然後用這部分變量建立的模型通過CV優化參數;第二種方法是對全樣本CV後,在CV過程中進行篩選變量,然後用篩選出來的變量優化參數,這樣CV中每個循環裡得到的預測能力強的變量有可能是不一樣的。

我們經常使用的是第一種方法,但事實上第一種方法是錯誤的,直接通過全樣本得到的預測能力強的變量,再進行CV,計算出來的誤差一定是偏低的。因為篩選過程使用了全樣本,再使用CV,用其中K-1份預測1份,使用的預測能力強的變量中實際上包含了這一份的信息,因此存在前視誤差。

(The problem is that the predictors have an unfair advantage, as they were chosen in step (1) on the basis of all of the samples. Leaving samples out after the variables have been selected does not correctly mimic the application of the classifier to a completely independent test set, since these predictors “have already seen” the left out samples.)

作者給出一個可以驗證的例子:

考慮包含50個樣本的二分類集合,兩種類別的樣本數相同,5000個服從標準正態分佈的自變量且均與類別標籤獨立。這種情況下,理論上使用任何模型得到的誤差都應該是50%

如果此時我們使用上述方法1找出100個與類別標籤相關性最強的變量,然後僅對這100個變量使用KNN算法,並令K=1,CV得到的誤差僅有3%,遠遠低於真實誤差50%。

作者使用了5-FOLD CV並且計算了CV中每次Validation set 中10個樣本的自變量與類別的相關係數,發現此時相關係數平均值為0.28,遠大於0。

R語言模擬:Cross Validation

而使用第二種方法計算的相關係數遠低於第一種方法。

模擬

我們通過R語言模擬給出一個通過CV估計最優參數的例子,例子為

上一篇

右下圖的延伸。

樣本: 80個樣本,20個自變量,自變量均服從[0,1]均勻分佈,因變量定義為:

Y = ifelse(X1+X2+...+X10>5,1,0)

使用Best Subset Regression建模(就是在線性迴歸的基礎上加一步篩選最優自變量組合),模型唯一待定參數為Subset Size p,即最優的自變量個數。

通過10-Fold CV計算不同參數p下的預測誤差並給出最優的p值,Best Subset Regression可以通過函數regsubsets實現,最終結果如下:

R語言模擬:Cross Validation

對比教材中的結果

R語言模擬:Cross Validation

其中,紅色線為真實的預測誤差,藍色線為10-FOLD CV計算出的誤差,bar為1標準誤。可以得出以下結論:

  1. CV估計的誤差與實際預測誤差的數值大小、趨勢變化基本相同,並且真實預測誤差基本都在CV1標準誤以內。
  2. p = 10附近誤差最小,表明參數p應設定為10,這也與樣本因變量定義相符。

可以直接運行的R代碼

setwd('xxxx')
library(leaps)
library(DAAG)
library(caret)
lm.BestSubSet lm.sub summary(lm.sub)
coef_lm strings_coef_lm x formulas return(formulas)
}
# ==================== get error ===============================
getError set.seed(seeds)
testset
Allfx_hat Ally Allfx
# 模擬 num次
for (i in 1:num){
trainset fx_train trainset[,6] +trainset[,7] +trainset[,8] +trainset[,9] +trainset[,10]>5,1,0)
trainset[,21] fx_test testset[,6] +testset[,7] +testset[,8] +testset[,9] +testset[,10]>5,1,0)

testset[,21]

# best subset
lm.sub probs

Allfx_hat[,i] Ally[,i] Allfx[,i]
}
# 計算方差、偏差等

# irreducible
irreducible SquareBais Variance
# 迴歸或分類兩種情況
if (modeltype == 'reg'){
PredictError }else{
PredictError =0.5,1,0)!=Allfx)
}
result return(result)
}
# -------------------- classification -------------------
modeltype num n_test seeds all_p result for (i in all_p){
result }

# ==================== CV =========================
fun_cv set.seed(seed_num)
folds misrate
for (i in(1:kfold)){

train_cv test_cv
model result
# result
misrate[i] =0.5,1,0) != test_cv$V21)
}
sderror misrate result return(result)
}
plot_error len arrows(x0 = x, y0 = y, x1 = x, y1 = y - sd, col = col, angle = 90, length = len)
arrows(x0 = x, y0 = y, x1 = x, y1 = y + sd, col = col, angle = 90, length = len)
}
# ================================ draw ==============================
# seed = 9,10, 92,65,114, 10912
seed_num = 9
trainset trainset[,21] trainset[,6] +trainset[,7] +trainset[,8] +trainset[,9] +trainset[,10]>5,1,0)
resultcv for (p in 2:20){
resultcv }
png(file = "Cross Validation_large_testset.png")
plot(result$k,result$PredictError,type = 'o',col = 'red',
xlim = c(0,20),ylim = c(0,0.6),xlab = '', ylab ='', lwd = 2)
par(new = T)
plot(resultcv$p,resultcv$misrate,type='o',lwd=2,col='blue',ylim = c(0,0.6),xlim = c(0,20),
xlab = 'Subset Size p', ylab = 'Misclassification Error')
plot_error(resultcv$p,resultcv$misrate,resultcv$sderror,col = 'blue',len = 1)
dev.off()

參考文獻

Ruppert D. The Elements of Statistical Learning: Data Mining, Inference, and Prediction[J]. Journal of the Royal Statistical Society, 2010, 99(466):567-567.

R語言模擬:Cross Validation

R語言模擬:Cross Validation

R語言模擬:Cross Validation

回覆 爬蟲 爬蟲三大案例實戰

回覆 Python 1小時破冰入門

回覆 數據挖掘 R語言入門及數據挖掘

回覆 人工智能 三個月入門人工智能

回覆 數據分析師 數據分析師成長之路

回覆 機器學習 機器學習的商業應用

回覆 數據科學 數據科學實戰

回覆 常用算法 常用數據挖掘算法


分享到:


相關文章: