Are SAE features from the Base Model still meaningful to LLaVA?

post by Shan23Chen (shan-chen) · 2024-12-05T19:24:34.727Z · LW · GW · 0 comments

Contents

  TL;DR:
  Introduction
    Some Background on LLaVA:
    Key Components
      Training:
    Evaluating SAE Transferability with LLaVA
  Experimental Design
    Dataset
    Features and Evaluation
  Is there any signal?
      1. Dolphin 🐬
      Layer 0
      Layer 6
      Layer 10
      Layer 12
      Layer 12-it
      2. Skyscraper 🏙️
      Layer 0
      Layer 6
      Layer 10
      Layer 12
      Layer 12-it
      3. Boy 👦
      Layer 0
      Layer 6
      Layer 10
      Layer 12
      Layer 12-it
      4. Cloud ☁️
      Layer 0
      Layer 6
      Layer 10
      Layer 12
      Layer 12-it
  Classification Analysis
    Classification Setup
      Layer Evaluation
  Classification Findings
    Performance Summary
      So, in a way, we actually nearly recovered the full VIT performance here!
    1. How Many Features Do We Need?
      2. Which Layer Performs Best?
      3. To Binarize or Not to Binarize?
      4. Data Efficiency: How Much Training Data Do We Need?
    Big Picture Insights
    What’s Next?
None
No comments

Shan Chen, Jack Gallifant, Kuleen Sasse, Danielle Bitterman[1]
Please read this as a work in progress where we are colleagues sharing this in a lab (https://www.bittermanlab.org) meeting to help/motivate potential parallel research.

TL;DR:

Introduction

The pursuit of universal and interpretable features has long captivated researchers in AI, with Sparse Autoencoders (SAEs) emerging as a promising tool for extracting meaningful representations. Universality, in this context, refers to the ability of features to transcend domains, languages, modalities, model architectures, sizes, and training strategies. Recent advances have shed light on key properties of these representations, including their dataset-dependent nature, their relationship with the granularity of training data, and their transferability across tasks. Notably, studies have demonstrated the intriguing ability of features to transfer from base models to fine-tuned models, such as Kissane et al. (2024) [LW · GW] and Kutsyk et al. (2024) [AF · GW], and have even hinted at their generalization across layers (Ghilardi et al. 2024). However, one critical question remains underexplored: can features trained in unimodal contexts (e.g., text-only or image-only models) effectively generalize to multimodal systems?

In this work, we focus on bridging this "modality gap" by investigating the applicability of SAE-derived features in multimodal settings. Specifically, we explore LLaVA (Liu et al. 2024), a popular multimodal model that integrates vision and language tasks. Leveraging the CIFAR-100 dataset, which provides a challenging fine-grained classification task, we assess the transferability and interpretability of features learned from base models in this multimodal context. Through a detailed layer-wise analysis, we investigate the semantic evolution of tokens and evaluate the utility of these features in downstream classification tasks.

While previous work has largely focused on the unimodal-to-unimodal transfer of features, our experiments aim to answer whether features extracted from base models can effectively bridge the gap to multimodal applications. This exploration aligns with ongoing efforts to understand how features encode information, how transferable they are across different contexts, and how they can be interpreted when applied to diverse tasks.

This write-up details our exploratory experiments, including:

  1. Layer-wise Feature Analysis: Understanding how token-level representations evolve semantically across layers within the LLaVA framework.
  2. Classification Performance: Evaluating the utility of SAE-derived features on CIFAR-100 for fine-grained image classification.
  3. Transferability: Assessing the generalization of features trained in unimodal contexts to a multimodal setting.

Our findings contribute to advancing the interpretability and universality of features in large models, paving the way for more robust, explainable, and cross-modal AI systems.

Some Background on LLaVA:

LLaVA (Liu et al. 2023) is a multimodal framework that integrates vision and language tasks. By combining a Vision Encoder and a Language Model, LLaVA processes both image and textual inputs to generate coherent and contextually appropriate language-based outputs. 

A visual representation of the LLaVA model architecture from (Liu et al. 2023). This diagram illustrates the flow of information from image input through the Vision Encoder, projection layer, and into the Language Model, culminating in generating text outputs.

Key Components

Vision Encoder:

Language Model:

Token Structure:

Output:

The final output of the LLaVA model is a text-based response that reflects both the visual content of the input image and the language instructions provided. This enables a wide range of applications, from answering questions about an image to generating detailed image captions.

Training:

LLaVA’s multimodal alignment is realized during visual instruction tuning, the fine-tuning of the Language Model using multimodal instruction-following data, where each textual instruction is paired with corresponding visual inputs. During this process, the model learns to interpret visual data in conjunction with textual context, which aligns visual features with language features.

Evaluating SAE Transferability with LLaVA

LLaVA’s architecture provides an ideal testbed for evaluating the transferability of SAEs. By leveraging its unified token space and multimodal alignment, we can assess how well unimodal features extracted by SAEs adapt to multimodal contexts. Specifically, LLaVA’s ability to process and integrate image and text tokens allows us to analyze the semantic evolution of SAE-derived features across its layers, offering insights into their utility and generalization capabilities in multimodal scenarios.

In this study, we utilize the Intel Gemma-2B LLaVA 1.5-based model (Intel/llava-gemma-2b) as the foundation for our experiments. For feature extraction, we incorporate pre-trained SAEs from jbloom/Gemma-2b-Residual-Stream-SAEs, trained on the Gemma-1-2B model. These SAEs include 16,384 features (an expansion factor of 8 × 2048) and are designed to capture sparse and interpretable representations.

Our analysis focuses on evaluating the layer-wise integration of these features within LLaVA to determine their effectiveness in bridging unimodal-to-multimodal gaps. Specifically, we assess their impact on semantic alignment, and classification performance. We hypothesized that the text-trained SAE features were still meaningful to LLaVA. 

Experimental Design

Dataset

We used the CIFAR-100 (Krizhevsky et al. 2009dataset, which comprises:

Features and Evaluation

  1. Feature representation: We extracted SAE embeddings consisting of 576 tokens derived from image patches.
  2. Linear Classifier: A linear probe using the End-of-Image (EoI) token.
  3. Evaluation Metric: The Macro F1 score to evaluate the balance between precision and recall across all classes.
  4. Layer-Wise Analysis: Feature dynamics were analyzed across layers (0, 6, 10, 12, and 12-it (using an instruct model trained SAE instead)).
Figure 1: How did we extract and apply SAE features to LLaVa

Is there any signal?

We implemented the outlined procedure and analyzed the retrieved features to evaluate whether meaningful features could be identified through this transfer method. As a first step, a preliminary cleaning process was conducted to refine the feature set before delving into the detailed retrieved features and their auto-interpretability explanations.

The objective of the cleaning process was to eliminate features that appeared to be disproportionately represented across instances, which could introduce noise, diminish interpretability, or indicate unaligned or non-transferable features. Considering the CIFAR-100 dataset, which comprises 100 labels with 100 instances per label, the expected maximum occurrence of any feature under uniform distribution is approximately 100. To address potential anomalies, a higher threshold of 1000 occurrences was selected as the cutoff for identifying and excluding overrepresented features. This conservative threshold ensured that dominant, potentially less informative features were removed while retaining those likely to contribute meaningfully to the analysis.

After cleaning, we examined the retrieved features across different model layers (0–12 of 19 layers). A clear trend emerged: deeper layers exhibited increasingly useful features.

Below, we provide examples of retrieved features from both high-performing and underperforming classes, demonstrating the range of interpretability outcomes:

1. Dolphin 🐬

Layer 0

Layer 6

Layer 10

Layer 12

Layer 12-it

2. Skyscraper 🏙️

Layer 0

Layer 6

Layer 10

Layer 12

Layer 12-it

3. Boy 👦

Layer 0

Layer 6

Layer 10

Layer 12

Layer 12-it

4. Cloud ☁️

Layer 0

Layer 6

Layer 10

Layer 12

Layer 12-it

Classification Analysis

Building on the feature extraction process, we shifted focus to an equally critical question: Could the extracted features meaningfully classify CIFAR-100 labels? Specifically, we aimed to determine whether these features could reliably distinguish between diverse categories such as "dolphin" and "skyscraper." Additionally, we investigated how choices like binarization and layer selection influenced the robustness and effectiveness of the classification process.

Here, we outline our methodology, key findings, and their broader implications.

Classification Setup

We implemented a linear classification pipeline to evaluate the retrieved features' predictive utility. Features were collected from multiple layers of the model and underwent the following preparation steps:

Feature Pooling

Features were aggregated along the token dimension using two strategies:

Activation Transformation

We explored the impact of activation scaling on performance:

Layer Evaluation

Features were extracted from Layers 6, 10, and 17 of the model. A linear classifier was trained using the features of each layer, and performance was assessed with Macro F1 scores. This ensured a balanced evaluation across all CIFAR-100 categories, allowing us to identify robustness, efficiency, and interpretability trends across different configurations.

Figure 2: Details on feature pooling; once we obtain the features, we can feed them into our classifier. 

Classification Findings

Performance Summary

So, in a way, we actually nearly recovered the full VIT performance here!

1. How Many Features Do We Need?

We tested a range of feature selection methods, from summing activations over all tokens to taking only the top-1 activation per token.

What We Found:

Takeaway: Retaining a larger set of features preserves more discriminative information for CIFAR100, and this may have different imports across layers).

2. Which Layer Performs Best?

We tested features from Layers 6, 10, and 17 to see which part of the model provided the best representations.

What We Found:

Layer 10 Superiority: Features from Layer 10 consistently achieved the highest Macro F1 scores, balancing generalization and specificity.

Takeaway: Mid-level features (Layer 10) offered the best trade-off for CIFAR-100 classification.

3. To Binarize or Not to Binarize?

We compared binarized activations, which have cap values, with non-binarized ones. The idea is that binarization reduces noise and keeps things simple.

What We Found:

Binarized vs. Non-Binarized: Binarized features outperformed non-binarized counterparts, particularly with smaller feature sets.

Takeaway: Binarization improves performance, especially under limited feature budgets.

Binaired feature helps

4. Data Efficiency: How Much Training Data Do We Need?

We tested how well the features worked when we varied the amount of training data, from small splits (1 train, 9 test) to larger splits (5 train, 5 test).

What We Found:

Layer 17 Limitations: Performance for Layer 17 improved with increased data but lagged under low-data conditions.

Takeaway: Binarized middle-layer features (e.g., Layer 10) were the most data-efficient option.

Data efficiency using SAE features to do classification

Big Picture Insights

So, what did we learn from all this? Here are the three big takeaways:

  1. More Features = Better Results: Don’t be stingy with feature selection. Larger feature sets lead to better classification, especially with middle layers.
  2. Binarization is a Game-Changer: It’s simple but effective. By capping activations, binarization reduces noise and makes features more robust, particularly in low-resource scenarios.
  3. Layer Choice Matters: Not all layers are created equal. Middle layers (like Layer 10) provided the best representations, balancing generalization and specificity.

What’s Next?

These findings open up exciting transfer learning and feature design possibilities in multimodal systems. We’ve shown that thoughtful feature selection and transformation can make a big difference even with simple linear classifiers.

For future work, we’re interested in exploring:

  1. ^

    The authors acknowledge financial support from the Google PhD Fellowship (SC), the Woods Foundation (DB, SC, JG), the NIH (NIH R01CA294033 (SC, JG, DB), NIH U54CA274516-01A1 (SC, DB) and the American Cancer Society and American Society for Radiation Oncology, ASTRO-CSDG-24-1244514-01-CTPS Grant DOI #: https://doi.org/10.53354/ACS.ASTRO-CSDG-24-1244514-01-CTPS.pc.gr.222210 (DB)

  2. ^

    It is very intersting that models are more focusing on the cloth and potential parental care here.

    Which lead me to think another question here, should we train SAE jointly for vision and text, or should we do it separately for multimodal systems.

  3. ^

    We actually did image_net 1k too, and it is still running due to size. We are seeing 0.49 Macro F1 for layer 6...

0 comments

Comments sorted by top scores.