Latent Explorer
To use PyQt5 interactivity you must download this notebook
[ ]:
!pip install goodfire==0.2.5
Skip if you’re just demoing in Colab
[ ]:
!pip install PyQt5
[ ]:
%matplotlib qt
Run from here
[ ]:
from google.colab import userdata
# Add you Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')
# Or locally read from env variables
[ ]:
import goodfire
client = goodfire.Client(GOODFIRE_API_KEY)
variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")
[ ]:
origin_feature, _ = client.features.search("whales", top_k=1)
[ ]:
from typing import Dict, List, Optional, Union
import numpy as np
class LatentExplorer:
def __init__(self, client: goodfire.Client, model: Union[str, goodfire.Variant]):
self.client = client
self.model = model
self.active_figure = None
self.scatter_points: Dict[int, goodfire.Feature] = {}
self.active_annotation = None
def _create_scatter_plot(
self,
dim_reduction: List[List[float]],
all_features: goodfire.FeatureGroup,
origin: goodfire.Feature,
ax,
) -> None:
import matplotlib.pyplot as plt
"""Create scatter plot with proper picking setup"""
points = np.array(dim_reduction)
# Use the extra dimension for coloring
color_values = points[:, -1]
normalized_colors = (color_values - color_values.min()) / (
color_values.max() - color_values.min()
)
# Create colormap with viridis (better for continuous data)
colors = ["orange" for feature in all_features]
sizes = [50 for feature in all_features]
display_points = points[:, :3]
# Mark origin
display_points = np.vstack([display_points, [0, 0, 0]])
sizes.append(100)
colors.append("red")
# Create scatter plot based on dimensions
scatter = ax.scatter3D(
display_points[:, 0],
display_points[:, 1],
display_points[:, 2],
c=colors,
s=sizes,
picker=True,
alpha=0.6,
)
# Store mapping of point indices to features
self.scatter_points = {idx: feature for idx, feature in enumerate(all_features)}
self.scatter_points[len(self.scatter_points)] = origin
# Store dimension reduction data for hover functionality
self.dim_reduction = display_points
return scatter
def _handle_hover(self, event, ax):
import matplotlib.pyplot as plt
"""Handle mouse hover events"""
if not hasattr(self, "scatter") or event.inaxes != ax:
if self.active_annotation:
self.active_annotation.remove()
self.active_annotation = None
plt.draw()
return
cont, ind = self.scatter.contains(event)
if cont:
if self.active_annotation:
self.active_annotation.remove()
idx = ind["ind"][0]
feature = self.scatter_points[idx]
point = self.dim_reduction[idx]
self.active_annotation = ax.text(
point[0],
point[1],
point[2],
feature.label,
fontsize=8,
alpha=0.9,
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
plt.draw()
elif self.active_annotation:
self.active_annotation.remove()
self.active_annotation = None
plt.draw()
def _handle_pick(self, event, horizon: int, origin: goodfire.Feature):
import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
"""Handle pick events on scatter points"""
if isinstance(event.artist, PathCollection): # Verify it's a scatter point
ind = event.ind[0] # Get the index of picked point
if ind in self.scatter_points:
selected_feature = self.scatter_points[ind]
# Store current window position
old_geometry = self.active_figure.canvas.window().geometry()
# Close current plot
plt.close(self.active_figure)
# Create new chart
self.chart(
origin=selected_feature,
horizon=horizon,
_previous_feature=origin,
window_geometry=old_geometry,
)
def _setup_3d_controls(self, fig, ax):
"""Set up keyboard controls for 3D plot rotation"""
def on_key(event):
if event.key == "left":
ax.view_init(elev=ax.elev, azim=ax.azim - 10)
elif event.key == "right":
ax.view_init(elev=ax.elev, azim=ax.azim + 10)
elif event.key == "up":
ax.view_init(elev=ax.elev + 10, azim=ax.azim)
elif event.key == "down":
ax.view_init(elev=ax.elev - 10, azim=ax.azim)
fig.canvas.draw()
fig.canvas.mpl_connect("key_press_event", on_key)
def chart(
self,
origin: goodfire.Feature,
horizon: int = 100,
_previous_feature: Optional[goodfire.Feature] = None,
window_geometry: Optional[object] = None,
):
import matplotlib.pyplot as plt
# Get neighboring features
local_group = self.client.features._experimental.neighbors(
origin, model=self.model, top_k=horizon
)
# Combine feature groups
all_features = goodfire.FeatureGroup([origin]) | local_group
if _previous_feature:
all_features = all_features | goodfire.FeatureGroup([_previous_feature])
# Perform dimension reduction
dim_reduction = self.client.features._experimental.dimension_reduction(
origin, local_group, self.model, 3
)
# Create figure and axis
fig = plt.figure(figsize=(10, 8))
self.active_figure = fig
ax = fig.add_subplot(111, projection="3d")
self._setup_3d_controls(fig, ax)
# Create scatter plot
self.scatter = self._create_scatter_plot(dim_reduction, local_group, origin, ax)
# Set up event handlers
fig.canvas.mpl_connect(
"pick_event", lambda event: self._handle_pick(event, horizon, origin)
)
fig.canvas.mpl_connect(
"motion_notify_event", lambda event: self._handle_hover(event, ax)
)
fig.set_facecolor("black")
ax.set_facecolor("black")
# Customize plot
ax.set_title(f"Latent Explorer: {origin.label}", color="white")
ax.grid(False)
ax.set_axis_off()
# Add legend with colorbar
legend_elements = [
plt.scatter([], [], c="red", alpha=0.6, s=100, label="Current Feature"),
plt.scatter([], [], c="orange", alpha=0.6, s=50, label="Other Features"),
]
ax.legend(handles=legend_elements, loc="upper right")
plt.tight_layout()
# Show the plot and restore window position if available
plt.show()
if window_geometry:
self.active_figure.canvas.window().setGeometry(window_geometry)
[ ]:
explorer = LatentExplorer(client, variant)
explorer.chart(
origin_feature[0],
horizon=100,
)
Warning: The experimental features API is subject to change.