from typing import Any, Generator, Iterable, Literal, Optional, Union, overload
from pydantic import ValidationError
import httpx
from ...utils.logger import logger
from ...variants.variants import VariantInterface
from ..constants import PRODUCTION_BASE_URL, SSE_DONE
from ..exceptions import ServerErrorException
from ..utils import HTTPWrapper
from .interfaces import (
ChatCompletion,
ChatMessage,
LogitsResponse,
StreamingChatCompletionChunk,
)
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant who should follow the users requests. Be brief and to the point, but also be friendly and engaging."
class ChatAPICompletions:
"""OpenAI compatible chat completions API."""
def __init__(self, api_key: str, base_url: str = PRODUCTION_BASE_URL):
self.api_key = api_key
self.base_url = base_url
self._http = HTTPWrapper()
def _get_headers(self):
return {
"Authorization": f"Bearer {self.api_key}",
}
@overload
def create(
self,
messages: Iterable[Union[ChatMessage, dict[str, str]]],
model: Union[str, VariantInterface],
*,
stream: Literal[False] = False,
max_completion_tokens: Optional[int] = None,
top_p: float = 0.9,
temperature: float = 0.6,
stop: Optional[Union[str, list[str]]] = None,
seed: Optional[int] = 42,
) -> ChatCompletion: ...
@overload
def create(
self,
messages: Iterable[Union[ChatMessage, dict[str, str]]],
model: Union[str, VariantInterface],
*,
stream: Literal[True] = True,
max_completion_tokens: Optional[int] = None,
top_p: float = 0.9,
temperature: float = 0.6,
stop: Optional[Union[str, list[str]]] = None,
seed: Optional[int] = 42,
) -> Generator[StreamingChatCompletionChunk, Any, Any]: ...
def create(
self,
messages: Iterable[Union[ChatMessage, dict[str, str]]],
model: Union[str, VariantInterface],
stream: bool = False,
max_completion_tokens: Optional[int] = 2048,
top_p: float = 0.9,
temperature: float = 0.6,
stop: Optional[Union[str, list[str]]] = ["<|eot_id|>", "<|begin_of_text|>"],
timeout: Optional[int] = 320,
seed: Optional[int] = 42,
__system_prompt: str = DEFAULT_SYSTEM_PROMPT,
) -> Union[ChatCompletion, Generator[StreamingChatCompletionChunk, Any, Any]]:
"""Create a chat completion."""
url = f"{self.base_url}/api/inference/v1/chat/completions"
headers = self._get_headers()
if __system_prompt != DEFAULT_SYSTEM_PROMPT:
logger.warning(
"We recommend using Goodfire's default system prompt to maximize intervention stability."
)
messages = [*messages]
if __system_prompt:
messages.insert(0, {"role": "system", "content": __system_prompt})
payload: dict[str, Any] = {
"messages": messages,
"stream": stream,
"max_completion_tokens": max_completion_tokens,
"top_p": top_p,
"temperature": temperature,
"stop": stop,
"seed": seed,
}
if isinstance(model, str):
payload["model"] = model
else:
payload["model"] = model.base_model
payload["controller"] = model.controller.json()
if stream:
def _stream_response() -> Generator[StreamingChatCompletionChunk, Any, Any]:
try:
for chunk in self._http.stream(
"POST",
url,
headers={
**headers,
"Accept": "text/event-stream",
"Connection": "keep-alive",
},
json=payload,
timeout=timeout,
):
chunk = chunk.decode("utf-8")
if chunk == SSE_DONE:
break
json_chunk = chunk.split("data: ")[1].strip()
yield StreamingChatCompletionChunk.model_validate_json(
json_chunk
)
except (httpx.RemoteProtocolError, ValidationError):
raise ServerErrorException()
return _stream_response()
else:
response = self._http.post(
url,
headers={
**headers,
"Accept": "application/json",
},
json=payload,
timeout=timeout,
)
try:
return ChatCompletion.model_validate(response.json())
except ValidationError:
raise ServerErrorException("Server error")
class ExperimentalChatAPI:
"""Experimental chat API."""
def __init__(self, chat_api: "ChatAPI"):
self.chat_api = chat_api
self._warned_user = False
self._http = HTTPWrapper()
def _warn_user(self):
if not self._warned_user:
print("Warning: The experimental chat API is subject to change.")
self._warned_user = True
def logits(
self,
messages: Iterable[Union[ChatMessage, dict[str, str]]],
model: Union[str, VariantInterface],
top_k: Optional[int] = None,
vocabulary: Optional[list[str]] = None,
) -> LogitsResponse:
"""Compute logits for a chat completion."""
payload: dict[str, Any] = {
"messages": messages,
"k": top_k,
"vocabulary": vocabulary,
}
if isinstance(model, str):
payload["model"] = model
else:
payload["model"] = model.base_model
payload["controller"] = model.controller.json()
response = self._http.post(
f"{self.chat_api.base_url}/api/inference/v1/chat/compute-logits",
headers={
**self.chat_api._get_headers(),
},
json=payload,
)
return LogitsResponse.model_validate(response.json())
[docs]
class ChatAPI:
"""OpenAI compatible chat API.
Example:
>>> for token in client.chat.completions.create(
... [
... {"role": "user", "content": "hello"}
... ],
... model="meta-llama/Meta-Llama-3-8B-Instruct",
... stream=True,
... max_completion_tokens=50,
... ):
... print(token.choices[0].delta.content, end="")
"""
def __init__(self, api_key: str, base_url: str = PRODUCTION_BASE_URL):
self.api_key = api_key
self.base_url = base_url
self.completions = ChatAPICompletions(api_key, base_url)
self._experimental = ExperimentalChatAPI(self)
def _get_headers(self):
return {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}