from typing import Dict, List, Optional, Union
import numpy as np
from ..api.client import Client
from ..features.features import Feature, FeatureGroup
from ..variants.variants import VariantInterface
[docs]
class LatentExplorer:
"""Interactive visualization tool for exploring feature relationships in latent space.
Uses PCA dimensionality reduction and interactive plotting to visualize feature
neighborhoods and relationships. Supports clicking features to explore their local
neighborhoods.
Args:
client (Client): Client instance for API communication
model (Union[str, VariantInterface]): Model or variant to explore
"""
def __init__(self, client: Client, model: Union[str, VariantInterface]):
self.client = client
self.model = model
self.active_figure = None
self.scatter_points: Dict[int, Feature] = {}
self.active_annotation = None
def _create_scatter_plot(
self,
dim_reduction: List[List[float]],
all_features: FeatureGroup,
origin: Feature,
ax,
) -> None:
import matplotlib.pyplot as plt
"""Create scatter plot with proper picking setup"""
points = np.array(dim_reduction)
# Use the extra dimension for coloring
color_values = points[:, -1]
normalized_colors = (color_values - color_values.min()) / (
color_values.max() - color_values.min()
)
# Create colormap with viridis (better for continuous data)
colors = plt.cm.viridis(normalized_colors)
# Make origin point stand out with a different color
origin_idx = None
for idx, feature in enumerate(all_features):
if feature.uuid == origin.uuid:
origin_idx = idx
colors[idx] = [1, 0, 0, 1] # Bright red for origin
break
sizes = [100 if feature.uuid == origin.uuid else 50 for feature in all_features]
display_points = points[:, :3]
# Create scatter plot based on dimensions
scatter = ax.scatter3D(
display_points[:, 0],
display_points[:, 1],
display_points[:, 2],
c=colors,
s=sizes,
picker=True,
alpha=0.6,
)
# Add colorbar
if origin_idx is not None:
# Create a separate scatter for the colorbar that excludes the origin point
non_origin_colors = color_values[np.arange(len(color_values)) != origin_idx]
norm = plt.Normalize(non_origin_colors.min(), non_origin_colors.max())
sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis, norm=norm)
sm.set_array([])
plt.colorbar(sm, ax=ax)
# Store mapping of point indices to features
self.scatter_points = {idx: feature for idx, feature in enumerate(all_features)}
# Store dimension reduction data for hover functionality
self.dim_reduction = display_points
return scatter
def _handle_hover(self, event, ax):
import matplotlib.pyplot as plt
"""Handle mouse hover events"""
if not hasattr(self, "scatter") or event.inaxes != ax:
if self.active_annotation:
self.active_annotation.remove()
self.active_annotation = None
plt.draw()
return
cont, ind = self.scatter.contains(event)
if cont:
if self.active_annotation:
self.active_annotation.remove()
idx = ind["ind"][0]
feature = self.scatter_points[idx]
point = self.dim_reduction[idx]
self.active_annotation = ax.text(
point[0],
point[1],
point[2],
feature.label,
fontsize=8,
alpha=0.9,
bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
)
plt.draw()
elif self.active_annotation:
self.active_annotation.remove()
self.active_annotation = None
plt.draw()
def _handle_pick(self, event, horizon: int, origin: Feature):
import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
"""Handle pick events on scatter points"""
if isinstance(event.artist, PathCollection): # Verify it's a scatter point
ind = event.ind[0] # Get the index of picked point
if ind in self.scatter_points:
selected_feature = self.scatter_points[ind]
# Store current window position
old_geometry = self.active_figure.canvas.window().geometry()
# Close current plot
plt.close(self.active_figure)
# Create new chart
self.chart(
origin=selected_feature,
horizon=horizon,
_previous_feature=origin,
window_geometry=old_geometry,
)
def _setup_3d_controls(self, fig, ax):
"""Set up keyboard controls for 3D plot rotation"""
def on_key(event):
if event.key == "left":
ax.view_init(elev=ax.elev, azim=ax.azim - 10)
elif event.key == "right":
ax.view_init(elev=ax.elev, azim=ax.azim + 10)
elif event.key == "up":
ax.view_init(elev=ax.elev + 10, azim=ax.azim)
elif event.key == "down":
ax.view_init(elev=ax.elev - 10, azim=ax.azim)
fig.canvas.draw()
fig.canvas.mpl_connect("key_press_event", on_key)
def chart(
self,
origin: Feature,
horizon: int = 100,
_previous_feature: Optional[Feature] = None,
window_geometry: Optional[object] = None,
):
import matplotlib.pyplot as plt
# Get neighboring features
local_group = self.client.features._experimental.neighbors(
origin, model=self.model, top_k=horizon
)
# Combine feature groups
all_features = FeatureGroup([origin]) | local_group
if _previous_feature:
all_features = all_features | FeatureGroup([_previous_feature])
# Perform dimension reduction
dim_reduction = self.client.features._experimental.dimension_reduction(
origin, local_group, self.model, 4, "pca"
)
# Create figure and axis
fig = plt.figure(figsize=(10, 8))
self.active_figure = fig
ax = fig.add_subplot(111, projection="3d")
self._setup_3d_controls(fig, ax)
# Create scatter plot
self.scatter = self._create_scatter_plot(dim_reduction, local_group, origin, ax)
# Set up event handlers
fig.canvas.mpl_connect(
"pick_event", lambda event: self._handle_pick(event, horizon, origin)
)
fig.canvas.mpl_connect(
"motion_notify_event", lambda event: self._handle_hover(event, ax)
)
# Customize plot
ax.set_title(f"Latent Explorer: {origin.label}")
ax.grid(True, alpha=0.3, linestyle="--")
# Add legend with colorbar
legend_elements = [
plt.scatter([], [], c="red", alpha=0.6, s=100, label="Current Feature"),
plt.scatter([], [], c="blue", alpha=0.6, s=50, label="Other Features"),
]
ax.legend(handles=legend_elements, loc="upper right")
plt.tight_layout()
# Show the plot and restore window position if available
plt.show()
if window_geometry:
self.active_figure.canvas.window().setGeometry(window_geometry)