Source code for goodfire.features.features

from collections import OrderedDict
from typing import Any, Optional, Union, overload
from uuid import UUID

from .interfaces import CONDITIONAL_OPERATOR, JOIN_OPERATOR


class FeatureNotInGroupError(Exception):
    pass


[docs] class Feature: """A class representing a single feature aka a conceptual unit of the SAE. Handles individual feature operations and comparisons. Features can be combined into groups and compared using standard operators. Attributes: uuid (UUID): Unique identifier for the feature label (str): Human-readable label describing the feature max_activation_strength (float): Maximum activation strength of the feature in the training dataset index_in_sae (int): Index position in the SAE """ def __init__( self, uuid: UUID, label: str, max_activation_strength: float, index_in_sae: int ): """Initialize a new Feature instance. Args: uuid: Unique identifier for the feature label: Human-readable label describing the feature max_activation_strength: Maximum activation strength of the feature index_in_sae: Index position in the SAE """ self.uuid = uuid self.label = label self.max_activation_strength = max_activation_strength self.index_in_sae = index_in_sae def json(self): return { # Change to hex while passing through http. "uuid": self.uuid.hex if isinstance(self.uuid, UUID) else self.uuid, "label": self.label, "max_activation_strength": self.max_activation_strength, "index_in_sae": self.index_in_sae, } @staticmethod def from_json(data: dict[str, Any]): # If str is provided, update it to UUID. if isinstance(data["uuid"], str): data["uuid"] = UUID(data["uuid"]) return Feature( uuid=data["uuid"], label=data["label"], max_activation_strength=data["max_activation_strength"], index_in_sae=data["index_in_sae"], ) def __or__(self, other: "Feature"): group = FeatureGroup() group.add(self) group.add(other) return group def __repr__(self) -> str: return str(self) def __hash__(self): return hash(self.uuid) def __str__(self): return f'Feature("{self.label}")' def __eq__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": return FeatureGroup([self]) == other def __ne__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": return FeatureGroup([self]) != other def __le__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": return FeatureGroup([self]) <= other def __lt__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": return FeatureGroup([self]) < other def __ge__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": return FeatureGroup([self]) >= other def __gt__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": return FeatureGroup([self]) > other
[docs] class FeatureGroup: """A collection of Feature instances with group operations. Provides functionality for managing and operating on groups of features, including union and intersection operations, indexing, and comparison operations. Example: >>> feature_group = FeatureGroup([feature1, feature2, feature3, feature4]) >>> # Access single feature by index >>> first_feature = feature_group[0] # Returns Feature >>> >>> # Slice features >>> first_two = feature_group[0:2] # Returns FeatureGroup with features 0,1 >>> last_two = feature_group[-2:] # Returns FeatureGroup with last 2 features >>> >>> # Multiple indexes using list or tuple >>> selected = feature_group[[0, 2]] # Returns FeatureGroup with features 0,2 >>> selected = feature_group[0, 3] # Returns FeatureGroup with features 0,3 """ def __init__(self, features: Optional[list["Feature"]] = None): self._features: OrderedDict[int, "Feature"] = OrderedDict() if features: for feature in features: self.add(feature) def __iter__(self): for feature in self._features.values(): yield feature @overload def __getitem__(self, index: int) -> "Feature": ... @overload def __getitem__(self, index: list[int]) -> "FeatureGroup": ... @overload def __getitem__(self, index: slice) -> "FeatureGroup": ... @overload def __getitem__(self, index: tuple[int, ...]) -> "FeatureGroup": ... def __getitem__(self, index: Union[int, list[int], tuple[int, ...], slice]): if isinstance(index, int): if index not in self._features: raise FeatureNotInGroupError(f"Feature with ID {index} not in group.") return self._features[index] elif isinstance(index, list) or isinstance(index, tuple): if isinstance(index, tuple): index = list(index) features: list[Feature] = [] failed_indexes: list[int] = [] while len(index) > 0: latest_index = index.pop(0) try: features.append(self._features[latest_index]) except KeyError: failed_indexes.append(latest_index) if len(failed_indexes) > 0: raise FeatureNotInGroupError( f"Features with IDs {failed_indexes} not in group." ) return FeatureGroup(features) else: start = index.start if index.start else 0 stop = index.stop if index.stop else len(self._features) step = index.step if index.step else 1 if start < 0: start = len(self._features) + start if stop < 0: stop = len(self._features) + stop if step < 0: start, stop = stop, start if stop > len(self._features): stop = len(self._features) if start > len(self._features): start = len(self._features) if step == 0: raise ValueError("Step cannot be zero.") return FeatureGroup([self._features[i] for i in range(start, stop, step)]) def __repr__(self): return str(self)
[docs] def pick(self, feature_indexes: list[int]): """Create a new FeatureGroup with selected features. Args: feature_indexes: List of indexes to select Returns: FeatureGroup: New group containing only the selected features """ new_group = FeatureGroup() for index in feature_indexes: new_group.add(self._features[index]) return new_group
def json(self) -> dict[str, Any]: return {"features": [f.json() for f in self._features.values()]} @staticmethod def from_json(data: dict[str, Any]): return FeatureGroup([Feature.from_json(f) for f in data["features"]])
[docs] def add(self, feature: "Feature"): """Add a feature to the group. Args: feature: Feature instance to add to the group """ self._features[len(self._features)] = feature
[docs] def pop(self, index: int): """Remove and return a feature at the specified index. Args: index: Index of the feature to remove Returns: Feature: The removed feature """ feature = self._features[index] del self._features[index] return feature
[docs] def union(self, feature_group: "FeatureGroup"): """Combine this group with another feature group. Args: feature_group: Another FeatureGroup to combine with Returns: FeatureGroup: New group containing features from both groups """ new_group = FeatureGroup() new_features: OrderedDict[int, Feature] = OrderedDict() for index, feature in self._features.items(): new_features[index] = feature for index, feature in feature_group._features.items(): new_features[len(self._features) + index] = feature new_group._features = new_features return new_group
[docs] def intersection(self, feature_group: "FeatureGroup"): """Create a new group with features common to both groups. Args: feature_group: Another FeatureGroup to intersect with Returns: FeatureGroup: New group containing only features present in both groups """ new_group = FeatureGroup() new_features: OrderedDict[int, Feature] = OrderedDict() index_in_new_group = 0 for _, feature in self._features.items(): if feature in feature_group: new_features[index_in_new_group] = feature index_in_new_group += 1 new_group._features = new_features return new_group
def __or__(self, other: "FeatureGroup"): return self.union(other) def __and__(self, other: "FeatureGroup"): return self.intersection(other) def __len__(self): return len(self._features) def __str__(self): features = list(self._features.items()) if len(features) <= 10: features_str = ",\n ".join( [f'{index}: "{f.label}"' for index, f in features[:10]] ) else: features_str = ",\n ".join( [f'{index}: "{f.label}"' for index, f in features[:9]] ) features_str += ",\n ...\n " features_str += ",\n ".join( [f'{index}: "{f.label}"' for index, f in features[-1:]] ) return f"FeatureGroup([\n {features_str}\n])" def __eq__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": if isinstance(other, Feature): return self == FeatureGroup([other]) else: return Conditional(self, other, "==") def __ne__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": if isinstance(other, Feature): return self != FeatureGroup([other]) else: return Conditional(self, other, "!=") def __le__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": if isinstance(other, Feature): return self <= FeatureGroup([other]) else: return Conditional(self, other, "<=") def __lt__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": if isinstance(other, Feature): return self < FeatureGroup([other]) else: return Conditional(self, other, "<") def __ge__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": if isinstance(other, Feature): return self >= FeatureGroup([other]) else: return Conditional(self, other, ">=") def __gt__( self, other: Union[ "FeatureGroup", "Feature", "FeatureStatistic", float, ], ) -> "Conditional": if isinstance(other, Feature): return self > FeatureGroup([other]) else: return Conditional(self, other, ">")
class FeatureStatistic: def __init__(self, initial_values: dict[UUID, float]): self._values = initial_values def json(self): return {"values": self._values} @staticmethod def from_json(data: dict[str, Any]): return FeatureStatistic(data["values"]) def copy(self): return FeatureStatistic({**self._values}) def _check_keys(self, other: "FeatureStatistic"): if len(set(list(self._values.keys()) + list(other._values.keys()))) != len( self._values.keys() ): raise ValueError() def __add__(self, other: Union["FeatureStatistic", float]): if isinstance(other, FeatureStatistic): self._check_keys(other) for key, val in other._values.items(): self._values[key] += val elif isinstance(other, float): for key, val in self._values.items(): self._values[key] += other else: raise ValueError() return self def __sub__(self, other: Union["FeatureStatistic", float]): if isinstance(other, FeatureStatistic): self._check_keys(other) for key, val in other._values.items(): self._values[key] -= val elif isinstance(other, float): for key, val in self._values.items(): self._values[key] -= other else: raise ValueError() return self def __neg__(self): copy = self.copy() copy.__mul__(-1) return copy def __mul__(self, other: Union["FeatureStatistic", float]): if isinstance(other, FeatureStatistic): self._check_keys(other) for key, val in other._values.items(): self._values[key] *= val elif isinstance(other, float): for key, val in self._values.items(): self._values[key] *= other else: raise ValueError() return self def __pow__(self, other: Union["FeatureStatistic", float]): if isinstance(other, FeatureStatistic): self._check_keys(other) for key, val in other._values.items(): self._values[key] **= val elif isinstance(other, float): for key, val in self._values.items(): self._values[key] **= other else: raise ValueError() return self def __floordiv__(self, other: Union["FeatureStatistic", float]): if isinstance(other, FeatureStatistic): self._check_keys(other) for key, val in other._values.items(): self._values[key] //= val elif isinstance(other, float): for key, val in self._values.items(): self._values[key] //= other else: raise ValueError() return self def __truediv__(self, other: Union["FeatureStatistic", float]): if isinstance(other, FeatureStatistic): self._check_keys(other) for key, val in other._values.items(): self._values[key] /= val elif isinstance(other, float): for key, val in self._values.items(): self._values[key] /= other else: raise ValueError() return self def __iter__(self): for value in self._values.values(): yield value def __len__(self): return len(self._values.keys()) class ConditionalGroup: """Groups multiple conditions with logical operators. Manages groups of conditions that can be combined using AND/OR operations. """ def __init__( self, conditionals: list["Conditional"], operator: JOIN_OPERATOR = "AND" ): """Initialize a new ConditionalGroup. Args: conditionals: List of Conditional instances to group operator: Logical operator to join conditions ("AND" or "OR") """ self.conditionals = conditionals self.operator = operator def json(self) -> dict[str, Any]: """Convert the conditional group to a JSON-serializable dictionary. Returns: dict: Dictionary containing conditionals and operator """ return { "conditionals": [c.json() for c in self.conditionals], "operator": self.operator, } @staticmethod def from_json(data: dict[str, Any]): """Create a ConditionalGroup instance from JSON data. Args: data: Dictionary containing conditionals and operator Returns: ConditionalGroup: New instance with the deserialized data """ return ConditionalGroup( [Conditional.from_json(c) for c in data["conditionals"]], operator=data["operator"], ) def __and__( self, other: Union["ConditionalGroup", "Conditional"] ) -> "ConditionalGroup": if isinstance(other, Conditional): other_group = ConditionalGroup([other]) else: other_group: ConditionalGroup = other return ConditionalGroup( self.conditionals + other_group.conditionals, operator="AND" ) def __or__( self, other: Union["ConditionalGroup", "Conditional"] ) -> "ConditionalGroup": if isinstance(other, Conditional): other_group = ConditionalGroup([other]) else: other_group: ConditionalGroup = other return ConditionalGroup( self.conditionals + other_group.conditionals, operator="OR" ) class Conditional: """Represents a conditional expression comparing features. Handles comparison operations between features, feature groups, and statistics. """ def __init__( self, left_hand: FeatureGroup, right_hand: Union[Feature, FeatureGroup, FeatureStatistic, float], operator: CONDITIONAL_OPERATOR, ): """Initialize a new Conditional. Args: left_hand: FeatureGroup for the left side of the comparison right_hand: Value to compare against (Feature, FeatureGroup, FeatureStatistic, or float) operator: Comparison operator to use """ self.left_hand = left_hand self.right_hand = right_hand self.operator = operator def json(self) -> dict[str, Any]: """Convert the conditional to a JSON-serializable dictionary. Returns: dict: Dictionary containing the conditional expression data """ return { "left_hand": self.left_hand.json(), "right_hand": ( self.right_hand.json() if getattr(self.right_hand, "json", None) else self.right_hand ), "operator": self.operator, } @staticmethod def from_json(data: dict[str, Any]): """Create a Conditional instance from JSON data. Args: data: Dictionary containing conditional expression data Returns: Conditional: New instance with the deserialized data """ return Conditional( FeatureGroup.from_json(data["left_hand"]), ( FeatureStatistic.from_json(data["right_hand"]) if isinstance(data["right_hand"], dict) else data["right_hand"] ), data["operator"], ) def __and__(self, other: "Conditional") -> ConditionalGroup: return ConditionalGroup([self, other], operator="AND") def __or__(self, other: "Conditional") -> ConditionalGroup: return ConditionalGroup([self, other], operator="OR")