# 图数据处理管道
# DGLDataset 类
官方给的一个流程图。
# 框架代码
| from dgl.data import DGLDataset |
| |
| class MyDataset(DGLDataset): |
| """ 用于在DGL中自定义图数据集的模板: |
| |
| Parameters |
| ---------- |
| url : str |
| 下载原始数据集的url。 |
| raw_dir : str |
| 指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/ |
| save_dir : str |
| 处理完成的数据集的保存目录。默认:raw_dir指定的值 |
| force_reload : bool |
| 是否重新导入数据集。默认:False |
| verbose : bool |
| 是否打印进度信息。 |
| """ |
| def __init__(self, |
| url=None, |
| raw_dir=None, |
| save_dir=None, |
| force_reload=False, |
| verbose=False): |
| super(MyDataset, self).__init__(name='dataset_name', |
| url=url, |
| raw_dir=raw_dir, |
| save_dir=save_dir, |
| force_reload=force_reload, |
| verbose=verbose) |
| |
| def download(self): |
| |
| pass |
| |
| def process(self): |
| |
| pass |
| |
| def __getitem__(self, idx): |
| |
| pass |
| |
| def __len__(self): |
| |
| pass |
| |
| def save(self): |
| |
| pass |
| |
| def load(self): |
| |
| pass |
| |
| def has_cache(self): |
| |
| pass |
继承 DGLDataset
的子类必须实现的函数: process
、 __getitem__
、 __len__
。
在 DGLDataset
做的事情:存储有关数据集的图、特征、标签、掩码,以及诸如类别数、标签数等基本信息
不应该做的事情: 诸如采样、划分或特征归一化等操作建议在 DGLDataset
子类之外完成。
# 下载数据
# 直接下载:
| import os |
| from dgl.data.utils import download |
| |
| def download(self): |
| |
| file_path = os.path.join(self.raw_dir, self.name + '.mat') |
| |
| download(self.url, path=file_path) |
# 下载并解压:
| from dgl.data.utils import download, check_sha1 |
| |
| def download(self): |
| |
| gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz') |
| |
| download(self.url, path=gz_file_path) |
| |
| if not check_sha1(gz_file_path, self._sha1_str): |
| raise UserWarning('File {} is downloaded but the content hash does not match.' |
| 'The repo may be outdated or download may be incomplete. ' |
| 'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz')) |
| |
| self._extract_gz(gz_file_path, self.raw_path) |
解压文件会被解压缩到 self.raw_dir
下的目录 self.name
中。
# 处理数据
# 处理整图分类数据集
QM7bDataset
| from dgl.data import DGLDataset |
| |
| class QM7bDataset(DGLDataset): |
| _url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \ |
| 'datasets/qm7b.mat' |
| _sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392' |
| |
| def __init__(self, raw_dir=None, force_reload=False, verbose=False): |
| super(QM7bDataset, self).__init__(name='qm7b', |
| url=self._url, |
| raw_dir=raw_dir, |
| force_reload=force_reload, |
| verbose=verbose) |
| |
| def process(self): |
| mat_path = self.raw_path + '.mat' |
| |
| self.graphs, self.label = self._load_graph(mat_path) |
| |
| def __getitem__(self, idx): |
| """ 通过idx获取对应的图和标签 |
| |
| Parameters |
| ---------- |
| idx : int |
| Item index |
| |
| Returns |
| ------- |
| (dgl.DGLGraph, Tensor) |
| """ |
| return self.graphs[idx], self.label[idx] |
| |
| def __len__(self): |
| """数据集中图的数量""" |
| return len(self.graphs) |
# 使用数据集
| import dgl |
| import torch |
| |
| from dgl.dataloading import GraphDataLoader |
| |
| |
| dataset = QM7bDataset() |
| num_labels = dataset.num_labels |
| |
| |
| dataloader = GraphDataLoader(dataset, batch_size=1, shuffle=True) |
| |
| |
| for epoch in range(100): |
| for g, labels in dataloader: |
| |
| pass |
# 节点分类任务
# 使用节点分类数据集:
| |
| dataset = CiteseerGraphDataset(raw_dir='') |
| graph = dataset[0] |
| |
| |
| train_mask = graph.ndata['train_mask'] |
| val_mask = graph.ndata['val_mask'] |
| test_mask = graph.ndata['test_mask'] |
| |
| |
| feats = graph.ndata['feat'] |
| |
| |
| labels = graph.ndata['label'] |
# 可用于节点分类的数据集:
- Citation network dataset
- CoraFull dataset
- Amazon Co-Purchase dataset
- Coauthor dataset
- Karate club dataset
- Protein-Protein Interaction dataset
- Reddit dataset
- Symmetric Stochastic Block Model Mixture dataset
- Stanford sentiment treebank dataset
- RDF datasets
# 链接预测数据集
# 实现代码
| |
| class KnowledgeGraphDataset(DGLBuiltinDataset): |
| def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True): |
| self._name = name |
| self.reverse = reverse |
| url = _get_dgl_url('dataset/') + '{}.tgz'.format(name) |
| super(KnowledgeGraphDataset, self).__init__(name, |
| url=url, |
| raw_dir=raw_dir, |
| force_reload=force_reload, |
| verbose=verbose) |
| |
| def process(self): |
| |
| |
| |
| |
| g.edata['train_mask'] = train_mask |
| g.edata['val_mask'] = val_mask |
| g.edata['test_mask'] = test_mask |
| |
| |
| g.edata['etype'] = etype |
| |
| |
| g.ndata['ntype'] = ntype |
| self._g = g |
| |
| def __getitem__(self, idx): |
| assert idx == 0, "这个数据集只有一个图" |
| return self._g |
| |
| def __len__(self): |
| return 1 |
# 使用
| from dgl.data import FB15k237Dataset |
| |
| |
| dataset = FB15k237Dataset() |
| graph = dataset[0] |
| |
| |
| train_mask = graph.edata['train_mask'] |
| train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() |
| src, dst = graph.edges(train_idx) |
| |
| |
| rel = graph.edata['etype'][train_idx] |
# 保存和加载数据
| import os |
| from dgl import save_graphs, load_graphs |
| from dgl.data.utils import makedirs, save_info, load_info |
| |
| def save(self): |
| |
| graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') |
| save_graphs(graph_path, self.graphs, {'labels': self.labels}) |
| |
| info_path = os.path.join(self.save_path, self.mode + '_info.pkl') |
| save_info(info_path, {'num_classes': self.num_classes}) |
| |
| def load(self): |
| |
| graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') |
| self.graphs, label_dict = load_graphs(graph_path) |
| self.labels = label_dict['labels'] |
| info_path = os.path.join(self.save_path, self.mode + '_info.pkl') |
| self.num_classes = load_info(info_path)['num_classes'] |
| |
| def has_cache(self): |
| |
| graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin') |
| info_path = os.path.join(self.save_path, self.mode + '_info.pkl') |
| return os.path.exists(graph_path) and os.path.exists(info_path) |
# OBG
Open Graph Benchmark (OGB) 是一个图深度学习的基准数据集。使用 pip 安装 ogb 就可以使用了。
# 图分类任务
| |
| import dgl |
| import torch |
| from ogb.graphproppred import DglGraphPropPredDataset |
| from dgl.dataloading import GraphDataLoader |
| |
| def _collate_fn(batch): |
| |
| graphs = [e[0] for e in batch] |
| g = dgl.batch(graphs) |
| labels = [e[1] for e in batch] |
| labels = torch.stack(labels, 0) |
| return g, labels |
| |
| |
| dataset = DglGraphPropPredDataset(name='ogbg-molhiv') |
| split_idx = dataset.get_idx_split() |
| |
| train_loader = GraphDataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn) |
| valid_loader = GraphDataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn) |
| test_loader = GraphDataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn) |
# 节点分类任务
| |
| from ogb.nodeproppred import DglNodePropPredDataset |
| |
| dataset = DglNodePropPredDataset(name='ogbn-proteins') |
| split_idx = dataset.get_idx_split() |
| |
| |
| |
| g, labels = dataset[0] |
| |
| train_label = dataset.labels[split_idx['train']] |
| valid_label = dataset.labels[split_idx['valid']] |
| test_label = dataset.labels[split_idx['test']] |
# 链接预测任务
| |
| from ogb.linkproppred import DglLinkPropPredDataset |
| |
| dataset = DglLinkPropPredDataset(name='ogbl-ppa') |
| split_edge = dataset.get_edge_split() |
| |
| graph = dataset[0] |
| print(split_edge['train'].keys()) |
| print(split_edge['valid'].keys()) |
| print(split_edge['test'].keys()) |