Latent Explorer
Hello! This notebook is highly interactive if running using qt
which is not compatible with colab. It’s recommended to download this notebook and run in your IDE of choice.
[ ]:
!pip install goodfire==0.2.11
Skip if you’re just demoing in Colab
[ ]:
!pip install PyQt5
[ ]:
%matplotlib qt
Run from here
[ ]:
from google.colab import userdata
# Add you Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')
# Or locally read from env variables
[ ]:
import goodfire
client = goodfire.Client(GOODFIRE_API_KEY)
variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")
[ ]:
origin_feature, _ = client.features.search("whales", top_k=1)
[ ]:
from typing import Dict, List, Optional, Union
import numpy as np
class LatentExplorer:
def __init__(self, client: goodfire.Client, model: Union[str, goodfire.Variant]):
self.client = client
self.model = model
self.active_figure = None
self.scatter_points: Dict[int, goodfire.Feature] = {}
self.active_annotation = None
def _create_scatter_plot(
self,
dim_reduction: List[List[float]],
all_features: goodfire.FeatureGroup,
origin: goodfire.Feature,
ax,
) -> None:
import matplotlib.pyplot as plt
"""Create scatter plot with proper picking setup"""
points = np.array(dim_reduction)
# Use the extra dimension for coloring
color_values = points[:, -1]
normalized_colors = (color_values - color_values.min()) / (
color_values.max() - color_values.min()
)
# Create colormap with viridis (better for continuous data)
colors = ["orange" for feature in all_features]
sizes = [50 for feature in all_features]
display_points = points[:, :3]
# Mark origin
display_points = np.vstack([display_points, [0, 0, 0]])
sizes.append(100)
colors.append("red")
# Create scatter plot based on dimensions
scatter = ax.scatter3D(
display_points[:, 0],
display_points[:, 1],
display_points[:, 2],
c=colors,
s=sizes,
picker=True,
alpha=0.6,
)
# Store mapping of point indices to features
self.scatter_points = {idx: feature for idx, feature in enumerate(all_features)}
self.scatter_points[len(self.scatter_points)] = origin
# Store dimension reduction data for hover functionality
self.dim_reduction = display_points
return scatter
def _handle_hover(self, event, ax):
import matplotlib.pyplot as plt
"""Handle mouse hover events"""
if not hasattr(self, "scatter") or event.inaxes != ax:
if self.active_annotation:
self.active_annotation.remove()
self.active_annotation = None
plt.draw()
return
cont, ind = self.scatter.contains(event)
if cont:
if self.active_annotation:
self.active_annotation.remove()
idx = ind["ind"][0]
feature = self.scatter_points[idx]
point = self.dim_reduction[idx]
self.active_annotation = ax.text(
point[0],
point[1],
point[2],
feature.label,
fontsize=8,
alpha=0.9,
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
plt.draw()
elif self.active_annotation:
self.active_annotation.remove()
self.active_annotation = None
plt.draw()
def _handle_pick(self, event, horizon: int, origin: goodfire.Feature):
import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
"""Handle pick events on scatter points"""
if isinstance(event.artist, PathCollection): # Verify it's a scatter point
ind = event.ind[0] # Get the index of picked point
if ind in self.scatter_points:
selected_feature = self.scatter_points[ind]
# Store current window position
old_geometry = self.active_figure.canvas.window().geometry()
# Close current plot
plt.close(self.active_figure)
# Create new chart
self.chart(
origin=selected_feature,
horizon=horizon,
_previous_feature=origin,
window_geometry=old_geometry,
)
def _setup_3d_controls(self, fig, ax):
"""Set up keyboard controls for 3D plot rotation"""
def on_key(event):
if event.key == "left":
ax.view_init(elev=ax.elev, azim=ax.azim - 10)
elif event.key == "right":
ax.view_init(elev=ax.elev, azim=ax.azim + 10)
elif event.key == "up":
ax.view_init(elev=ax.elev + 10, azim=ax.azim)
elif event.key == "down":
ax.view_init(elev=ax.elev - 10, azim=ax.azim)
fig.canvas.draw()
fig.canvas.mpl_connect("key_press_event", on_key)
def chart(
self,
origin: goodfire.Feature,
horizon: int = 100,
_previous_feature: Optional[goodfire.Feature] = None,
window_geometry: Optional[object] = None,
):
import matplotlib.pyplot as plt
# Get neighboring features
local_group = self.client.features._experimental.neighbors(
origin, model=self.model, top_k=horizon
)
# Combine feature groups
all_features = goodfire.FeatureGroup([origin]) | local_group
if _previous_feature:
all_features = all_features | goodfire.FeatureGroup([_previous_feature])
# Perform dimension reduction
dim_reduction = self.client.features._experimental.dimension_reduction(
origin, local_group, self.model, 3
)
# Create figure and axis
fig = plt.figure(figsize=(10, 8))
self.active_figure = fig
ax = fig.add_subplot(111, projection="3d")
self._setup_3d_controls(fig, ax)
# Create scatter plot
self.scatter = self._create_scatter_plot(dim_reduction, local_group, origin, ax)
# Set up event handlers
fig.canvas.mpl_connect(
"pick_event", lambda event: self._handle_pick(event, horizon, origin)
)
fig.canvas.mpl_connect(
"motion_notify_event", lambda event: self._handle_hover(event, ax)
)
fig.set_facecolor("black")
ax.set_facecolor("black")
# Customize plot
ax.set_title(f"Latent Explorer: {origin.label}", color="white")
ax.grid(False)
ax.set_axis_off()
# Add legend with colorbar
legend_elements = [
plt.scatter([], [], c="red", alpha=0.6, s=100, label="Current Feature"),
plt.scatter([], [], c="orange", alpha=0.6, s=50, label="Other Features"),
]
ax.legend(handles=legend_elements, loc="upper right")
plt.tight_layout()
# Show the plot and restore window position if available
plt.show()
if window_geometry:
self.active_figure.canvas.window().setGeometry(window_geometry)
[ ]:
explorer = LatentExplorer(client, variant)
explorer.chart(
origin_feature[0],
horizon=100,
)
Warning: The experimental features API is subject to change.