# 消息传递范式
# 消息传递范式定义
边上计算: .
点上计算:
# 消息函数
我的理解其实这个 说白了就是在使用一个聚合方式,这个就是边聚合,也就是消息函数,包括 add、sub、mul、div、dot 等。
DGL 的消息函数定义,是 dgl.function.u_add_v (' 源 ',' 目标 ',' 边 ')
def message_func(edges): | |
return {'he': edges.src['hu'] + edges.dst['hv']} |
# 聚合函数
然后点上的计算就是聚合函数,有 sum、max、min 和 mean。
DGL 的聚合定义是 dgl.function.sum ('mailbox', 'h')
import dgl.function as fn | |
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) |
# 更新函数
更新函数是一个可选择的函数。
# 消息传递函数
update_all
是 dgl 定义的 api,用法:
def updata_all_example(graph): | |
# 在 graph.ndata ['ft'] 中存储结果 | |
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), | |
fn.sum('m', 'ft')) | |
# 在 update_all 外调用更新函数 | |
final_ft = graph.ndata['ft'] * 2 | |
return final_ft |
# 编写消息传递代码
对于一些情况下,必须保存边上的信息,所以可以用 apply_edges
的方法。
import torch | |
import torch.nn as nn | |
linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim))) | |
def concat_message_function(edges): | |
return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])} | |
g.apply_edges(concat_message_function) | |
g.edata['out'] = g.edata['cat_feat'] @ linear |
通过一个 linear 线性层,实现一个简单的降维。建议是不要保留 concat 的 feat,分开来计算。
import dgl.function as fn | |
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim))) | |
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim))) | |
out_src = g.ndata['feat'] @ linear_src | |
out_dst = g.ndata['feat'] @ linear_dst | |
g.srcdata.update({'out_src': out_src}) | |
g.dstdata.update({'out_dst': out_dst}) | |
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out')) |
# 图上的一部分进行消息
用节点编号创建一个子图,然后在子图上调用 update_all () 方法
nid = [0, 2, 3, 6, 7, 9] | |
sg = g.subgraph(nid) | |
sg.update_all(message_func, reduce_func, apply_node_func) |
# 消息传递中使用边的权重
GAT(图注意力网络)和 GCN 的变种需要使用边的权重。GAT 好理解,就是需要 attention 来聚合。
import dgl.function as fn | |
# 假定 eweight 是一个形状为 (E, *) 的张量,E 是边的数量。 | |
graph.edata['a'] = eweight | |
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), | |
fn.sum('m', 'ft')) |
u_mul_e
: u 是节点、e 是边,于是对应的意思就是节点 ft
的 ndata 乘上 edata
这个 api 的命名思路值得学习。
# 在异构图上进行消息传递
异构图的消息传递分为两个部分:
- 对每个关系计算和聚合消息
- 对每个节点聚合来自不同关系的消息
异构图的信息传递 api 是 multi_update_all
,示例代码:
import dgl.function as fn | |
for c_etype in G.canonical_etypes: | |
srctype, etype, dsttype = c_etype | |
Wh = self.weight[etype](feat_dict[srctype]) | |
# 把它存在图中用来做消息传递 | |
G.nodes[srctype].data['Wh_%s' % etype] = Wh | |
# 指定每个关系的消息传递函数:(message_func, reduce_func). | |
# 注意结果保存在同一个目标特征 “h”,说明聚合是逐类进行的。 | |
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) | |
# 将每个类型消息聚合的结果相加。 | |
G.multi_update_all(funcs, 'sum') | |
# 返回更新过的节点特征字典 | |
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes} |