DGL 搭建 GNN 模型
import torch.nn as nn | |
from dgl.utils import expand_as_pair # 如果不是 pair 就复制 object 变成 pair | |
class SAGEConv(nn.Module): | |
def __init__(self, | |
in_feats, | |
out_feats, | |
aggregator_type, | |
bias=True, | |
norm=None, | |
activation=None): | |
super(SAGEConv, self).__init__() | |
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) | |
self._out_feats = out_feats | |
self._aggre_type = aggregator_type | |
self.norm = norm | |
self.activation = activation |