作者:Manjunath Bhat
編譯:ronghuaiyang
導讀
對空間變換網絡STN做了一個簡單的原理性的介紹。
作為谷歌Summer of Code項目的一部分,我要實現的第一個模型是空間變壓器網絡。空間變壓器網絡(STN)是一個可學習的模塊,可以放置在卷積神經網絡(CNN)中,有效地增加空間不變性??臻g不變性是指模型對圖像的空間變換如旋轉、平移和縮放不變性。不變性是指即使輸入被變換或輕微修改,模型也能識別和識別特征的能力。空間變壓器可以放置到CNN中,以完成各種任務。圖像分類就是一個例子。假設任務是對手寫數字進行分類,每個樣本中數字的位置、大小和方向變化顯著。一個空間轉換器將提取、變換和縮放樣本中感興趣的區域。現在CNN可以完成分類的任務。
空間變壓器網絡由3個主要組成部分組成:
(i) 定位網絡:該網絡以一個batch的圖像的四維張量表示(寬度x高度x通道x Batch_Size)作為輸入。它是一個簡單的神經網絡,有幾個卷積層和幾個dense層。將變換參數預測為輸出。這些參數決定了輸入必須旋轉的角度、要完成的平移量以及聚焦于輸入特征圖中感興趣的區域所需的比例因子。
(ii) 采樣網格生成器:對batch中每幅圖像使用定位網絡預測的變換參數,其形式為大小為2×3的仿射變換矩陣。仿射變換是一種保留點、直線和平面的變換。經過仿射變換后,平行線保持平行。旋轉、縮放和平移都是仿射變換。
這里,T是這個仿射變換,A是表示仿射變換的矩陣。θ11, θ12, θ21, θ22被用來確定圖像旋轉的角度。θ13, θ23分別確定了圖像沿寬度和高度的平移量。因此,我們得到了一個轉換索引的采樣網格。
(iii) 變換后索引上的雙線性插值:現在圖像的索引和坐標軸已經進行了仿射變換。它的像素移動了。例如,一個點(1,1)在軸逆時針旋轉45度后變成(√2,0),因此要找到變換點處的像素值,我們需要使用四個最接近的像素值進行雙線性插值。
為了找到點(x, y)上的像素值,我們取4個最近的點,如上圖所示。其中,floor(x)表示最大整數函數,ceil(x)表示ceiling函數。線性插值必須在x和y兩個方向上完成。因此,這個函數返回完全轉換后的圖像,并在轉換索引處使用適當的像素值。
純Julia實現空間變壓器網絡的代碼可以在這里找到:https://github.com/thebhatman/Spatial-Transformer-Network/blob/master/src/stn.jl。我在一些圖像上測試了我的空間轉換器模塊的功能。下面是轉換函數輸出的一些示例圖像。左邊的圖像是轉換器模塊的輸入,右邊的圖像是輸出。
- 放大感興趣的區域
- 對人臉進行放大并旋轉45度。
- 對圖像沿著寬度平移,移到中心。
從上面的例子可以清楚地看出,空間轉換器模塊能夠執行任何類型的仿射變換。在實現過程中,我花了很多時間來理解數組的reshape、permutedims和concatenation是如何工作的,因為當我使用這些函數時,很難調試像素和索引是如何移動的。在STN實現過程中,調試插值和圖像索引是最耗費時間和最令人沮喪的部分。
現在,我計劃使用一個CNN來訓練這個空間轉換器模塊,以便對一個雜亂和扭曲的MNIST數據集進行手寫數字分類??臻g變壓器將能夠增加CNN的空間不變性,因此期望即使在數字被平移、旋轉或縮放時也能給出良好的分類結果。
英文原文:https://medium.com/@manjunathbhat9920/spatial-transformer-network-82666f184299