機器學習模型從regression到 decision tree, SVM 再到 gradient boosting, Neural network,模型能力越來越強但也越來越難解,只能當作一個黑盒子使用。
如何解釋機器學習模型的預測結果一直是一個值得探討的問題,工作上也常被老闆challenge為何做出這樣的預測,要是預測錯了怎麼解釋?
大概在兩年前開始聽過LIME(Local Interpretable Model-agnostic Explanations)這個概念,最近有緣再碰到,便利用機會把它看過一遍並且實做看看。
Why we need explanation?
可能有些人認為,我會一些feature importance的方法像 stepwise, random forest, lasso等,不就可以知道模型主要是根據那些特徵判斷的嗎? 其實不然,feature importance只能告訴你再建置這個模型時,整體來說哪個自變項 x 對依變項 y 比較重要,但無法解釋每一個單一樣本的預測。
比如說我的模型現在是個醫生,根據病人提供的一些症狀來判斷是否有感冒,而我在建置這個模型時可能發現最重要的特徵是發燒(若發燒則很感冒的可能性很高),但今天的樣本沒發燒但有喉嚨痛+流鼻水,模型可能依然預測為感冒。而 LIME 這篇paper就是想探討這種 by-instance 的 feature importance。
網路上還有一個 promo video 來介紹LIME 做的蠻可愛的<3
LIME(Local Interpretable Model-agnostic Explanations):
顧名思義就是一種local的解釋器,且和原模型無關的一套方法。
以下這張圖大概就是LIME的精神:
這是一個複雜分類器,紅色和藍色表示兩個類別,分類器會把落在紅色區域的點分類為 ‘+’,落在藍色的分類為 ‘o’
今天有一個點被分成’+’(圖上粗體+),LIME 用一個簡單線性模型來解釋為何這個樣本被預測成’+’。
要建置一個線性模型我們要有 y 和 x :
x: 將這個粗體’+’ 進行一些”擾動”得到多個粗體’+’附近的樣本點
y: 將那些經過擾動得到的樣本點經過原本複雜模型運算後得到的機率值
基於與原始’+’的距離給這些擾動的樣本點權重然後就能來fit線性模型了!
這個線性模型如果拿來分類藍色和紅色兩類可能很差,但若只看’粗+’這個樣本點附近,其實和原複雜模型是蠻接近的。
值得注意的是: 這邊只把原本複雜模型當作一個黑盒子,不管你是用什麼演算法,我只要你output的機率值就可以!
那麼接下來的問題就是在要如何 ”擾動” 想解釋的樣本點
- tabular data:
以下是LIME package document 裡的解釋:
For numerical features, perturb them by sampling from a Normal(0,1) and doing the inverse operation of mean-centering and scaling, according to the means and stds in the training data. For categorical features, perturb by sampling according to the training distribution, and making a binary feature that is 1 when the value is the same as the instance being explained.
→數值型的欄位以normal(0,1)的noise來擾動,若為類別型的就以training data的分配(比例)來抽樣 - image data:
先使用super pixel(一種cluster 方法,將像素將近的聚為一群型成一個更大的元素) 來將影像切塊,藉由隨機遮住其中一塊來擾動原本的樣本。
x: 每一個super pixel 的partition是否存在(0,1 的向量)
y: 被遮住部分的影像丟進模型出來的機率值
3. text data:
和影像類似,隨機將一些單字給蓋住生成擾動的文檔
其實這樣看下來讓我覺得LIME 擾動的概念和random forest 的 permutation importance 有點類似,藉由打亂feature來看準確性差多少。而 LIME 藉由擾動樣本點來 fit 模型給出的機率值。
作者定義了一個目標函數來找出這個最佳的local 線性函數:
其中x為當下要解釋的樣本點, f 為原函數,g 為簡單線性函數,G為所有簡單線性函數的集合,\pi _x 代表擾動後的樣本與x 的距離(相似)函數,而omaga 為簡單模型g的複雜度衡量,用來限制g不要太複雜。
定義loss function 及衡量擾動前後樣本離的相似函數:
z: 對於要解釋的樣本x 擾動而得到的擾動樣本點。
z’: z 在可解釋特徵上的表示式(若為影像or 文字則 z’為0,1向量;若為tabular data 則z’ 與z 為相同的向量空間)。
g(z’): 簡單線性模型對於該擾動點在原空間 f(z) 的估計值
舉個例子: 假設我們簡單模型是個degree 3 的線性迴歸模型,並定義相似函數為高斯相似度:
則目標函數就變成:
LIME 的package裡也有提供蠻完整的使用教學,接著就來實作看看~
實作
以下我以鐵達尼號資料試了tabular data,以及ImageNet的 pretrain model 試了影像的資料。
Tabular data:
首先從kaggle download 鐵達尼號的資料並稍微處理一下missing value:
接著來建模型~就選XGBoost來試試:
一切就緒,來隨便選些樣本點用LIME來解釋看看:
sample_ind = 3
y_val.iloc[sample_ind] #0
xgb_np.predict_proba(X_val.iloc[sample_ind])
#[0.5639446 , 0.43605545]
可以看到這筆資料預測值為0(死亡) ,預測也是死亡,我們再來看看這筆資料的內容
X_val.iloc[sample_ind]Pclass 3.000
Age 29.000
SibSp 0.000
Parch 4.000
Fare 21.075
Sex_male 0.000
Nulls_1 1.000
Nulls_2 0.000
Embarked_Q 0.000
Embarked_S 1.000
哦~原來是比年輕妹仔的資料,人家是年輕妹仔耶你怎麼預測人家死掉呢!!(但實際上也的確死掉了)
讓我們來用LIME解釋看看:
可以看到: 雖然你是女的,但是你 pclass = 3 以及妳 embarked= S (Southampton) ,導致你生存機率下降不少。
來稍微分析一下這個資訊:
首先看一下鐵達尼號 pclass的資訊
可以看到pclass=3 分布在底層或邊邊,且看起來佔地不大卻有最多的passengers(1100) 看起來就很難逃!!
接著來看上船的地方,我們先從training data 來看看:
train[['Embarked','Survived']].groupby(['Embarked'],as_index=False).mean()
可以看到,若你從 Cherbourg上船則存活率居然明顯高於 Queenstown和 Southampton,至於為什麼呢? 這應該是要和domain專家合作才能得到比較明確的答案,但我自己亂猜應該跟各城市的經濟狀況有關吧?!所以稍微看一下各港口上船的人他們平均個花多少錢買票(Fare)以及住的艙等:
train[['Embarked','Fare','Pclass']].groupby(['Embarked'],as_index=False).mean()
發現在Cherbourg上船的人,還真的整體都是花比較多錢並且住比較好的pclass!
接著我們來看一個分錯的case:
sample_ind = 0
y_val.iloc[sample_ind] #0
xgb_np.predict_proba(X_val.iloc[sample_ind])
#[0.01956439 0.9804356]
這筆資料實際為0(死亡) ,但卻預測有很高的機率存活(0.98)。用LIME來解釋看看:
exp_xgb = explainer.explain_instance(X_val.iloc[sample_ind].values, xgb_np.predict_proba)
print('截距項:'+str(exp_xgb.intercept)) #{1: 0.08749233993418792}
可以看到,這筆資料不論性別、Pclass、Fare、年齡都傾向預測為存活,故模型預測為1。至於這個點為何實際是死亡,則需要資料外的資訊來評斷了。
Image data:
以下我幾乎都是參考作者github 教學做的 pytorch 範例:
我先從網路上隨便找一張有狗和貓的照片:
import matplotlib.pyplot as plt
from PIL import Imagedef get_image(path):
with open(os.path.abspath(path), 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')img = get_image('./dogs-cats.jpg')
plt.imshow(img)
然後利用pytorch 在IMAGENET的pretrain model(ResNet)來分類看看:
因為LIME 需要 input是 numpy array,並且一個吃numpy array當input 然後output機率值的複雜模型,所以做一些處理,然後build LIME explainer:
解釋為何預測成各個類別並利用Skimage 裡的 mark_boundaries 來把影像區塊給畫出來:
最後得到:
其實CNN 影像分類的問題也可以透過GAP 及其進階grad-cam這一套來得知每一層layer是學到那些特徵,因為看到哪個區域而做出這樣的預測。
透過這個 repository 來做做看這張圖:
在查 LIME的一些資料時還有發現SHAP(SHapley Additive exPlanations) 這套比較新的方法,在使用上也蠻容易的,以下我直接來套用到上面的鐵達尼號資料來解釋是女生卻死掉的樣本點:
可以看到結果和 LIME 類似,但因為方法還不太了解,所以就先略過~
Conclusion
LIME這個方法的概念還蠻直觀的,只是他有一些限制:
1. 首先你的複雜模型必須要夠準(要不然根本不值得拿來參考fit簡單模型)
2. 簡單線性模型有可能under-fitting
3. 針對不同類型的資料需要客製一種擾動的資料表示式(若非數值型、影像、文字分類的問題)
但對於一些基本的ML 問題,LIME 提供了一個解釋方法,相信未來在應用ML演算法時會有用處的!
Reference
paper: “Why Should I Trust You?”: Explaining the Predictions of Any Classifier