[Paper Review] End to end learning and optimization on graphs
ClusterNet embeds graph nodes and differentiably maps to graph-optimization decisions by a soft clustering layer, achieving better performance than two-stage and pure end-to-end baselines on learning+optimization tasks.
Real-world applications often combine learning and optimization problems on graphs. For instance, our objective may be to cluster the graph in order to detect meaningful communities (or solve other common graph optimization problems such as facility location, maxcut, and so on). However, graphs or related attributes are often only partially observed, introducing learning problems such as link prediction which must be solved prior to optimization. Standard approaches treat learning and optimization entirely separately, while recent machine learning work aims to predict the optimal solution directly from the inputs. Here, we propose an alternative decision-focused learning approach that integrates a differentiable proxy for common graph optimization problems as a layer in learned systems. The main idea is to learn a representation that maps the original optimization problem onto a simpler proxy problem that can be efficiently differentiated through. Experimental results show that our ClusterNet system outperforms both pure end-to-end approaches (that directly predict the optimal solution) and standard approaches that entirely separate learning and optimization. Code for our system is available at https://github.com/bwilder0/clusternet.
Motivation & Objective
- Motivate and formalize the integration of learning and optimization on graphs when the graph is partially observed.
- Introduce a differentiable proxy (soft K-means in embedding space) to approximate discrete graph optimization problems.
- Demonstrate end-to-end training that optimizes downstream decision quality rather than predictive accuracy.
- Show that the learned representations induce high-value solutions for downstream tasks across multiple domains.
Proposed method
- Embed graph nodes into a continuous space using a graph neural network (e.g., GCN) driven by observed edges and node features.
- Include a differentiable K-means clustering layer that assigns nodes to K clusters with soft assignments.
- Use a forward pass that updates cluster centers and soft assignments via a differentiable form of K-means (softmax over distances with inverse temperature beta).
- Differentiate through the clustering fixed point using the implicit function theorem to obtain gradients with respect to embeddings.
- Interpret cluster assignments as soft solutions to graph optimization problems (partitioning or subset selection) and compute a differentiable expected loss over these soft solutions.
- Provide two practical strategies for turning soft clusterings into discrete decisions: (i) partitioning via soft partitions and (ii) selecting a subset via probabilistic mass on cluster centers, followed by rounding (e.g., pipage rounding) at test time.
- Argue and provide approximation guarantees for efficient approximate/backward passes, enabling scalable end-to-end training.
Experimental results
Research questions
- RQ1Can decision-focused learning improve downstream optimization quality over two-stage and pure end-to-end methods?
- RQ2Does embedding-based soft clustering serve as an effective differentiable proxy for hard graph optimization problems (partitioning and subset selection)?
- RQ3How well does the model generalize to unseen graphs, and how does fine-tuning affect performance on new graphs?
- RQ4What are the computational trade-offs of exact versus approximate differentiable backpropagation through the clustering layer?
Key findings
- ClusterNet consistently outperforms baselines that combine learning with optimization or purely end-to-end approaches on community detection and facility location tasks.
- Two-stage approaches sometimes underperform compared to training only on observed edges, illustrating the value of end-to-end decision-focused learning.
- GCN-e2e (pure end-to-end) often ties or underperforms compared to ClusterNet, highlighting the benefit of incorporating algorithmic structure as a differentiable layer.
- ClusterNet generalizes across unseen graphs and distributions, with fine-tuning offering additional gains when available.
- Forward passes are efficient (up to 0.23 seconds on the largest graph), and the architecture supports scalable approximation in the backward pass.
Better researchstarts right now
From paper design to paper writing, dramatically reduce your research time.
No credit card · Free plan available
This review was created by AI and reviewed by human editors.