这篇文章主要记录一下 GNN,图神经网络。

MiniBatch for GCN

图结构和图片、视频等数据有个重要的区别,图的结构是由 G=(Nodes,Edges) 点和边构成,不同的图结构点和边的数量往往不一样。 这对训练造成了麻烦,在模型训练过程,模型往往是一样大小的,比如输入I维度输出O维度的MLP。如何做Batch呢?

小图变大图

再转变回来使用

docs: batch docs: mini-batch

geometric 库的使用

geometric 很不错,已经集成了很多 Graph Neural Network

我的使用代码如下:

from torch_geometric.nn.models import GAT, GraphSAGE, PNA
 
if graph_type == 'gcn':
    self.graph_layer = GCN(
        input_dim=token_dim, 
        hidden_dim=graph_hidden_dim, 
        output_dim=output_dim, 
        num_layers=graph_layers,
        dropout=graph_dropout,
        return_embeds=False
    )
elif graph_type == "gat":
    gat_heads = config.gat_heads
    self.graph_layer = GAT(
        in_channels=token_dim, 
        hidden_channels=graph_hidden_dim,
        out_channels=output_dim, 
        num_layers=graph_layers, 
        heads=gat_heads, 
        dropout=graph_dropout, 
        negative_slope=0.2,   # hypers never change
        add_self_loops=True, 
        bias=True
    )
elif graph_type == "graphsage":
    self.graph_layer = GraphSAGE(
        in_channels=token_dim,
        hidden_channels=graph_hidden_dim,
        num_layers=graph_layers,
        out_channels=output_dim,
        dropout=graph_dropout
    )
elif graph_type == "pna":
    self.graph_layer = PNA(
        in_channels=token_dim,
        hidden_channels=graph_hidden_dim,
        out_channels=output_dim,
        num_layers=graph_layers,
        # aggregators=["mean", "max", "min", "std"],
        # scalers=["identity", "amplification", "attenuation"],
        # deg=None,
        dropout=graph_dropout,
        graph_pooling="sum",
        batch_norm=True,
        cat=True,
        residual=True
    )
else:
    raise ValueError("Unknown graph type: {}".format(graph_type))

scatter

分段平均在 图中的操作:

from torch_geometric.utils import scatter
 
seg_mean = scatter(gcn_output, scatter_idx, dim=0, dim_size=self.num_virtual_tokens*bsz, reduce='mean')   # shape (#bsz * num_virtual_toknes, dim)
 

DiffPool

DiffPool 来自 18 年论文 Hierarchical Graph Representation Learning with Differentiable Pooling。提出了一种简单有效的 图神经网路的 Pooling 方式, diff-pool-figure