Skip to main content
QUICK REVIEW

[论文解读] Stiffness: A New Perspective on Generalization in Neural Networks

Stanislav Fort, Paweł Krzysztof Nowak|arXiv (Cornell University)|Jan 28, 2019
Neural Networks and Applications参考文献 28被引用 60
一句话总结

本文提出 stiffness 作为衡量来自一个样本的梯度更新如何影响其它样本损失的量度,将梯度对齐与泛化联系起来,并在跨数据集、架构和学习率的分析中进行研究。

ABSTRACT

In this paper we develop a new perspective on generalization of neural networks by proposing and investigating the concept of a neural network stiffness. We measure how stiff a network is by looking at how a small gradient step in the network's parameters on one example affects the loss on another example. Higher stiffness suggests that a network is learning features that generalize. In particular, we study how stiffness depends on 1) class membership, 2) distance between data points in the input space, 3) training iteration, and 4) learning rate. We present experiments on MNIST, FASHION MNIST, and CIFAR-10/100 using fully-connected and convolutional neural networks, as well as on a transformer-based NLP model. We demonstrate the connection between stiffness and generalization, and observe its dependence on learning rate. When training on CIFAR-100, the stiffness matrix exhibits a coarse-grained behavior indicative of the model's awareness of super-class membership. In addition, we measure how stiffness between two data points depends on their mutual input-space distance, and establish the concept of a dynamical critical length -- a distance below which a parameter update based on a data point influences its neighbors.

研究动机与目标

  • 提出并形式化 stiffness 概念,作为神经网络泛化的探针。
  • 研究 stiffness 如何依赖于类别成员关系、输入空间中的数据点距离、训练轮数和学习率。
  • 展示 stiffness 在视觉模型(MNIST、FASHION-MNIST、CIFAR-10/100)和基于变换器的 NLP 模型上的行为。
  • 检查由 stiffness 显现的动态临界长度和语义分组结构(超类)。

提出的方法

  • 通过两种基于梯度的测量来定义 stiffness:sign stiffness(g1·g2 的符号)和 cosine stiffness(g1 与 g2 的余弦相似度)。
  • 计算来自一个输入 X1、梯度为 g1 的小更新如何改变对另一个输入 X2 的损失。
  • 构造 class stiffness 矩阵 C(ca, cb) 并分析类间与类内 stiffness。
  • 在 train-train、train-val 和 val-val 设置下评估 stiffness,以便将其与泛化联系起来。
  • 使用动态临界长度 xi 将 stiffness 表征为输入空间距离的函数。
  • 在不同学习率和训练轮次下评估 stiffness,以观察更高的学习率如何使 stiffness 趋向较低、更加局部。

实验结果

研究问题

  • RQ1神经网络 stiffness 如何定义?它对泛化揭示了什么?
  • RQ2在不同数据集中,stiffness 如何随类别成员关系和语义分组(包括超类)变化?
  • RQ3stiffness 如何依赖于数据点之间的输入空间距离?
  • RQ4训练轮次和学习率对 stiffness 及 dynamical critical length xi 有何影响?
  • RQ5stiffness 的行为是否在视觉和语言模型(包括 CNN、ResNet 以及 BERT)中具有普适性?

主要发现

  • Stiffness 与泛化相关:在学习过程中,类内和类间的 stiffness 更高,但随着过拟合而下降。
  • 类内 stiffness 在早期和学习过程中保持较高,而类间 stiffness 随着模型学习而增加;两者在开始过拟合时衰减。
  • Stiffness 展示出语义上有意义的分组结构:在超类甚至更高层的类内 stiffness 高于随机基线。
  • 存在动态临界长度 xi:随着输入空间距离的增大,stiffness 衰减至零;xi 随训练和更高学习率而减小。
  • 更高的学习率产生具有更小 xi 的函数,即更新更局部且更易被弯曲,表明对学习到的函数具有正则化作用。
  • stiffness 概念扩展到 NLP(在 MNLI 上微调的 BERT),呈现与视觉模型相似的类内和类间动力学。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。