Source code for goodfire.variants.fast

from collections import OrderedDict
from typing import Literal, Union, overload

from typing_extensions import TypedDict

from ..controller.controller import Controller
from ..features.features import Feature, FeatureGroup
from .variants import SUPPORTED_MODELS, VariantInterface


class FeatureDelta(TypedDict):
    mode: Literal["nudge", "pin"]
    value: Union[float, bool]


class FeatureEdits:
    """A collection of feature modifications with ordered tracking.

    This class manages a set of feature edits using an OrderedDict to maintain
    the order in which edits were applied.
    """

    def __init__(self):
        self._edits: OrderedDict[Feature, FeatureDelta] = OrderedDict()

    def __getitem__(self, feature: Feature) -> FeatureDelta:
        return self._edits[feature]

    def __setitem__(self, feature: Feature, delta: FeatureDelta):
        self._edits[feature] = delta

    def __delitem__(self, feature: Feature):
        self._edits.pop(feature, None)

    def __iter__(self):
        return iter(list(self._edits.items()))

    def __len__(self):
        return len(self._edits)


[docs] class Variant(VariantInterface): """A class representing a variant of a base model with feature modifications. This class allows for creating variants of a base model by applying feature modifications through either nudging or pinning values. Args: base_model (str): Identifier of the base model to create variants from Attributes: base_model (str): The base model identifier edits (FeatureEdits): Collection of feature modifications """ def __init__(self, base_model: SUPPORTED_MODELS): self.base_model = base_model self.edits: FeatureEdits = FeatureEdits() @overload def set( self, feature: Union[Feature, FeatureGroup], value: Union[float, None], mode: Literal["nudge"] = "nudge", ) -> None: ... @overload def set( self, feature: Union[Feature, FeatureGroup], value: Union[float, bool, None], mode: Literal["pin"] = "pin", ) -> None: ...
[docs] def set( self, feature: Union[Feature, FeatureGroup], value: Union[float, bool, None], mode: Literal["nudge", "pin"] = "nudge", ): """Set or modify feature values in the variant. Args: feature (Union[Feature, FeatureGroup]): Feature(s) to modify value (Union[float, bool, None]): Value to apply: - float: For numerical adjustments - bool: For binary states (pin mode only) - None: To clear the modification mode (Literal["nudge", "pin"], optional): Modification mode: - "nudge": Bias the feature strength - "pin": Set the feature strength to a fixed value Defaults to "nudge". """ if value is None: self.clear(feature) return if isinstance(feature, Feature): self.edits[feature] = { "mode": mode, "value": value, } else: for f in feature: self.edits[f] = {"mode": mode, "value": value}
[docs] def clear(self, feature: Union[Feature, FeatureGroup]): """Remove modifications for specified feature(s). Args: feature (Union[Feature, FeatureGroup]): Feature(s) to clear modifications for """ if isinstance(feature, Feature): del self.edits[feature] else: for f in feature: del self.edits[f]
[docs] def reset(self): """Remove all feature modifications.""" self.edits = FeatureEdits()
def __repr__(self): return str(self) def __str__(self): edits = "{" for feature, edit in self.edits: edits += f"\n {feature}: {edit}," edits += "\n }" return f"Variant(\n base_model={self.base_model},\n edits={edits}\n)"
[docs] def json(self): """Convert the variant to a JSON-compatible dictionary. Returns: dict: Dictionary containing base model and feature configurations """ return { "base_model": self.base_model, "fastmodel_config": [ { "feature_id": str(feature.uuid), "feature_label": feature.label, "max_activation_strength": feature.max_activation_strength, "index_in_sae": feature.index_in_sae, "mode": edit["mode"], "value": edit["value"], } for feature, edit in self.edits ], }
@property def controller(self) -> Controller: """Get a controller instance with the variant's modifications applied. Returns: Controller: Controller instance with feature modifications """ controller = Controller() for feature, edit in self.edits: value = edit["value"] if isinstance(value, bool): value = 0.5 if value else -0.3 if edit["mode"] == "nudge": controller[feature] += edit["value"] else: controller[feature] = edit["value"] return controller