# 消息传递范式

# 消息传递范式定义

边上计算: me(t+1)=ϕ(xv(t),xu(t),we(t)),(u,v,e)Em_{e}^{(t+1)}=\phi\left(x_{v}^{(t)}, x_{u}^{(t)}, w_{e}^{(t)}\right),(u, v, e) \in \mathcal{E}.
点上计算: xv(t+1)=ψ(xv(t),ρ({me(t+1):(u,v,e)E}))x_{v}^{(t+1)}=\psi\left(x_{v}^{(t)}, \rho\left(\left\{m_{e}^{(t+1)}:(u, v, e) \in \mathcal{E}\right\}\right)\right)

# 消息函数ϕ\phi

我的理解其实这个me(t+1)=ϕ(xv(t),xu(t),we(t)),(u,v,e)Em_{e}^{(t+1)}=\phi\left(x_{v}^{(t)}, x_{u}^{(t)}, w_{e}^{(t)}\right),(u, v, e) \in \mathcal{E} 说白了就是在使用一个聚合方式,这个就是边聚合,也就是消息函数,包括 add、sub、mul、div、dot 等。

DGL 的消息函数定义ϕ\phi,是 dgl.function.u_add_v (' 源 ',' 目标 ',' 边 ')

def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}
# 聚合函数ρ\rho

然后点上的计算就是聚合函数ρ\rho,有 sum、max、min 和 mean。

DGL 的聚合定义是 dgl.function.sum ('mailbox', 'h')

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
# 更新函数ψ\psi

更新函数是一个可选择的函数。

# 消息传递函数

update_all 是 dgl 定义的 api,用法:

def updata_all_example(graph):
    # 在 graph.ndata ['ft'] 中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在 update_all 外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

# 编写消息传递代码

对于一些情况下,必须保存边上的信息,所以可以用 apply_edges 的方法。

import torch
import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges):
     return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] @ linear

通过一个 linear 线性层,实现一个简单的降维。建议是不要保留 concat 的 feat,分开来计算。

import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

# 图上的一部分进行消息

用节点编号创建一个子图,然后在子图上调用 update_all () 方法

nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)

# 消息传递中使用边的权重

GAT(图注意力网络)和 GCN 的变种需要使用边的权重。GAT 好理解,就是需要 attention 来聚合。

import dgl.function as fn
# 假定 eweight 是一个形状为 (E, *) 的张量,E 是边的数量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                 fn.sum('m', 'ft'))

u_mul_e : u 是节点、e 是边,于是对应的意思就是节点 ft 的 ndata 乘上 edata 这个 api 的命名思路值得学习。

# 在异构图上进行消息传递

异构图的消息传递分为两个部分:

  1. 对每个关系计算和聚合消息
  2. 对每个节点聚合来自不同关系的消息

异构图的信息传递 api 是 multi_update_all ,示例代码:

import dgl.function as fn
for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # 把它存在图中用来做消息传递
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # 指定每个关系的消息传递函数:(message_func, reduce_func).
    # 注意结果保存在同一个目标特征 “h”,说明聚合是逐类进行的。
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 将每个类型消息聚合的结果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新过的节点特征字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
更新于

请我喝[茶]~( ̄▽ ̄)~*

Kalice 微信支付

微信支付

Kalice 支付宝

支付宝