Latent Explorer

Colab here

Hello! This notebook is highly interactive if running using qt which is not compatible with colab. It’s recommended to download this notebook and run in your IDE of choice.

[ ]:
!pip install goodfire==0.2.11

Skip if you’re just demoing in Colab

[ ]:
!pip install PyQt5
[ ]:
%matplotlib qt

Run from here

[ ]:
from google.colab import userdata

# Add you Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')

# Or locally read from env variables
[ ]:
import goodfire

client = goodfire.Client(GOODFIRE_API_KEY)

variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")
[ ]:
origin_feature, _ = client.features.search("whales", top_k=1)
[ ]:
from typing import Dict, List, Optional, Union

import numpy as np


class LatentExplorer:
    def __init__(self, client: goodfire.Client, model: Union[str, goodfire.Variant]):
        self.client = client
        self.model = model
        self.active_figure = None
        self.scatter_points: Dict[int, goodfire.Feature] = {}
        self.active_annotation = None

    def _create_scatter_plot(
        self,
        dim_reduction: List[List[float]],
        all_features: goodfire.FeatureGroup,
        origin: goodfire.Feature,
        ax,
    ) -> None:
        import matplotlib.pyplot as plt

        """Create scatter plot with proper picking setup"""
        points = np.array(dim_reduction)

        # Use the extra dimension for coloring
        color_values = points[:, -1]
        normalized_colors = (color_values - color_values.min()) / (
            color_values.max() - color_values.min()
        )

        # Create colormap with viridis (better for continuous data)
        colors = ["orange" for feature in all_features]

        sizes = [50 for feature in all_features]

        display_points = points[:, :3]

        # Mark origin
        display_points = np.vstack([display_points, [0, 0, 0]])
        sizes.append(100)
        colors.append("red")

        # Create scatter plot based on dimensions
        scatter = ax.scatter3D(
            display_points[:, 0],
            display_points[:, 1],
            display_points[:, 2],
            c=colors,
            s=sizes,
            picker=True,
            alpha=0.6,
        )

        # Store mapping of point indices to features
        self.scatter_points = {idx: feature for idx, feature in enumerate(all_features)}
        self.scatter_points[len(self.scatter_points)] = origin

        # Store dimension reduction data for hover functionality
        self.dim_reduction = display_points

        return scatter

    def _handle_hover(self, event, ax):
        import matplotlib.pyplot as plt

        """Handle mouse hover events"""
        if not hasattr(self, "scatter") or event.inaxes != ax:
            if self.active_annotation:
                self.active_annotation.remove()
                self.active_annotation = None
                plt.draw()
            return

        cont, ind = self.scatter.contains(event)
        if cont:
            if self.active_annotation:
                self.active_annotation.remove()

            idx = ind["ind"][0]
            feature = self.scatter_points[idx]
            point = self.dim_reduction[idx]

            self.active_annotation = ax.text(
                point[0],
                point[1],
                point[2],
                feature.label,
                fontsize=8,
                alpha=0.9,
                bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
            )

            plt.draw()
        elif self.active_annotation:
            self.active_annotation.remove()
            self.active_annotation = None
            plt.draw()

    def _handle_pick(self, event, horizon: int, origin: goodfire.Feature):
        import matplotlib.pyplot as plt
        from matplotlib.collections import PathCollection

        """Handle pick events on scatter points"""
        if isinstance(event.artist, PathCollection):  # Verify it's a scatter point
            ind = event.ind[0]  # Get the index of picked point
            if ind in self.scatter_points:
                selected_feature = self.scatter_points[ind]
                # Store current window position
                old_geometry = self.active_figure.canvas.window().geometry()
                # Close current plot
                plt.close(self.active_figure)
                # Create new chart
                self.chart(
                    origin=selected_feature,
                    horizon=horizon,
                    _previous_feature=origin,
                    window_geometry=old_geometry,
                )

    def _setup_3d_controls(self, fig, ax):
        """Set up keyboard controls for 3D plot rotation"""

        def on_key(event):
            if event.key == "left":
                ax.view_init(elev=ax.elev, azim=ax.azim - 10)
            elif event.key == "right":
                ax.view_init(elev=ax.elev, azim=ax.azim + 10)
            elif event.key == "up":
                ax.view_init(elev=ax.elev + 10, azim=ax.azim)
            elif event.key == "down":
                ax.view_init(elev=ax.elev - 10, azim=ax.azim)
            fig.canvas.draw()

        fig.canvas.mpl_connect("key_press_event", on_key)

    def chart(
        self,
        origin: goodfire.Feature,
        horizon: int = 100,
        _previous_feature: Optional[goodfire.Feature] = None,
        window_geometry: Optional[object] = None,
    ):
        import matplotlib.pyplot as plt

        # Get neighboring features
        local_group = self.client.features._experimental.neighbors(
            origin, model=self.model, top_k=horizon
        )

        # Combine feature groups
        all_features = goodfire.FeatureGroup([origin]) | local_group
        if _previous_feature:
            all_features = all_features | goodfire.FeatureGroup([_previous_feature])

        # Perform dimension reduction
        dim_reduction = self.client.features._experimental.dimension_reduction(
            origin, local_group, self.model, 3
        )

        # Create figure and axis
        fig = plt.figure(figsize=(10, 8))
        self.active_figure = fig

        ax = fig.add_subplot(111, projection="3d")
        self._setup_3d_controls(fig, ax)

        # Create scatter plot
        self.scatter = self._create_scatter_plot(dim_reduction, local_group, origin, ax)

        # Set up event handlers
        fig.canvas.mpl_connect(
            "pick_event", lambda event: self._handle_pick(event, horizon, origin)
        )
        fig.canvas.mpl_connect(
            "motion_notify_event", lambda event: self._handle_hover(event, ax)
        )

        fig.set_facecolor("black")
        ax.set_facecolor("black")

        # Customize plot
        ax.set_title(f"Latent Explorer: {origin.label}", color="white")
        ax.grid(False)
        ax.set_axis_off()

        # Add legend with colorbar
        legend_elements = [
            plt.scatter([], [], c="red", alpha=0.6, s=100, label="Current Feature"),
            plt.scatter([], [], c="orange", alpha=0.6, s=50, label="Other Features"),
        ]
        ax.legend(handles=legend_elements, loc="upper right")

        plt.tight_layout()

        # Show the plot and restore window position if available
        plt.show()
        if window_geometry:
            self.active_figure.canvas.window().setGeometry(window_geometry)

[ ]:
explorer = LatentExplorer(client, variant)

explorer.chart(
    origin_feature[0],
    horizon=100,
)
Warning: The experimental features API is subject to change.
../_images/examples_latent_explorer_12_1.png