Skip to main content
QUICK REVIEW

[논문 리뷰] Orthogonal Weight Normalization: Solution to Optimization over Multiple Dependent Stiefel Manifolds in Deep Neural Networks

Lei Huang, Xianglong Liu|arXiv (Cornell University)|2017. 09. 16.
Advanced Neural Network Applications인용 수 90
한 줄 요약

본 논문은 심층 네트워크에서 직교 직사각형 가중치 행렬의 학습을 Optimization over Multiple Dependent Stiefel Manifolds (OMDSM)으로 공식화하고 프록시 파라미터를 통한 직교 가중치 정규화를 제안하여 Orthogonal Linear Module을 도출하고 프로토콜을 바꾸지 않고 CNN의 성능을 향상시킨다.

ABSTRACT

Orthogonal matrix has shown advantages in training Recurrent Neural Networks (RNNs), but such matrix is limited to be square for the hidden-to-hidden transformation in RNNs. In this paper, we generalize such square orthogonal matrix to orthogonal rectangular matrix and formulating this problem in feed-forward Neural Networks (FNNs) as Optimization over Multiple Dependent Stiefel Manifolds (OMDSM). We show that the rectangular orthogonal matrix can stabilize the distribution of network activations and regularize FNNs. We also propose a novel orthogonal weight normalization method to solve OMDSM. Particularly, it constructs orthogonal transformation over proxy parameters to ensure the weight matrix is orthogonal and back-propagates gradient information through the transformation during training. To guarantee stability, we minimize the distortions between proxy parameters and canonical weights over all tractable orthogonal transformations. In addition, we design an orthogonal linear module (OLM) to learn orthogonal filter banks in practice, which can be used as an alternative to standard linear module. Extensive experiments demonstrate that by simply substituting OLM for standard linear module without revising any experimental protocols, our method largely improves the performance of the state-of-the-art networks, including Inception and residual networks on CIFAR and ImageNet datasets. In particular, we have reduced the test error of wide residual network on CIFAR-100 from 20.04% to 18.61% with such simple substitution. Our code is available online for result reproduction.

연구 동기 및 목표

  • 깊은 네트워크에서 직교 가중치 행렬을 통한 정규화와 안정적 최적화를 동기화한다.
  • DNN에서 직교 필터 학습을 OMDSM(Optimization over Multiple Dependent Stiefel Manifolds)으로 형식화한다.
  • 직교화 변환을 역전파로 전달하는 안정적인 해법인 Orthogonal Weight Normalization을 개발한다.
  • 실무에서 표준 선형 계층을 대체하는 Orthogonal Linear Module(OLM)을 도입한다.
  • CIFAR 및 ImageNet 데이터셋에서 MLP와 CNN에 걸친 성능 향상을 보여준다.

제안 방법

  • 각 층에서 W^l을 직교로 정의하고 W^l ∈ O^{n_l × d_l}로 구성하여 OMDSM을 형성한다.
  • W^l을 W^l = φ(V^l)로 재파라미터화하고 φ가 프록시 V^l를 직교한 W^l로 매핑하도록 한다.
  • 공분산 Σ의 고유 분해를 이용하여 W^l = D Λ^{-1/2} D^T (V^l - c 1_d^T) 형태로 φ를 중심화하여 계산한다.
  • 고유 분해 도함수의 행렬 미분 calculus를 사용하여 φ를 통해 기울기를 역전파한다.
  • W = φ(V)이고 W W^T = I를 만족하도록 왜곡 tr((W - V_c)(W - V_c)^T)를 최소화하여 학습을 안정화한다(OLM).
  • n > d인 경우 그룹 단위로 가중치를 나누고 각 그룹에서 직교화를 수행하는 그룹 기반 직교화를 선택적으로 사용한다.
  • 합성곱 층은 W^C를 2D로 풀어 동일한 직교화를 적용하며, 그룹 기반 전략은 계산을 줄여준다.
  • φ 변환을 이용한 순전파/역전파를 구현하고 추론 시 W를 저장하는 Orthogonal Linear Module(OLM)을 제안한다.

실험 결과

연구 질문

  • RQ1OMDSM 하에서 딥 피드포워드 네트워크에서 직교 직사각형 가중치를 효과적으로 학습할 수 있는가?
  • RQ2프록시 파라미터 직교화를 사용한 OMDSM의 해가 리만 최적화 방법과 비교하여 안정적이고 확장 가능한 학습을 제공하는가?
  • RQ3표준 선형 모듈을 OLM으로 대체했을 때 CNN 아키텍처의 최적화 속도와 일반화에 어떤 영향을 미치는가?
  • RQ4대규모 CNN에서 OMDSM을 도입하기 위한 실용적 전략(예: 그룹 기반 직교화, BN/Adam과의 호환성)은 무엇인가?

주요 결과

  • 리만 최적화 방법은 OMDSM에 대해 불안정하거나 수렴이 느리지만, OLM은 안정적이고 빠른 최적화를 달성한다.
  • OLM은 활성 분포를 안정시키고 그래디언트 노름을 보존하여 학습 깊이와 조건화에 도움을 준다.
  • 표준 선형 모듈을 OLM으로 교체하면 CNN 아키텍처와 데이터셋에서 일관된 성능 향상을 얻을 수 있다.
  • CIFAR-100에서 Wide ResNet을 사용한 경우 테스트 에러가 OLM으로 20.04%에서 18.61%로 개선되며 CIFAR-10에서도 관련 이득이 있다.
  • VGG 스타일 네트워크에서 OLM(및 변형)을 적용하면 CIFAR-10/100에서 최첨단 또는 경쟁력 있는 결과를 달성하며 예를 들어 WRN-28-10-OLM은 CIFAR-10에서 3.73%, CIFAR-100에서 18.76%를 달성한다.
  • BN-Inception에 OLM을 적용하면 CIFAR-10/100에서 베이스라인보다 성능이 향상된다; 예를 들어 CIFAR-100은 일반 BN-Inception 대비 22.02%로 감소한다(24.87%에서).

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.