Skip to main content
QUICK REVIEW

[논문 리뷰] Distributed Statistical Machine Learning in Adversarial Settings: Byzantine Gradient Descent

Yudong Chen, Lili Su|arXiv (Cornell University)|2017. 05. 16.
Stochastic Gradient Optimization Techniques참고 문헌 24인용 수 142
한 줄 요약

본 논문은 Byzantine Gradient Descent를 제안합니다. 이는 최대 ~2(1+ε)q 명의 Byzantine 워커를 허용하는 강건한 분산 학습 알고리즘이며, 각 log N 라운드당 오차가 ∼max{√(dq/N), √(d/N)}인 지수 수렴을 달성합니다.

ABSTRACT

We consider the problem of distributed statistical machine learning in adversarial settings, where some unknown and time-varying subset of working machines may be compromised and behave arbitrarily to prevent an accurate model from being learned. This setting captures the potential adversarial attacks faced by Federated Learning -- a modern machine learning paradigm that is proposed by Google researchers and has been intensively studied for ensuring user privacy. Formally, we focus on a distributed system consisting of a parameter server and $m$ working machines. Each working machine keeps $N/m$ data samples, where $N$ is the total number of samples. The goal is to collectively learn the underlying true model parameter of dimension $d$. In classical batch gradient descent methods, the gradients reported to the server by the working machines are aggregated via simple averaging, which is vulnerable to a single Byzantine failure. In this paper, we propose a Byzantine gradient descent method based on the geometric median of means of the gradients. We show that our method can tolerate $q \le (m-1)/2$ Byzantine failures, and the parameter estimate converges in $O(\log N)$ rounds with an estimation error of $\sqrt{d(2q+1)/N}$, hence approaching the optimal error rate $\sqrt{d/N}$ in the centralized and failure-free setting. The total computational complexity of our algorithm is of $O((Nd/m) \log N)$ at each working machine and $O(md + kd \log^3 N)$ at the central server, and the total communication cost is of $O(m d \log N)$. We further provide an application of our general results to the linear regression problem. A key challenge arises in the above problem is that Byzantine failures create arbitrary and unspecified dependency among the iterations and the aggregated gradients. We prove that the aggregated gradient converges uniformly to the true gradient function.

연구 동기 및 목표

  • 연합 학습과 같이 적대적( Byzantine) 오류가 존재하는 분산 통계 학습의 필요성을 제시한다.
  • Byzantine 오류를 허용하는 강건한 그래디언트 집계 방법을 개발한다.
  • Byzantine 오류 하에서 수렴 보장성과 추정 오차를 특징화한다.
  • 제안된 방법의 계산 및 통신 비용을 분석한다.
  • 이러한 접근법을 설명하기 위한 선형 회귀에의 응용을 제공한다.

제안 방법

  • 서버가 배치 평균과 기하 중앙값에 기반한 강건한 계획으로 그래디언트를 집계하는 Byzantine Gradient Descent를 제안한다.
  • 작업 기계를 m개에서 k개의 배치로 분할하고 그래디언트의 배치 평균을 계산한다.
  • 이 k개의 배치 평균의 기하 중앙값을 계산하여 업데이트를 위한 집계된 그래디언트를 형성한다.
  • 강한 볼록성 및 Lipschitz 기울기 가정하에 η = L/(2M^2)로 선택된 학습률 η의 경사하강 스텝을 사용한다.
  • 형식적 수렴 정리를 제공하여 로그 N 라운드에서 지수 수렴을 보이며, 오차 한계는 √(dq/N)와 √(d/N)에 비례하여 증가한다.
  • 계산 비용은 각 작업자당 O((Nd/m) log N), 파라미터 서버에서는 O(md + qd log^3 N), 통신 비용은 O(md log N)이다.

실험 결과

연구 질문

  • RQ1분산 학습 알고리즘이 각 워커의 로컬 데이터를 사용하면서 Byzantine(임의) 실패를 허용할 수 있는가?
  • RQ2수렴을 파괴하지 않으면서 Byzantine 영향력을 완화하는 그래디언트의 강건한 집계 규칙은 무엇인가?
  • RQ3 Byzantine 오류 하에서 분산 학습의 수렴 속도와 통계적 오차 한계는 어떻게 되는가?
  • RQ4 fault 허용성과 통계적 정확도 사이의 균형을 맞추려면 시스템 매개변수(k, q, m, N, d)를 어떻게 선택해야 하는가?
  • RQ5이 방법이 선형 회귀와 같은 구체적 문제에 어떻게 적용되는가?

주요 결과

  • 제안된 Byzantine Gradient Descent 방법은 어떤 고정된 ε>0에 대해 최대 2(1+ε)q ≤ m Byzantine 실패를 허용한다.
  • 추정기가 O(log N) 라운드에서 수렴하며 오차 한계는 max{√(dq/N), √(d/N)}이다.
  • minimax 최적 속도 √(d/N)은 Byzantine 환경에서 √q의 인수에 의해 달성 가능하다.
  • 계산 총 비용은 각 워커당 O((Nd/m) log N)이며 파라미터 서버에서 O(md + qd log^3 N), 통신 비용은 O(md log N)이다.
  • 선형 회귀의 경우 이 프레임워크가 적대적 워커에 대한 적용성과 강건성을 보인다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.