Few Shot Learning - Siamese Network

CW Lin
11 min readJun 6, 2021

透過few shot learning 來打造生物識別(人臉、聲紋、手寫…)模型!

Photo by Green Chameleon on Unsplash

由於疫情,近期都很少出門宅在家,就來記錄一下最近有用到的技術吧~

一般的分類問題往往都是屬於類別不多且每個類別資料量很多任務,比如MNIST 手寫資料集或 imagenet 的影像分類問題。但在生物識別的task 上我們往往沒辦法收集到那麼多資料,比如說我要建立一個人臉識別的模型,我應該不太可能跟每個我要識別的人都收集大量的照片,況且世界上的人那麽多,我也不大可能收集所有人的照片來建立分類模型。

Few shot learning 算是 meta-learning 的其中一塊,核心概念是讓模型學會學習(learn to learn)。這樣說有點懸,我們可以把它理解成: few shot learning 是要讓模型學會區分事物的差異。一個學會區分事物差異的模型,我們可以把它用在訓練集從未見過的新類別,並且可以只透過很少的樣本(few shot) 就學會區別此事物。

Siamese Network

Siamese 這個詞是孿生、連體嬰的意思,表示兩個人身體相連且共享部分的器官。而siamese network 是只有兩個架構權重都相同的類神經網路組合在一起(如下右圖)

可以看到這個網路的input 是一個image pairs,而我們的目標是要訓練一個能夠區分事物差異的網路。想必聰明的你已經想到要如何使用這個網路結構了!

首先我們要準備很多positive samples 以及 negative samples,分別表示相同類別的 image pairs 以及不同類別的 image pairs:

https://www.youtube.com/watch?v=UkQ2FVpDxHg

而 siamese network 就是要預測 input 的 image pairs 是否為相同類別,所以說就是一個binary classification 的問題!

實際要搭建siamese network 也非常的簡單,以pytorch 做一個範例:

class siameseNet(nn.Module):
def __init__(self, embedding_net):
super(siameseNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x1, x2):
output1 = self.embedding_net(x1)
output2 = self.embedding_net(x2)
return output1, output2
def get_embedding(self, x):
return self.embedding_net(x)

其中 embedding_net 是任何你自己搭建的CNN 網路,最後再把output1, output2 喂給 loss function backward 即可。

預測時可以調用 get_embedding() 就不用每次都 forward 兩張image了~

Contrastive loss

目標: 使相同類別的embedding 越接近越好,不同類別的embedding 越遠越好,用這個觀點來看下面的式子就會非常直觀了

Dw 表示兩embedding 之距離(歐式距離)
class ContrastiveLoss(nn.Module):   def __init__(self, margin):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.eps = 1e-9
def forward(self, output1, output2, target, size_average=True):
distances = (output2 - output1).pow(2).sum(1) # squared distances
losses = 0.5 * (target.float() * distances +
(1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
return losses.mean() if size_average else losses.sum()

Triplet loss

Triplet Loss是Google 在 2015 年發表的 FaceNet 論文中提出。可視為Contrastive loss 的改良。

triplet loss 必須建構在三元的image pair 下才能計算,搭配的網路架構如下

可以看到triplet loss 的做法直接喂給模型一個positive 以及一個 negative sample 來訓練,目標一樣是期望positive 能越接近anchor 而 negative 能越遠離anchor:

triplet loss
class TripletLoss(nn.Module):   def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, size_average=True):
distance_positive = (anchor - positive).pow(2).sum(1)
distance_negative = (anchor - negative).pow(2).sum(1)
losses = F.relu(distance_positive - distance_negative + self.margin)
return losses.mean() if size_average else losses.sum()

要搭建triplet 的 network 其實和siamese 幾乎一模一樣,只是input, output 變成三個而已

class TripletNet(nn.Module):
def __init__(self, embedding_net):
super(TripletNet, self).__init__()
self.embedding_net = embedding_net
def forward(self, x1, x2, x3):
output1 = self.embedding_net(x1)
output2 = self.embedding_net(x2)
output3 = self.embedding_net(x3)
return output1, output2, output3
def get_embedding(self, x):
return self.embedding_net(x)

透過contrastive loss or triplet loss 訓練出來的 NN 已經學會了區分事物的差異,我們可以透過這個 NN 將影像embedding 到一個很具鑑別力的空間,在這個空間計算embedding 向量的距離來判斷是否為相同類別的事物。

所以不論人臉辨識,聲紋辨識,手寫辨識…等,我們首先需要大量資料來訓練這個embedding 的類神經網路;將來部署預測時每個類別只需要一筆資料把它轉換到這個embedding 的高維度空間中,便完成該類別的建模了。

以人臉為例,我們將要辨識的人臉forward到 embedding 空間,然後計算與之前已 ”註冊” 的人臉看看是否夠接近,若有超過門檻值則判為同一人:

以 MNIST 手寫資料及為例來比較看看相同的網路架構以不同訓練方式所得到的embedding 空間:

# network structure
class embedding(nn.Module):
def __init__(self, input_size, embed_dim=2):
super(embedding, self).__init__()
self.embed_dim = embed_dim
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.LeakyReLU(0.1,inplace=True),
nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.BatchNorm2d(32), nn.LeakyReLU(0.1,inplace=True),

nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.1,inplace=True),
nn.Conv2d(64, 64, 3, padding=1,stride=2), nn.BatchNorm2d(64), nn.LeakyReLU(0.1,inplace=True),
)
self.fc = nn.Linear(64*7*7, self.embed_dim)

def forward(self, x):
x = self.conv(x)
x = x.view(x.size()[0], -1)
x = self.fc(x)
return x

用上面的 siameseNet, TripletNet 訓練後使用get_embedding() 得到以下的結果:

Margin Based Classification

以上的siamese network 算是一種 metric based learning的訓練方式,我們的目標其實是要讓類神經網路的embedding 空間更具鑑別力。

Margin Based Classficiation 不像在 feature 層直接計算損失的 Metric Learning 那樣,對 feature 加直觀的強限制,而是依然把人臉識別當 classification 任務進行訓練,通過對 softmax 公式的改進,實現了對 feature 層施加 margin 的限制,使網絡最後得到的 feature 更具鑑別力。訓練好的網路再把最後的classification layer 拿掉,只用前面的CNN做feature extraction

這一塊有些一系列的做法 sphereface, cosineface, arcface,分別在不同的地方對 softmax 加上margin 限制,具體細節有在 人臉辨識, Face recognition (ArcFace) 這篇文章裡描述。

Conclusion

metric learning 的概念直觀,直接針對verification的目的進行訓練,但其缺點是有時訓練很慢很難收斂,一方面是窮舉每種 image pair(or triplet pair) 會讓資料量變很多,另一方面是訓練的目標比單純分類問題還要困難,導致比較難收斂。因此有時候有人會先用分類模型先把模型訓練到一定程度後再用 triplet loss 來 fine tune。

Reference

Comparing images for similarity using siamese networks

如何走近深度學習人臉識別?你需要這篇超長綜述

Few-Shot Learning

SigNet: Convolutional Siamese Network for Writer Independent Offline Signature Verification

--

--