[论文解读] Towards a Neural Statistician
该论文将变分自编码器扩展为学习数据集级统计信息,通过统计网络实现,在跨数据集的无监督、数据高效和少样本学习中建模在每个数据集内共享的潜在上下文 c。
An efficient learner is one who reuses what they already know to tackle a new problem. For a machine learner, this means understanding the similarities amongst datasets. In order to do this, one must take seriously the idea of working with datasets, rather than datapoints, as the key objects to model. Towards this goal, we demonstrate an extension of a variational autoencoder that can learn a method for computing representations, or statistics, of datasets in an unsupervised fashion. The network is trained to produce statistics that encapsulate a generative model for each dataset. Hence the network enables efficient learning from new datasets for both unsupervised and supervised tasks. We show that we are able to learn statistics that can be used for: clustering datasets, transferring generative models to new datasets, selecting representative samples of datasets and classifying previously unseen classes. We refer to our model as a neural statistician, and by this we mean a neural network that can learn to compute summary statistics of datasets without supervision.
研究动机与目标
- 代表数据集(不仅是数据点)作为第一等对象,以提升迁移和学习效率。
- 开发一个无监督的神经方法来计算定义每个数据集生成模型的数据集级统计信息。
- 通过学习的统计信息实现聚类、数据集级迁移、代表性采样,以及对少样本类别的处理。
- 提供一种可扩展、参数高效的方法,将数据集数量与模型规模解耦。
提出的方法
- 在数据集内的项之间扩展一个共享上下文变量 c 的变分自编码器。
- 引入一个统计网络 q(c|D;φ),它使用一个无序集合 D 的示例,通过可交换的池化层(如平均值)聚合以产生 c。
- 使用包含潜在变量 z 的分层潜在结构,以及跳连接,形成一个灵活的生成模型 p(x|z,c;θ)。
- 定义三部分变分边界 LD = RD + CD + LD,其中 RD 是重构,CD 是上下文差异,LD 是潜在差异。
- 在数据集批次上训练(而非数据点),以最大化数据集上的期望 LD。
- 使用近似推断网络 q(z|x,c;φ) 和 q(c|D;φ) 进行前馈推断,并对梯度估计使用重参数化技巧。
实验结果
研究问题
- RQ1神经模型是否能够学习出能够概括每个数据集生成过程的有意义的数据集级统计信息?
- RQ2学习到的统计信息是否能够按分布族对数据集进行聚类、实现跨数据集的迁移,以及在少样本下对未见类别进行分类或采样?
- RQ3统计网络是否支持在数据集上进行条件化,以生成并从特定数据集的生成模型中采样?
- RQ4将数据集视为单位进行建模,如何提升样本效率并实现数据集表示的无监督学习?
- RQ5层次潜在结构和跳连接对建模复杂数据集结构有何影响?
主要发现
| 任务 | 方法 | 测试数据集 | K 次样本 | K 类别 | 孪生网络 | MANN | 匹配 | 本方法 |
|---|---|---|---|---|---|---|---|---|
| MNIST | 1 | 10 | 70 | - | 72 | 78.6 | ||
| MNIST | 5 | 10 | - | 93.2 | ||||
| OMNIGLOT | 1 | 5 | 97.3 | 82.8 | 98.1 | 98.1 | ||
| OMNIGLOT | 5 | 5 | 98.4 | 94.9 | 98.9 | 99.5 | ||
| OMNIGLOT | 1 | 20 | 88.1 | 93.8 | 93.2 | |||
| OMNIGLOT | 5 | 20 | 97.0 | 98.7 | 98.1 |
- 该模型学习的统计信息能够按其分布族将一维合成数据集聚类,在簇内实现均值与方差的正交映射。
- 在空间 MNIST 上,模型可以在给定数据集的条件下生成样本,并执行有意义的子集选择作为摘要。
- 对于 OMNIGLOT 和少样本任务,该方法在 5-way 任务上达到具有竞争力的准确率,在 1-shot/5-shot 设置中表现强劲,显示出对未见字符和数字的迁移。
- 少样本分类实验表明神经统计学家是一个强基线,特别是在 5-way 任务中,尽管有时在较大(20-way)任务上略落后于专门的匹配网络。
- 该方法展示了对 YouTube Faces 的少样本学习和数据集条件生成,在生成样本中产生一致的身份和不同的姿态。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。