# 图数据处理管道

# DGLDataset 类

image-20211216095935092

官方给的一个流程图。

# 框架代码

from dgl.data import DGLDataset
class MyDataset(DGLDataset):
    """ 用于在DGL中自定义图数据集的模板:
    Parameters
    ----------
    url : str
        下载原始数据集的url。
    raw_dir : str
        指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/
    save_dir : str
        处理完成的数据集的保存目录。默认:raw_dir指定的值
    force_reload : bool
        是否重新导入数据集。默认:False
    verbose : bool
        是否打印进度信息。
    """
    def __init__(self,
                 url=None,
                 raw_dir=None,
                 save_dir=None,
                 force_reload=False,
                 verbose=False):
        super(MyDataset, self).__init__(name='dataset_name',
                                        url=url,
                                        raw_dir=raw_dir,
                                        save_dir=save_dir,
                                        force_reload=force_reload,
                                        verbose=verbose)
    def download(self):
        # 将原始数据下载到本地磁盘
        pass
    def process(self):
        # 将原始数据处理为图、标签和数据集划分的掩码
        pass
    def __getitem__(self, idx):
        # 通过 idx 得到与之对应的一个样本
        pass
    def __len__(self):
        # 数据样本的数量
        pass
    def save(self):
        # 将处理后的数据保存至 `self.save_path`
        pass
    def load(self):
        # 从 `self.save_path` 导入处理后的数据
        pass
    def has_cache(self):
        # 检查在 `self.save_path` 中是否存有处理后的数据
        pass

继承 DGLDataset 的子类必须实现的函数: process__getitem____len__

DGLDataset 做的事情:存储有关数据集的图、特征、标签、掩码,以及诸如类别数、标签数等基本信息

不应该做的事情: 诸如采样、划分或特征归一化等操作建议在 DGLDataset 子类之外完成。

# 下载数据

# 直接下载:
import os
from dgl.data.utils import download
def download(self):
    # 存储文件的路径
    file_path = os.path.join(self.raw_dir, self.name + '.mat')
    # 下载文件
    download(self.url, path=file_path)
# 下载并解压:
from dgl.data.utils import download, check_sha1
def download(self):
    # 存储文件的路径,请确保使用与原始文件名相同的后缀
    gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
    # 下载文件
    download(self.url, path=gz_file_path)
    # 检查 SHA-1
    if not check_sha1(gz_file_path, self._sha1_str):
        raise UserWarning('File {} is downloaded but the content hash does not match.'
                          'The repo may be outdated or download may be incomplete. '
                          'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
    # 将文件解压缩到目录 self.raw_dir 下的 self.name 目录中
    self._extract_gz(gz_file_path, self.raw_path)

解压文件会被解压缩到 self.raw_dir 下的目录 self.name 中。

# 处理数据

# 处理整图分类数据集

QM7bDataset

from dgl.data import DGLDataset
class QM7bDataset(DGLDataset):
    _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
           'datasets/qm7b.mat'
    _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
    def __init__(self, raw_dir=None, force_reload=False, verbose=False):
        super(QM7bDataset, self).__init__(name='qm7b',
                                          url=self._url,
                                          raw_dir=raw_dir,
                                          force_reload=force_reload,
                                          verbose=verbose)
    def process(self):
        mat_path = self.raw_path + '.mat'
        # 将数据处理为图列表和标签列表
        self.graphs, self.label = self._load_graph(mat_path)
    def __getitem__(self, idx):
        """ 通过idx获取对应的图和标签
        Parameters
        ----------
        idx : int
            Item index
        Returns
        -------
        (dgl.DGLGraph, Tensor)
        """
        return self.graphs[idx], self.label[idx]
    def __len__(self):
        """数据集中图的数量"""
        return len(self.graphs)
# 使用数据集
import dgl
import torch
from dgl.dataloading import GraphDataLoader
# 数据导入
dataset = QM7bDataset()
num_labels = dataset.num_labels
# 创建 dataloaders
dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True)
# 训练
for epoch in range(100):
    for g, labels in dataloader:
        # 用户自己的训练代码
        pass
# 节点分类任务
# 使用节点分类数据集:
# 导入数据
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]
# 获取划分的掩码
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
# 获取节点特征
feats = graph.ndata['feat']
# 获取标签
labels = graph.ndata['label']
# 可用于节点分类的数据集:
  • Citation network dataset
  • CoraFull dataset
  • Amazon Co-Purchase dataset
  • Coauthor dataset
  • Karate club dataset
  • Protein-Protein Interaction dataset
  • Reddit dataset
  • Symmetric Stochastic Block Model Mixture dataset
  • Stanford sentiment treebank dataset
  • RDF datasets
# 链接预测数据集
# 实现代码
# 创建链接预测数据集示例
class KnowledgeGraphDataset(DGLBuiltinDataset):
    def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
        self._name = name
        self.reverse = reverse
        url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
        super(KnowledgeGraphDataset, self).__init__(name,
                                                    url=url,
                                                    raw_dir=raw_dir,
                                                    force_reload=force_reload,
                                                    verbose=verbose)
    def process(self):
        # 跳过一些处理的代码
        # === 跳过数据处理 ===
        # 划分掩码
        g.edata['train_mask'] = train_mask
        g.edata['val_mask'] = val_mask
        g.edata['test_mask'] = test_mask
        # 边类型
        g.edata['etype'] = etype
        # 节点类型
        g.ndata['ntype'] = ntype
        self._g = g
    def __getitem__(self, idx):
        assert idx == 0, "这个数据集只有一个图"
        return self._g
    def __len__(self):
        return 1
# 使用
from dgl.data import FB15k237Dataset
# 导入数据
dataset = FB15k237Dataset()
graph = dataset[0]
# 获取训练集掩码
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
src, dst = graph.edges(train_idx)
# 获取训练集中的边类型
rel = graph.edata['etype'][train_idx]

# 保存和加载数据

import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_info
def save(self):
    # 保存图和标签
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    save_graphs(graph_path, self.graphs, {'labels': self.labels})
    # 在 Python 字典里保存其他信息
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    save_info(info_path, {'num_classes': self.num_classes})
def load(self):
    # 从目录 `self.save_path` 里读取处理过的数据
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    self.graphs, label_dict = load_graphs(graph_path)
    self.labels = label_dict['labels']
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    self.num_classes = load_info(info_path)['num_classes']
def has_cache(self):
    # 检查在 `self.save_path` 里是否有处理过的数据文件
    graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
    info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
    return os.path.exists(graph_path) and os.path.exists(info_path)

# OBG

Open Graph Benchmark (OGB) 是一个图深度学习的基准数据集。使用 pip 安装 ogb 就可以使用了。

# 图分类任务
# 载入 OGB 的 Graph Property Prediction 数据集
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from dgl.dataloading import GraphDataLoader
def _collate_fn(batch):
    # 小批次是一个元组 (graph, label) 列表
    graphs = [e[0] for e in batch]
    g = dgl.batch(graphs)
    labels = [e[1] for e in batch]
    labels = torch.stack(labels, 0)
    return g, labels
# 载入数据集
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
# 节点分类任务
# 载入 OGB 的 Node Property Prediction 数据集
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset(name='ogbn-proteins')
split_idx = dataset.get_idx_split()
# there is only one graph in Node Property Prediction datasets
# 在 Node Property Prediction 数据集里只有一个图
g, labels = dataset[0]
# 获取划分的标签
train_label = dataset.labels[split_idx['train']]
valid_label = dataset.labels[split_idx['valid']]
test_label = dataset.labels[split_idx['test']]
# 链接预测任务
# 载入 OGB 的 Link Property Prediction 数据集
from ogb.linkproppred import DglLinkPropPredDataset
dataset = DglLinkPropPredDataset(name='ogbl-ppa')
split_edge = dataset.get_edge_split()
graph = dataset[0]
print(split_edge['train'].keys())
print(split_edge['valid'].keys())
print(split_edge['test'].keys())