from typing import Any, Literal, Optional, overload
from pydantic import BaseModel
from ...api.constants import PRODUCTION_BASE_URL
from ...api.utils import HTTPWrapper
from ...controller.controller import Controller
from ...features.features import Feature
from ...variants._experimental import ProgrammableVariant
from ...variants.fast import Variant
from ...variants.variants import VariantInterface
class VariantMetaData(BaseModel):
name: str
base_model: str
id: str
[docs]
class VariantsAPI:
"""Client for interacting with the Goodfire Variants API."""
def __init__(self, api_key: str, base_url: str = PRODUCTION_BASE_URL):
self.base_url = base_url
self.api_key = api_key
self._http = HTTPWrapper()
def _get_headers(self):
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
@overload
def get(self, variant_id: str, fast_variant: Literal[True] = True) -> Variant: ...
@overload
def get(
self, variant_id: str, fast_variant: Literal[False] = False
) -> ProgrammableVariant: ...
[docs]
def get(self, variant_id: str, fast_variant: bool = True):
"""Get a model variant by ID."""
url = f"{self.base_url}/api/inference/v1/model-variants/{variant_id}"
headers = self._get_headers()
response = self._http.get(url, headers=headers)
response_json = response.json()
if response_json.get("fastmodel_config") or fast_variant:
model = Variant(
response_json["base_model"],
)
if config := response_json.get("fastmodel_config"):
for edit in config:
model.set(
Feature(
uuid=edit["feature_id"],
label=edit["feature_label"],
max_activation_strength=edit["max_activation_strength"],
index_in_sae=edit["index_in_sae"],
),
edit["value"],
edit["mode"],
)
else:
controller = Controller.from_json(
response_json["controller"],
response_json.get("name", "controller"),
response_json["id"],
)
model = ProgrammableVariant(
base_model=response_json["base_model"],
controller=controller,
)
return model
[docs]
def list(self):
"""List all model variants."""
url = f"{self.base_url}/api/inference/v1/model-variants/"
headers = self._get_headers()
response = self._http.get(url, headers=headers)
response_json = response.json()
return [
VariantMetaData(
name=variant["name"],
base_model=variant["base_model"],
id=variant["id"],
)
for variant in response_json["model_variants"]
]
[docs]
def create(self, variant: VariantInterface, name: str):
"""Create a new model variant with the specified name."""
payload: dict[str, Any] = {
"tokens": [],
"base_model": variant.base_model,
"name": name,
}
if isinstance(variant, Variant):
payload["fastmodel_config"] = variant.json()["fastmodel_config"]
else:
payload["controller"] = variant.controller.json()
url = f"{self.base_url}/api/inference/v1/model-variants/"
headers = self._get_headers()
response = self._http.post(
url,
headers=headers,
json=payload,
)
response_json = response.json()
return response_json["id"]
[docs]
def update(
self, id: str, variant: VariantInterface, new_name: Optional[str] = None
):
"""Update an existing model variant."""
payload: dict[str, Any] = {
"tokens": [],
"base_model": variant.base_model,
}
if isinstance(variant, Variant):
payload["fastmodel_config"] = variant.json()["fastmodel_config"]
else:
payload["controller"] = variant.controller.json()
if new_name:
payload["name"] = new_name
url = f"{self.base_url}/api/inference/v1/model-variants/{id}"
headers = self._get_headers()
self._http.put(
url,
headers=headers,
json=payload,
)
[docs]
def delete(self, id: str):
"""Delete a model variant by ID."""
url = f"{self.base_url}/api/inference/v1/model-variants/{id}"
headers = self._get_headers()
self._http.delete(url, headers=headers)