# DGL 搭建 GNN 模型
# init
import torch.nn as nn | |
from dgl.utils import expand_as_pair # 如果不是 pair 就复制 object 变成 pair | |
class SAGEConv(nn.Module): | |
def __init__(self, | |
in_feats, | |
out_feats, | |
aggregator_type, | |
bias=True, | |
norm=None, | |
activation=None): | |
super(SAGEConv, self).__init__() | |
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) | |
self._out_feats = out_feats | |
self._aggre_type = aggregator_type | |
self.norm = norm | |
self.activation = activation |
# SageConv
这个 aggregate 可以是 mean、LSTM、pooling,放在能 aggregate 就行,甚至可以是 GCN,因为 GCN 实际上也是起到了一个 Aggregate 的作用,然后在传递到下一步,重复这个过程就可以聚合更远的节点信息。
# 构造参数
# 聚合类型:mean、max_pool、lstm、gcn | |
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']: | |
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type)) | |
if aggregator_type == 'max_pool': | |
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) | |
if aggregator_type == 'lstm': | |
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) | |
if aggregator_type in ['mean', 'max_pool', 'lstm']: | |
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) | |
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) | |
self.reset_parameters() | |
def reset_parameters(self): | |
"""重新初始化可学习的参数""" | |
gain = nn.init.calculate_gain('relu') | |
if self._aggre_type == 'max_pool': | |
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) | |
if self._aggre_type == 'lstm': | |
self.lstm.reset_parameters() | |
if self._aggre_type != 'gcn': | |
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) | |
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) |
其中的 fc_self
就聚合自身信息, fc_neigh
用于聚合邻居信息。
# 消息传递
import dgl.function as fn | |
import torch.nn.functional as F | |
from dgl.utils import check_eq_shape | |
if self._aggre_type == 'mean': | |
graph.srcdata['h'] = feat_src | |
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) | |
h_neigh = graph.dstdata['neigh'] | |
elif self._aggre_type == 'gcn': | |
check_eq_shape(feat) | |
graph.srcdata['h'] = feat_src | |
graph.dstdata['h'] = feat_dst | |
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) | |
# 除以入度 | |
degs = graph.in_degrees().to(feat_dst) | |
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) | |
elif self._aggre_type == 'max_pool': | |
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) | |
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) | |
h_neigh = graph.dstdata['neigh'] | |
else: | |
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) | |
# GraphSAGE 中 gcn 聚合不需要 fc_self | |
if self._aggre_type == 'gcn': | |
rst = self.fc_neigh(h_neigh) | |
else: | |
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh) |
这段代码写的真的简单易懂。
update_all
负责更新信息, fn.copy_u
就把节点信息给 copy 到 mail 里,然后通过 mean
把邻居节点信息给聚合起来,放到 dstdata
里。 再通过 fc_self
和 fc_neigh
实现节点信息的更新。因为 gcn
在聚合邻居节点信息的时候,实际上也把自身的信息给聚合起来了,所以不需要再在更新步骤上加上 fc_self
。
# 激活函数 | |
if self.activation is not None: | |
rst = self.activation(rst) | |
# 归一化 | |
if self.norm is not None: | |
rst = self.norm(rst) | |
return rst |
最后通过激活函数和归一化层。
# 异构图上的 GraphConv
这个异构图看的时候一直迷迷糊糊的,搞不懂咋回事。。。
先记下来把:
import torch.nn as nn | |
class HeteroGraphConv(nn.Module): | |
def __init__(self, mods, aggregate='sum'): | |
super(HeteroGraphConv, self).__init__() | |
self.mods = nn.ModuleDict(mods) | |
if isinstance(aggregate, str): | |
# 获取聚合函数的内部函数 | |
self.agg_fn = get_aggregate_fn(aggregate) | |
else: | |
self.agg_fn = aggregate | |
def forward(self, g, inputs, mod_args=None, mod_kwargs=None): | |
if mod_args is None: | |
mod_args = {} | |
if mod_kwargs is None: | |
mod_kwargs = {} | |
outputs = {nty : [] for nty in g.dsttypes} | |
if g.is_block: | |
src_inputs = inputs | |
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()} | |
else: | |
src_inputs = dst_inputs = inputs | |
for stype, etype, dtype in g.canonical_etypes: | |
rel_graph = g[stype, etype, dtype] | |
if rel_graph.num_edges() == 0: | |
continue | |
if stype not in src_inputs or dtype not in dst_inputs: | |
continue | |
dstdata = self.mods[etype]( | |
rel_graph, | |
(src_inputs[stype], dst_inputs[dtype]), | |
*mod_args.get(etype, ()), | |
**mod_kwargs.get(etype, {})) | |
outputs[dtype].append(dstdata) |