On this page

    Creating a Custom Dataset

    Dataset 类是 Spektral 1.0的一个新特性,它标准化了 Spektral 中图数据集的表示方式。

    在本教程中,我们将通过一个简单的例子来创建一个自定义数据集。

    如果你想公开分享你的数据集或将它们包含在Spektral中,这也很有用。

    Essential information

    您可以通过继承 spektral.data.Dataset 类来创建数据集。

    datasets 的核心是 read() 方法。这在数据集的每次实例化时被调用,并且必须返回spektral.data.Graph 的列表。从文件中读取数据或动态创建数据并不重要,这是数据集在内存中加载的位置。

    所有 dataset 有一个 path 属性, 其表示数据存储的文件夹。 默认为 ~/.spektral/datasets/[ClassName]。 你可以忽略它。但是,每次实例化数据集时,它将检查路径是否存在。如果没有,将调用download()方法。

    可以使用 download() 定义将原始数据保存到磁盘所需的任何额外操作。该方法将在 read() 之前被调用。

    read()download() 都由数据集的 init() 方法调用。如果需要重写数据集的初始化,请 在实现的某个地方(通常在最后一行)确保调用 super().Init()

    Example

    这是一个简单的示例,展示了如何使用五个随机图创建自定义数据集。我们假设数据来自在线数据源,这样就可以演示如何使用 download()

    我们首先覆盖 init() 方法,以存储数据集的一些自定义参数。

    class MyDataset(Dataset):
        """
        A dataset of five random graphs.
        """
        def __init__(self, nodes, feats, **kwargs):
            self.nodes = nodes
            self.feats = feats
    
            super().__init__(**kwargs)
    

    记得在最后一行调用 super().__init__(**kwargs)

    然后,我们模拟从网上下载数据。因为如果 path 在系统上不存在,就会调用这个方法,所以现在创建相应的目录是有意义的:

    def download(self):
        data = ...  # Download from somewhere
    
        # Create the directory
        os.mkdir(self.path)
    
        # Write the data to file
        for i in range(5):
            x = rand(self.nodes, self.feats)
            a = randint(0, 2, (self.nodes, self.nodes))
            y = randint(0, 2)
    
            filename = os.path.join(self.path, f'graph_{i}')
            np.savez(filename, x=x, a=a, y=y)
    

    最后,我们实现 read() 方法来返回一个 Graph 对象列表

    def read(self):
        # We must return a list of Graph objects
        output = []
    
        for i in range(5):
            data = np.load(os.path.join(self.path, f'graph_{i}.npz'))
            output.append(
                Graph(x=data['x'], a=data['a'], y=data['y'])
            )
    
        return output
    

    我们现在可以实例化我们的数据集,这将 “download” 我们的数据并将其读入内存:

    >>> dataset = MyDataset(3, 2)
    >>> dataset
    MyDataset(n_graphs=5)
    

    我们可以看到我们的图被保存为文件

    $ ls ~/.spektral/datasets/MyDataset/
    graph_0.npz  graph_1.npz  graph_2.npz  graph_3.npz  graph_4.npz
    

    所以下一次我们创建 MyDataset 时,它将从我们已经保存的文件中读取。

    您现在可以随意使用自定义数据集。Loaders 、transforms和文档中描述的所有其他功能都可以工作。

    请记住,如果您愿意,您可以自由地按您喜欢的方式存储数据。Spektral 中的数据集只是为了简化你的工作流程,但库仍然是根据 Keras 的原则设计的,即不妨碍你的工作。如果您希望以不同的方式操作数据,GNNs 仍然可以工作。