膠囊圖神經(jīng)網(wǎng)絡(CapsGNN)是在GNN啟發(fā)下誕生了基于圖片分類的新框架。CapsGNN在10個數(shù)據(jù)集中的6個的表現(xiàn)排名位居前兩名。與所有其他端到端架構(gòu)相比,CapsGNN在所有社交數(shù)據(jù)集中均名列首位。
本日Reddit上熱議的一個話題是名為“膠囊圖神經(jīng)網(wǎng)絡”(CapsGNN)的新框架。從名字不難看出,它是受圖神經(jīng)網(wǎng)絡(GNN)的啟發(fā),在其基礎上改進而來的成果。
CapsGNN框架的作者為新加坡南洋理工大學電氣與電子工程學院的Zhang Xinyi和Lihui Chen,該研究的論文將在ICLR 2019上發(fā)表。
目前,從圖神經(jīng)網(wǎng)絡(GNN)中學到的高質(zhì)量節(jié)點嵌入已經(jīng)應用于各種基于節(jié)點的應用程序中,其中一些程序已經(jīng)實現(xiàn)了最先進的性能。不過,當應用程序用GNN學習的節(jié)點嵌入來生成圖形嵌入時,標量節(jié)點表示可能不足以有效地保留節(jié)點或圖形的完整屬性,從而導致圖形嵌入的性能達不到最優(yōu)。
膠囊圖神經(jīng)網(wǎng)絡(CapsGNN)受到了膠囊神經(jīng)網(wǎng)絡的啟發(fā),利用膠囊的概念來解決現(xiàn)有基于GNN的圖嵌入算法的缺點。CapsGNN以膠囊形式對節(jié)點特征進行提取,利用路由機制來捕獲圖形級別的重要信息。因此,模型會為每個圖生成多個嵌入,從多個不同方面捕獲圖的屬性。
CapsGNN中包含的注意力模塊可用于處理各種尺寸的圖,讓模型能夠?qū)W⑻幚韴D的關鍵部分。通過對10個圖結(jié)構(gòu)數(shù)據(jù)集的廣泛評估表明,CapsGNN具有強大的機制,可通過數(shù)據(jù)驅(qū)動捕獲整個圖的宏觀屬性。在幾個圖分類任務上的性能優(yōu)于其他SOTA技術(shù)。
膠囊圖神經(jīng)網(wǎng)絡基本架構(gòu)
上圖所示為CapsGNN的簡化版本。它由三個關鍵模塊組成:1)基本節(jié)點膠囊提取模塊:GNN用于提取具有不同感受野的局部頂點特征,然后在該模塊中構(gòu)建主節(jié)點膠囊。 2)高級圖膠囊提取模塊:融合了注意力模塊和動態(tài)路由,以生成多個圖膠囊。 3)圖分類模塊:再次利用動態(tài)路由,生成用于圖分類的類膠囊。
注意力模塊
在CapsGNN中,基于每個節(jié)點提取主膠囊,即主膠囊的數(shù)量取決于輸入圖的大小。在這種情況下,如果直接應用路由機制,則生成的高級別的膠囊的值將高度依賴于主膠囊的數(shù)量(圖大?。?,這種情況并不理想。因此,實驗引入一個注意力模塊來解決這個問題。
注意力模塊架構(gòu)。首先壓平主膠囊,利用兩層全連接神經(jīng)網(wǎng)絡產(chǎn)生每個膠囊的注意力值。利用基于節(jié)點的歸一化(對每行進行歸一化)來生成最終注意力值。 將標準化值與主膠囊相乘來計算標度膠囊。
實驗設置與結(jié)果
我們驗證了從CapsGNN中提取的圖嵌入與大量SOTA方法的性能,與一些經(jīng)典方法的最優(yōu)性能做了對比。此外還進行了實驗研究,評估膠囊對圖編碼特征效率的影響。我們對生成的圖/類膠囊進行了簡要分析。實驗結(jié)果和分析如下所示。
表1為生物數(shù)據(jù)集的實驗結(jié)果,表2為社會數(shù)據(jù)集的實驗結(jié)果。對于每個數(shù)據(jù)集,以粗體突出顯示前2個準確度。
與所有其他算法相比,CapsGNN在10個數(shù)據(jù)集中的6個的表現(xiàn)排名位居前兩名,并且在其他數(shù)據(jù)集上也實現(xiàn)了基本相當?shù)慕Y(jié)果。與所有其他端到端架構(gòu)相比,CapsGNN在所有社交數(shù)據(jù)集中均名列首位。
表1:生物數(shù)據(jù)集的實驗結(jié)果
表2:社交數(shù)據(jù)集的實驗結(jié)果
膠囊的效率
在膠囊的效率測試實驗中,GNN的層數(shù)設置為L = 3,每層的通道數(shù)都設置為Cl = 2。通過調(diào)整節(jié)點的維度(dn)、圖(dg)、膠囊和圖形、膠囊的數(shù)量(P)來構(gòu)造不同的CapsGNN。
表3:膠囊效率評估實驗中經(jīng)過測試的體系結(jié)構(gòu)詳細信息
圖3:特征表示效率的比較。橫軸表示測試架構(gòu)的設置,縱軸表示NCI1的分類精度。
圖膠囊的可視化
分類膠囊的可視化
膠囊圖網(wǎng)絡:基于GNN的高效快捷的新框架
CapsGNN是一個新框架,將膠囊理論融合到GNN中,來實現(xiàn)更高效的圖表示學習。該框架受CapsNet的啟發(fā),在原體系結(jié)構(gòu)中引入了膠囊的概念,在從GNN提取的節(jié)點特征的基礎上,以向量的形式提取特征。
利用CapsGNN,一個圖可以表示為多個嵌入,每個嵌入都可以捕獲不同方面的圖屬性。生成的圖形和類封裝不僅可以保留與分類相關的信息,還可以保留關于圖屬性的其他信息,這些信息可能在后續(xù)流程中用到。CapsGNN是一種新穎、高效且強大的數(shù)據(jù)驅(qū)動方法,可以表示圖形等高維數(shù)據(jù)。
與其他SOTA算法相比,CapsGNN模型在10個圖表分類任務中有6個成功實現(xiàn)了更好或相當?shù)男阅?,在社交?shù)據(jù)集上的表現(xiàn)尤其顯眼。與其他類似的基于標量的體系結(jié)構(gòu)相比,CapsGNN在編碼特征方面更有效,這對于處理大型數(shù)據(jù)集非常有用。
關于開源代碼和模型的一些補充信息
運行環(huán)境
代碼庫在Python 3.5.2中實現(xiàn)。用于開發(fā)的軟件包版本如下:
networkx 1.11tqdm 4.28.1numpy 1.15.4pandas 0.23.4texttable 1.5.0scipy 1.1.0argparse 1.1.0torch 0.4.1torch-scatter 1.1.2torch-sparse 0.2.2torch-cluster 1.2.4torch-geometric 1.0.3torchvision 0.2.1
數(shù)據(jù)集
代碼會從input文件夾中獲取訓練圖,圖存儲形式為JSON。用于測試的圖也存儲為JSON文件。每個節(jié)點id和節(jié)點標簽必須從0開始索引。字典的鍵是存儲的字符串,以使JSON能夠序列化排布。
每個JSON文件都具有以下的鍵值結(jié)構(gòu):
{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]], "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"}, "target": 1}
邊緣鍵(edgeskey)具有邊緣列表值,用于描述連接結(jié)構(gòu)。標簽鍵具有每個節(jié)點的標簽,這些標簽存儲為字典- 在此嵌套字典中,標簽是值,節(jié)點標識符是鍵。目標鍵具有整數(shù)值,該值代表了類成員資格。
輸出
預測結(jié)果保存在output目錄中。每個嵌入都有一個標題和一個帶有圖標識符的列。最后,預測會按標識符列排序。
訓練CapsGNN模型由src /main.py腳本處理,該腳本提供以下命令行參數(shù)。
輸入和輸出選項
--training-graphs STR Training graphs folder. Default is `dataset/train/`. --testing-graphs STR Testing graphs folder. Default is `dataset/test/`. --prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
模型選項
--epochs INT Number of epochs. Default is 10. --batch-size INT Number fo graphs per batch. Default is 32. --gcn-filters INT Number of filters in GCNs. Default is 2. --gcn-layers INT Number of GCNs chained together. Default is 5. --inner-attention-dimension INT Number of neurons in attention. Default is 20. --capsule-dimensions INT Number of capsule neurons. Default is 8. --number-of-capsules INT Number of capsules in layer. Default is 8. --weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6. --lambd FLOAT Regularization parameter. Default is 1.0. --learning-rate FLOAT Adam learning rate. Default is 0.01.
-
神經(jīng)網(wǎng)絡
+關注
關注
42文章
4814瀏覽量
103708 -
數(shù)據(jù)集
+關注
關注
4文章
1224瀏覽量
25462 -
pytorch
+關注
關注
2文章
809瀏覽量
13976 -
GNN
+關注
關注
1文章
31瀏覽量
6567
原文標題:基于GNN,強于GNN:膠囊圖神經(jīng)網(wǎng)絡的PyTorch實現(xiàn) | ICLR 2019
文章出處:【微信號:AI_era,微信公眾號:新智元】歡迎添加關注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
人工神經(jīng)網(wǎng)絡實現(xiàn)方法有哪些?
matlab實現(xiàn)神經(jīng)網(wǎng)絡 精選資料分享
一種新型神經(jīng)網(wǎng)絡結(jié)構(gòu):膠囊網(wǎng)絡
基于PyTorch的深度學習入門教程之使用PyTorch構(gòu)建一個神經(jīng)網(wǎng)絡
PyTorch教程8.1之深度卷積神經(jīng)網(wǎng)絡(AlexNet)

PyTorch教程之循環(huán)神經(jīng)網(wǎng)絡

PyTorch教程之從零開始的遞歸神經(jīng)網(wǎng)絡實現(xiàn)

PyTorch教程9.6之遞歸神經(jīng)網(wǎng)絡的簡潔實現(xiàn)

PyTorch教程10.3之深度遞歸神經(jīng)網(wǎng)絡

PyTorch教程10.4之雙向遞歸神經(jīng)網(wǎng)絡

PyTorch教程16.2之情感分析:使用遞歸神經(jīng)網(wǎng)絡

評論