[AAAI-2024] Label Attentive Distillation for GNN-Based Graph Classification
1. Introduction and Motivation
Graph Neural Networks (GNNs) have gained substantial traction in recent years due to their powerful ability to model graph-structured data. They are used extensively in applications such as social networks, recommendation systems, molecular biology, and more. Traditional GNNs, while effective, face limitations when it comes to graph-level tasks, particularly in graph classification. Aiming to improve the performance, this study introduces a novel approach, Label Attentive Distillation (LAD-GNN).
The key problem in existing GNN approaches can be named to be embedding misalignment. Traditional GNN models perform node-level feature aggregation without considering graph-level label information. This results in embeddings that do not align well with their corresponding graph labels. The misalignment of these embeddings reduces the discriminative power of graph representations, leading lessened classification accuracy.
Figure 1. The shortcomings of the conventional pipeline of GNN-based graph classification.
To address this issue, the authors proposed the LAD-GNN algorithm, which introduces a distillation framework where a “teacher” model guides a “student” GNN. The teacher model uses a label-attentive encoder that incorporates label information into the node embedding process, while the student model learns to produce class-friendly node embeddings based on the teacher’s supervision that leads to improved alignment.
2. Related Work
Presented study builds upon several important research areas, including Graph Neural Networks (GNNs), label enhancement methods, and knowledge distillation.
Graph Neural Networks (GNNs)
Graph Neural Networks (GNNs) have emerged as powerful tools for modeling graph-structured data, with foundational models like Graph Convolutional Networks (GCN) (Kipf and Welling, 2017) and Graph Attention Networks (GAT) (Veličković et al., 2018) paving the way for advancements in the field. These models excel in node-level tasks such as node classification and link prediction by aggregating features from neighboring nodes. However, when applied to graph-level tasks like classification, these methods often face embedding misalignment issues. This occurs because traditional GNNs primarily focus on local neighborhood aggregation, where node embeddings are generated by combining features from immediate neighbors without incorporating the global graph-label information. Consequently, the readout functions, which aggregate diverse and potentially conflicting local embeddings, struggle to form a cohesive and class-discriminative global representation. This misalignment can result in suboptimal graph representations, limiting their effectiveness for tasks requiring a holistic understanding of the graph structure. LAD-GNN directly addresses this limitation by incorporating label-attentive mechanisms to align embeddings with graph-level labels, thereby enhancing classification accuracy.
Label Enhancement Methods
Label-enhanced techniques aim to improve model performance by integrating label information during training. For example, Label-Enhanced Node Embedding methods (Wang, 2021) combine node features with label information to improve node-level tasks, enabling better feature representation. Similarly, Label-Enhanced Graph Structure Optimization (Chen et al., 2019) refines the adjacency matrix using label data to facilitate smoother information propagation among nodes sharing similar labels. While effective for node-level applications, these approaches do not explicitly address the challenges of graph-level tasks, such as embedding alignment or the integration of label information across the entire graph structure. LAD-GNN builds upon these approaches by extending the use of label information beyond node-level enhancements, embedding label features directly into the graph representation to support graph-level classification tasks.
Knowledge Distillation
Knowledge distillation, introduced by Hinton et al. (2015), traditionally involves using a larger, more complex teacher model to train a smaller student model through “soft labels” or probabilistic outputs. This method is commonly used to reduce model size while retaining performance. In the context of GNNs, distillation has been applied to compress models, enabling deployment on resource-constrained devices (Yang et al., 2021).
The motivation to use knowledge distillation in LAD-GNN stems from its ability to transfer enriched, task-specific knowledge from the teacher to the student model, addressing the critical embedding misalignment problem that hampers graph-level classification tasks. By leveraging a teacher model equipped with a label-attentive encoder, LAD-GNN can generate ideal embeddings that integrate both node features and global label information. These embeddings provide structured intermediate supervision, guiding the student model to produce class-discriminative and label-aligned embeddings. Additionally, the distillation process ensures that the student model inherits the teacher’s ability to generate high-quality embeddings while remaining computationally efficient for deployment.
Instead of focusing on compression, LAD-GNN uses knowledge distillation as a method for embedding refinement and task-specific optimization, enabling it to effectively align embeddings with graph labels and improve graph classification performance. This innovative application of distillation transforms it into a tool for resolving embedding misalignment challenge. This approach is particularly effective in this case because graph-level tasks require aligning local node embeddings with global graph labels, which distillation achieves through the teacher’s enriched embeddings. Moreover, the teacher-student framework is computationally efficient, as the teacher is only needed during training, making the method scalable.
3. Problem Formulation
Given a graph dataset $D = (G,Y) = { (G_ {i}, y_ i) }_ {i=1}^{N}$, where $G_ i \in G$ represents the $i$ -th graph in the dataset and $y_ i \in Y$ is the corresponding graph label, the goal is to learn a low-dimensional graph representation for predicting the graph labels. Each graph $G_ i$ consists of:
- $A_ i \in \mathbb{R}^{n_ i \times n_ i}$ : an adjacency matrix describing the relationships between the nodes in the graph, where $n$ is the number of nodes.
- $X_ i \in \mathbb{R}^{n_ i \times d}$: a node feature matrix where $d$ is the dimension of the node features.
The objective is to learn a function $f$ that maps the graph $G_ i$ to its label $y_ i$ by finding an optimal graph representation.
4. Method
Figure 2. LAD-GNN framework.
The proposed LAD-GNN method involves two main stages: Label-Attentive Teacher Training and Distillation-Based Student Learning.
Label-Attentive Teacher Training
The teacher model generates a resulting embedding where both node and label embeddings are fused using an attention mechanism. A label-attentive encoder is employed where graph labels are encoded and combined with node embeddings to form the final embedding used for classification.
Distillation-Based Student Learning
After the teacher model is trained, the student model is trained using a distillation-based approach which transfers knowledge from the teacher model to the student model. The student model learns from the teacher’s ideal embedding through intermediate supervision, so distillation process involves minimizing two losses: Classification Loss (ensures accurate graph label predictions) and Distillation Loss (aligns the student’s embeddings with the teacher’s ideal embeddings).
The idea behind having a teacher model and a student model in LAD-GNN is rooted in the concept of knowledge distillation. The teacher model is designed to learn enriched and label-augmented embeddings using the label-attentive encoder. These embeddings act as a reference or “ground truth” for the student model. The student model, in turn, is trained to mimic these embeddings while also optimizing its classification performance. As a result it allows to addresses the “embedding misalignment” problem.
This two-step process achieves the following:
Intermediate Supervision: By training the student to replicate the teacher’s ideal embeddings, LAD-GNN ensures that the student captures the critical graph-level label information embedded by the teacher.
Separation of Complexity: The teacher model can afford to be computationally intensive during training, allowing it to learn sophisticated representations. The student, being simpler, focuses on leveraging this guidance, ensuring efficient inference without the teacher model during deployment.
Improved Generalization: The distillation process helps the student model generalize better by effectively transferring the “label knowledge” encapsulated by the teacher, even in cases of noisy or limited data.
Framework of LAD-GNN
GNN Backbone
The GNN backbone forms the core of the LAD-GNN architecture and is responsible for generating the initial node embeddings, which are later refined using the label-attentive encoder. The GNN backbone can be formalized as:
$H_ {v}^{(l+1)} = \text{UPT}(H_ {v}^{(l)}, \text{AGG}({ H_ {u}^{(l)}\vert u \in N_ v })), \forall v \in V$ where $H_ {v}^{(l+1)}$ is the node embedding at layer $l + 1$, $N_ v$ represents the set of neighbors for node $v$, $AGG$ is the aggregation function that combines information from the neighbors, and $UPT$ is the update function applied to node $v$.
After multiple layers of message-passing and aggregation, the final node embeddings $H$ are aggregated into a graph-level representation $Z_ G$ using a readout (pooling) function:
$Z_ G = POOL({H_ v\vert v \in V})$
where $POOL$ can be an average, max pooling, or an attention-based pooling method. The graph-level representation $Z_ G$ is then passed to the classifier for label prediction:
$\hat{\mathcal{Y}} = g(Z_ G; \phi_ g)$
where $\hat{\mathcal{Y}} \in \mathbb{R}^{N \times c}$ represents the predicted labels with $c$ classes and $\phi_ g$ is the classifier’s parameters.
Label-Attentive Teacher Model
The teacher model integrates label information into the node embeddings through the label-attentive encoder, a critical component of LAD-GNN. The teacher model is a critical component of LAD-GNN because it serves as the foundation for generating the “ideal embeddings” that guide the student model’s learning process. The teacher model incorporates label-attentive encoding to align node embeddings with graph-level labels, ensuring that these embeddings are more discriminative for the graph classification task. The encoder transforms the label $y_ G$ of graph $G$ into a latent embedding using a Multi-Layer Perceptron (MLP): $H_ l = h(y_ G\vert G \in \mathcal{G})$ Where $H_ l$ is the latent label embedding.
This embedding is then combined with the node embeddings $H_ v$ generated by the GNN backbone using a scaled dot-product attention mechanism: $H’_ v = Attention(H_ v, H_ l) = Softmax(\frac{QK^T}{\sqrt{d_ k}} \cdot\tau)V$ where $Q=H_ lW^Q$, $K=H_ vW^K$, and $V=H_ vW^V$ are projections of the label and node embeddings, $d_ k$ is the dimensionality of the projections and $\tau$ is the attention temperature coefficient.
The output $H’_ v$ is normalized using a layer normalization function layernorm (LN), then passed through a feed-forward network (FFN) to generate the final node embedding: $H_ v^{(T)}=FFN(LN(H’_ v + H_ v)) + H’_ v$
The teacher model is optimized using a classification loss:
$\mathcal{L}_ {cls} = \frac{1}{N}\sum_ {i = 1}^{N}{-(y_ {i}\log(\hat{y_ {i}}) + ( 1 - y_ {i} )\log( 1 - \hat{y_ {i}} ))}$
Distillation-Based Student Learning
Once the teacher model is trained, the student model learns to mimic the teacher’s embeddings using a distillation loss aiming to minimize the mean squared error (MSE) between the node embeddings from the teacher and student models:
$\mathcal{L}_ {dis} = \frac{1}{N}\sum_ {i = 1}^{N}{\vert\vert H_ i^{(T)} - H_ i^{(S)}\vert \vert^2_ 2}$
The student model is trained to minimize a combination of the classification loss $L_ {cls}$ and the distillation loss $L_ {dis}$, weighted by a hyperparameter $\lambda$:
$\mathcal{L} = \mathcal{L}_ {cls} + \lambda\cdot\mathcal{L}_ {dis}$
By distilling knowledge from the teacher model, the student GNN can learn to generate class-friendly node embeddings that align better with the graph labels.
The computational complexity of LAD-GNN is comparable to that of traditional GNN backbones, with an added cost for the distillation process. The primary increase in complexity arises from the label-attentive encoder, but this is considered negligible compared to the overall complexity of the GNN model.
Algorithm of LAD-GNN.
5. Experiments
Datasets
The effectiveness of LAD-GNN is demonstrated through extensive experiments on ten different datasets across multiple domains to ensure a broad understanding of the method’s generalizability. The datasets used in the experiments include:
-
Chemical Datasets: MUTAG, PTC, NCI1, NCI109: These datasets are composed of molecular graphs used for graph classification tasks like predicting chemical compound properties.
-
Bioinformatics Datasets: PROTEINS, ENZYMES: Graphs representing protein structures where the task is to classify proteins or predict enzyme classes.
-
Social Network Datasets: COLLAB, IMDB-BINARY, IMDB-MULTI, REDDIT-BINARY, REDDIT-MULTI-5K: These are graphs representing social networks, where the task is to classify graphs into different social categories or communities.
The experiments are conducted using 7 different GNN backbones: GCN, GAT, GraphSAGE, GIN, DGCNN, SAGPool, and MEWISPool. The performance of LAD-GNN is compared to the base models to measure its effectiveness in improving graph classification accuracy.
Experimental Results
The experiments are evaluated using the standard graph classification metric, accuracy, with the following design considerations:
A ten-fold cross-validation setting is used across all datasets to ensure robust and unbiased results. The datasets are split into train, validation, and test sets with an 80/10/10 split ratio.
Performance Comparison with GNN Backbones The first set of experiments compares the classification accuracy of the baseline GNN models (GCN, GAT, GraphSAGE, GIN, DGCNN, SAGPool, and MEWISPool) with and without the LAD-GNN framework. LAD-GNN consistently improves the performance of all backbones across all datasets. Notable results include:
-
IMDB-BINARY dataset: LAD-GNN achieves a 16.8% improvement in accuracy when combined with the GraphSAGE backbone (from 58.9% to 75.5%).
-
MUTAG dataset: LAD-GNN provides an increase of 4.7% in accuracy on GraphSAGE and similar improvements on other backbones.
-
NCI1 dataset: The performance of GraphSAGE combined with LAD-GNN jumps from 69.9% to 86.7%, representing a significant 16.8% boost.
Table 1. Performance comparison of LAD-GNN with GNN Backbones.
Comparison with Other GNN Training Methods To further validate the effectiveness of LAD-GNN, the authors compare it with nine other state-of-the-art GNN training strategies, including:
- Manual graph augmentation methods: DropEdge, M-Mixup, and G-Mixup.
- Graph auto-augmentation methods: JOAOv2, AD-GCL, and AutoGCL.
- Graph distillation methods: KD, GFKD, and DFAD-GNN.
Table 2 in the paper shows that LAD-GNN outperforms all these competing methods on most datasets. Specifically:
-
LAD-GNN achieves the highest accuracy on six out of seven datasets, including significant gains on the NCI1, PTC, and PROTEINS datasets.
-
On the MUTAG dataset, LAD-GNN is slightly outperformed by IGSD (Iterative Graph Self-Distillation), though the difference is small. This comparison shows that LAD-GNN is more effective at improving graph classification tasks than other augmentation and distillation-based methods.
Table 2. Performance comparison of LAD-GNN with with other GNN training tethods.
To further analyze the performance of LAD-GNN, authors provide Receiver Operating Characteristic (ROC) curves and Area Under the Curve (AUC) values for the MUTAG and PTC datasets. ROC curves are used to visualize the performance of classification models at various threshold settings, and AUC values summarize the model’s ability to discriminate between classes. The ROC curves for LAD-GNN show that it consistently outperforms the baseline GNN backbones (GCN, GAT, GIN, GraphSAGE) on both datasets. Specifically:
-
On the MUTAG dataset, LAD-GNN achieves an AUC of 0.86, which is higher than GCN (0.75), GAT (0.77), and GIN (0.79).
-
On the PTC dataset, LAD-GNN achieves an AUC of 0.80, outperforming GCN (0.71), GAT (0.70), and GraphSAGE (0.76). These results further confirm that LAD-GNN improves the graph classification performance not only in terms of accuracy but also in terms of the overall quality of predictions.
Figure 3. The ROC curves on MUTAG and PTC datasets of GCN, GAT, GIN, SAGE, and the proposed LAD-GNN.
Visualization of Embeddings The t-SNE visualizations of the learned graph embeddings from the original GNN backbones and LAD-GNN are included in the paper. The t-SNE plots reveal that LAD-GNN generates more tightly clustered embeddings, with better separation between classes. This indicates that the label-attentive distillation process helps the model learn more discriminative representations for graph classification tasks. The visualization results are particularly important because they provide an intuitive understanding of how LAD-GNN improves the alignment between node embeddings and graph labels. By incorporating label information into the node embedding process, LAD-GNN produces embeddings that are more class-friendly and easier to separate in the feature space.
Figure 4: The t-SNE visualization of graph representations for LAD-GNN and the original GNN
Hyperparameter Sensitivity The authors analyze the sensitivity of two key hyperparameters:
-
$\lambda$: The weight of the distillation loss in the overall objective function.
-
$\tau$: The attention temperature coefficient in the label-attentive encoder.
The experiments show that the performance of LAD-GNN is sensitive to $\lambda$, with different datasets requiring different values for optimal performance. For example, the best values for $\lambda$ range between 0.1 and 10 depending on the dataset. In contrast, $\tau$ is found to have less impact on the model’s performance, with $\tau = 0.1$ being chosen as a standard value for all datasets. This hyperparameter analysis provides valuable insights into the importance of fine-tuning $\lambda$ for each specific dataset, while keeping $\tau$ relatively stable.
Figure 5: Hyperparameter sensitivitivity of (a) $\lambda$ and (b) $\tau$ on the 4 datasets
6. Conclusion, Key Takeaway and Limitations
In conclusion, the paper Label Attentive Distillation for GNN-Based Graph Classification offers a significant improvement in graph-level tasks by addressing the embedding misalignment issue present in traditional GNN models. By introducing a novel distillation framework where a teacher model with a label-attentive encoder guides a student GNN, the authors demonstrate substantial improvements in classification accuracy across a wide range of datasets and GNN backbones.
The key takeaway of this study is that introduced LAD-GNN algorithm holds ability to align node embeddings with graph-level labels leads to more discriminative graph representations, improving classification accuracy without substantially increasing computational cost. The method is highly adaptable and versatile and can be integrated with a wide range of GNN backbones. We can say that LAD-GNN as a valuable addition to the GNN ecosystem, offering performance improvements without sacrificing model simplicity or computational efficiency.
While LAD-GNN demonstrates promising results in addressing embedding misalignment for graph-level classification tasks, we can notice that discussed paper has several limitations.
- First, it focuses primarily on knowledge distillation as the solution without thoroughly exploring or comparing it to alternative methods (for example, contrastive learning or adversarial training) which might also be effective for aligning embeddings.
- Second, the reliance on a teacher-student framework introduces additional training complexity, as it requires a two-step process where the teacher model must first be trained before guiding the student. While this approach is efficient for inference, the added computational burden during training may pose challenges for very large or resource-intensive datasets.
- Lastly, the paper is limited to static graph classification tasks, leaving its applicability to dynamic graphs, link prediction, or other graph-based problems unexplored, which constrains the broader generalizability of the approach.
Despite these limitations, LAD-GNN offers an innovative approach to embedding alignment, paving the way for future research to refine and expand its applicability across diverse graph-based tasks.
Author Information
- Author name: Ignatova Elizaveta
- Affiliation: KAIST Bio and Brain Engineering Department
- Research Topic: Bioinformatic Analysis of Autoimmune Diseases Multi-Omics
References
- Kipf, T. N.; and Welling, M. 2017. Semi-Supervised Classification with Graph Convolutional Networks. In International Conference on Learning Representations (ICLR).
- Veličković, P.; Cucurull, G.; Casanova, A.; Romero, A.; Lio, P.; and Bengio, Y. 2018. Graph Attention Networks. International Conference on Learning Representations.
- Wang, Y. 2021. Bag of tricks of semi-supervised classification with graph neural networks. arXiv preprint arXiv:2103.13355.
- Chen, D.; Liu, X.; Lin, Y.; Li, P.; Zhou, J.; Su, Q.; and Sun, X. 2019. Highwaygraph: Modelling long-distance node relations for improving general graph neural network. arXiv preprint arXiv:1911.03904.
- Hinton, G.; Vinyals, O.; Dean, J.; et al. 2015. Distill- ing the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2(7).
- Yang, C.; Liu, J.; and Shi, C. 2021. Extract the knowledge of graph neural networks and go beyond it: An effective knowl- edge distillation framework. In Proceedings of the Web Con- ference, 1227–1237.