Source code for goodfire.api.features.client

from typing import Any, Iterable, Optional, Union

import numpy as np

from ...features.features import Feature, FeatureGroup
from ...variants.variants import VariantInterface
from ..chat.interfaces import ChatMessage
from ..constants import PRODUCTION_BASE_URL
from ..utils import HTTPWrapper
from .interfaces import SearchFeatureResponse


class _ExperimentalFeaturesAPI:
    """A class for accessing experimental features of the Goodfire API."""

    def __init__(
        self,
        features_api: "FeaturesAPI",
    ):
        self.features_api = features_api

        self._warned_user = False

        self._http = HTTPWrapper()

    def _warn_user(self):
        if not self._warned_user:
            print("Warning: The experimental features API is subject to change.")
            self._warned_user = True

    def neighbors(
        self,
        features: Union[Feature, FeatureGroup],
        model: Union[str, VariantInterface] = "meta-llama/Meta-Llama-3-8B-Instruct",
        top_k: int = 10,
    ):
        """Get the nearest neighbors of a feature or group of features."""
        self._warn_user()
        if isinstance(features, Feature):
            features = FeatureGroup([features])

        url = f"{self.features_api.base_url}/api/inference/v1/attributions/neighbors"
        payload = {
            "feature_indices": [feature.index_in_sae for feature in features],
            "model": model if isinstance(model, str) else model.base_model,
            "top_k": top_k,
        }
        headers = self.features_api._get_headers()
        response = self._http.post(url, json=payload, headers=headers)

        response_body = response.json()

        results: list[Feature] = []
        for feature in response_body["neighbors"]:
            results.append(
                Feature(
                    uuid=feature["id"],
                    label=feature["label"],
                    max_activation_strength=feature["max_activation_strength"],
                    index_in_sae=feature["index_in_sae"],
                )
            )

        return FeatureGroup(results)

    def dimension_reduction(
        self,
        center: Feature,
        features: FeatureGroup,
        model: Union[str, VariantInterface] = "meta-llama/Meta-Llama-3-8B-Instruct",
        dimensions: int = 2,
    ) -> list[list[float]]:
        """Reduce the dimensionality of a set of features around a center feature."""
        self._warn_user()
        url = f"{self.features_api.base_url}/api/inference/v1/attributions/dimension-reduction"

        feature_indices = [feature.index_in_sae for feature in features]

        if center.index_in_sae in feature_indices:
            feature_indices.remove(center.index_in_sae)

        payload = {
            "center_feature_index": center.index_in_sae,
            "feature_indices": feature_indices,
            "model": model if isinstance(model, str) else model.base_model,
            "dimensions": dimensions,
            "mode": "pca",
        }
        headers = self.features_api._get_headers()
        response = self._http.post(url, json=payload, headers=headers)

        return response.json()["reduced_features"]


[docs] class FeaturesAPI: """A class for accessing interpretable SAE features of AI models.""" def __init__( self, goodfire_api_key: str, base_url: str = PRODUCTION_BASE_URL, ): self.goodfire_api_key = goodfire_api_key self.base_url = base_url self._experimental = _ExperimentalFeaturesAPI(self) self._http = HTTPWrapper() def _get_headers(self): return { "Authorization": f"Bearer {self.goodfire_api_key}", "Content-Type": "application/json", }
[docs] def search( self, query: str, model: Union[str, VariantInterface] = "meta-llama/Meta-Llama-3-8B-Instruct", top_k: int = 10, ): """Search for features based on a query.""" url = f"{self.base_url}/api/inference/v1/features/search" params = { "query": query, "page": 1, "perPage": top_k, "model": model if isinstance(model, str) else model.base_model, } headers = self._get_headers() response = self._http.get(url, params=params, headers=headers) response = SearchFeatureResponse.model_validate_json(response.text) features: list[Feature] = [] relevance_scores: list[float] = [] for feature in response.features: features.append( Feature( uuid=feature.id, label=feature.label, max_activation_strength=feature.max_activation_strength, index_in_sae=feature.index_in_sae, ) ) relevance_scores.append(feature.relevance) return FeatureGroup(features), relevance_scores
[docs] def rerank( self, features: FeatureGroup, query: str, model: Union[str, VariantInterface] = "meta-llama/Meta-Llama-3-8B-Instruct", top_k: int = 10, ): """Rerank a set of features based on a query.""" url = f"{self.base_url}/api/inference/v1/features/rerank" payload = { "query": query, "top_k": top_k, "model": model if isinstance(model, str) else model.base_model, "feature_ids": [str(feature.uuid) for feature in features], } headers = self._get_headers() response = self._http.post(url, json=payload, headers=headers) response = SearchFeatureResponse.model_validate_json(response.text) features_to_return: list[Feature] = [] for feature in response.features: features_to_return.append( Feature( uuid=feature.id, label=feature.label, max_activation_strength=feature.max_activation_strength, index_in_sae=feature.index_in_sae, ) ) return FeatureGroup(features_to_return)
[docs] def inspect( self, messages: list[ChatMessage], model: Union[str, VariantInterface] = "meta-llama/Meta-Llama-3-8B-Instruct", features: Optional[Union[Feature, FeatureGroup]] = None, ): """Retrieve feature activations for a set of messages.""" payload: dict[str, Any] = { "messages": messages, "aggregate_by": "count", } if isinstance(model, str): payload["model"] = model else: payload["model"] = model.base_model payload["controller"] = model.controller.json() include_feature_ids: Optional[set[str]] = None if features: if isinstance(features, Feature): include_feature_indexes = [features.index_in_sae] include_feature_ids = {str(features.uuid)} else: include_feature_indexes: list[int] = [] include_feature_ids = set() for f in features: include_feature_ids.add(str(f.uuid)) include_feature_indexes.append(f.index_in_sae) payload["include_feature_indexes"] = include_feature_indexes response = self._http.post( f"{self.base_url}/api/inference/v1/attributions/compute-features", headers=self._get_headers(), json=payload, ) return ContextInspector( self, response.json(), include_feature_ids=include_feature_ids )
[docs] def contrast( self, dataset_1: list[list[ChatMessage]], dataset_2: list[list[ChatMessage]], model: Union[str, VariantInterface] = "meta-llama/Meta-Llama-3-8B-Instruct", dataset_1_feature_rerank_query: Optional[str] = None, dataset_2_feature_rerank_query: Optional[str] = None, top_k: int = 5, ): """Identify features that differentiate between two conversation datasets. Args: dataset_1: First conversation dataset dataset_2: Second conversation dataset model: Model identifier or variant interface dataset_1_feature_rerank_query: Optional query to rerank dataset_1 features dataset_2_feature_rerank_query: Optional query to rerank dataset_2 features top_k: Number of top features to return (default: 5) Returns: tuple: Two FeatureGroups containing: - Features steering towards dataset_1 - Features steering towards dataset_2 Each Feature has properties: - uuid: Unique feature identifier - label: Human-readable feature description - max_activation_strength: Feature activation strength - index_in_sae: Index in sparse autoencoder Raises: ValueError: If datasets are empty or have different lengths Example: >>> dataset_1 = [[ ... {"role": "user", "content": "Hi how are you?"}, ... {"role": "assistant", "content": "I'm doing well..."} ... ]] >>> dataset_2 = [[ ... {"role": "user", "content": "Hi how are you?"}, ... {"role": "assistant", "content": "Arr my spirits be high..."} ... ]] >>> features_1, features_2 = client.features.contrast( ... dataset_1=dataset_1, ... dataset_2=dataset_2, ... model=model, ... dataset_2_feature_rerank_query="pirate", ... top_k=5 ... ) """ if len(dataset_1) != len(dataset_2): raise ValueError("dataset_1 and dataset_2 must have the same length") if len(dataset_1) == 0: raise ValueError("dataset_1 and dataset_2 must have at least one element") url = f"{self.base_url}/api/inference/v1/attributions/contrast" payload = { "dataset_1": dataset_1, "dataset_2": dataset_2, "k_to_add": top_k * 4, "k_to_remove": top_k * 4, "model": model if isinstance(model, str) else model.base_model, } headers = self._get_headers() response = self._http.post(url, json=payload, headers=headers) response_body = response.json() dataset_1_features = FeatureGroup( [ Feature( uuid=feature["id"], label=feature["label"], max_activation_strength=feature["max_activation_strength"], index_in_sae=feature["index_in_sae"], ) for feature in response_body["dataset_1_features"] ] ) dataset_2_features = FeatureGroup( [ Feature( uuid=feature["id"], label=feature["label"], max_activation_strength=feature["max_activation_strength"], index_in_sae=feature["index_in_sae"], ) for feature in response_body["dataset_2_features"] ] ) if dataset_1_feature_rerank_query: dataset_1_features = self.rerank( dataset_1_features, dataset_1_feature_rerank_query, model, top_k=top_k ) if dataset_2_feature_rerank_query: dataset_2_features = self.rerank( dataset_2_features, dataset_2_feature_rerank_query, model, top_k=top_k ) return dataset_1_features, dataset_2_features
[docs] def list(self, ids: "list[str]"): """Get features by their IDs.""" url = f"{self.base_url}/api/inference/v1/features/" params = { "feature_id": ids, } headers = self._get_headers() response = self._http.get(url, params=params, headers=headers) response = SearchFeatureResponse.model_validate_json(response.text) return FeatureGroup( [ Feature( uuid=feature.id, label=feature.label, max_activation_strength=feature.max_activation_strength, index_in_sae=feature.index_in_sae, ) for feature in response.features ] )
class FeatureActivation: def __init__(self, feature: Feature, activation_strength: float): self.feature = feature self.activation = activation_strength def __repr__(self): return str(self) def __str__(self): return ( f"FeatureActivation(feature={self.feature}, activation={self.activation})" ) class FeatureActivations: def __init__(self, acts: Iterable[tuple[Feature, float]]): self._acts = [FeatureActivation(feat, act) for feat, act in acts] def __getitem__(self, idx: int): return self._acts[idx] def __iter__(self): return iter(self._acts) def __len__(self): return len(self._acts) def __repr__(self): return str(self) def __str__(self): response_str = "FeatureActivations(" for index, act in enumerate(self._acts[:10]): response_str += f"\n{index}: ({act.feature}, {act.activation})" if len(self._acts) > 10: response_str += "\n..." response_str = response_str.replace("\n", "\n ") response_str += "\n)" return response_str def vector(self): SAE_SIZE = 65536 array = np.zeros(SAE_SIZE) feature_lookup: dict[int, Feature] = {} for act in self._acts: array[act.feature.index_in_sae] = act.activation feature_lookup[act.feature.index_in_sae] = act.feature return array, feature_lookup class Token: def __init__( self, client: FeaturesAPI, token: str, feature_acts: list[dict[str, Any]] ): self._client = client self._token = token self._feature_acts = feature_acts def __repr__(self): return str(self) def __str__(self): return f'Token("{self._token}")' def inspect(self, k: int = 5): uuids = [act["id"] for act in self._feature_acts[:k]] features = self._client.list(uuids) return FeatureActivations( tuple( (feature, act["activation_strength"]) for feature, act in zip(features, self._feature_acts) ) ) class ContextInspector: def __init__( self, client: FeaturesAPI, context_response: dict[str, Any], include_feature_ids: Optional[set[str]] = None, ): self._client = client self.tokens: list[Token] = [] self._feature_strengths: dict[str, list[float]] = {} if include_feature_ids: for id in include_feature_ids: self._feature_strengths[id] = [0, 0] for token_config in context_response["tokens"]: self.tokens.append( Token(client, token_config["token"], token_config["attributions"]) ) for act in token_config["attributions"]: if abs(act["activation_strength"]) > 0.25: if not self._feature_strengths.get(act["id"]): self._feature_strengths[act["id"]] = [0, 0] self._feature_strengths[act["id"]][0] += 1 self._feature_strengths[act["id"]][1] += act["activation_strength"] for feature_strength in self._feature_strengths.values(): if feature_strength[0]: feature_strength[1] /= feature_strength[0] def __repr__(self): return str(self) def __str__(self): response_str = "ContextInspector(\n" for token in self.tokens[:50]: response_str += f"{token._token}" response_str = response_str.replace("\n", "\n ") if len(self.tokens) >= 50: response_str += "..." response_str += "\n)" return response_str def top(self, k: int = 5): sorted_feature_ids = sorted( list(self._feature_strengths.items()), key=lambda row: row[1][0], reverse=True, ) features = self._client.list([feat[0] for feat in sorted_feature_ids[:k]]) return FeatureActivations( sorted( tuple( (feature, self._feature_strengths[str(feature.uuid)][1]) for feature in features ), key=lambda row: row[1], reverse=True, ) )