Decision Trees with Features
[ ]:
!pip install goodfire==0.2.11
Collecting goodfire==0.2.11
Downloading goodfire-0.2.11-py3-none-any.whl.metadata (1.2 kB)
Requirement already satisfied: httpx<0.28.0,>=0.27.2 in /usr/local/lib/python3.10/dist-packages (from goodfire==0.2.11) (0.27.2)
Collecting ipywidgets<9.0.0,>=8.1.5 (from goodfire==0.2.11)
Downloading ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB)
Requirement already satisfied: numpy<2.0.0,>=1.26.4 in /usr/local/lib/python3.10/dist-packages (from goodfire==0.2.11) (1.26.4)
Requirement already satisfied: pydantic<3.0.0,>=2.9.2 in /usr/local/lib/python3.10/dist-packages (from goodfire==0.2.11) (2.9.2)
Requirement already satisfied: anyio in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (3.7.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (2024.8.30)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (1.0.7)
Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (3.10)
Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (1.3.1)
Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (0.14.0)
Collecting comm>=0.1.3 (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11)
Downloading comm-0.2.2-py3-none-any.whl.metadata (3.7 kB)
Requirement already satisfied: ipython>=6.1.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (7.34.0)
Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (5.7.1)
Collecting widgetsnbextension~=4.0.12 (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11)
Downloading widgetsnbextension-4.0.13-py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.10/dist-packages (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (3.0.13)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.9.2->goodfire==0.2.11) (0.7.0)
Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.9.2->goodfire==0.2.11) (2.23.4)
Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.9.2->goodfire==0.2.11) (4.12.2)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (75.1.0)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11)
Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (3.0.48)
Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (2.18.0)
Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.2.0)
Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.1.7)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (4.9.0)
Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio->httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (1.2.2)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.2.13)
Downloading goodfire-0.2.11-py3-none-any.whl (27 kB)
Downloading ipywidgets-8.1.5-py3-none-any.whl (139 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 139.8/139.8 kB 6.6 MB/s eta 0:00:00
Downloading comm-0.2.2-py3-none-any.whl (7.2 kB)
Downloading widgetsnbextension-4.0.13-py3-none-any.whl (2.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 10.4 MB/s eta 0:00:00
Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 14.0 MB/s eta 0:00:00
[ ]:
from google.colab import userdata
# Add you Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')
Quick start
Initialize the SDK
[ ]:
import goodfire
client = goodfire.Client(GOODFIRE_API_KEY)
[ ]:
import pandas as pd
splits = {'train': 'data/train-00000-of-00001-ab9a3b4799b09589.parquet', 'test': 'data/test-00000-of-00001-8bd1e21c671fb670.parquet', 'valid': 'data/valid-00000-of-00001-303e4ba2afe838d4.parquet'}
df = pd.read_parquet("hf://datasets/ChanceFocus/en-fpb/" + splits["train"])
[ ]:
df
Extract features for each input
[ ]:
variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")
[ ]:
shuffled_df = df.sample(frac=1, random_state=42)
positive_examples = shuffled_df[shuffled_df['answer'] == 'positive']
negative_examples = shuffled_df[shuffled_df['answer'] == 'negative']
neutral_examples = shuffled_df[shuffled_df['answer'] == 'neutral']
FEATURE_COMPUTE_SIZE = 100
CLASSIFIER_FULL_SET_SIZE = 150
[ ]:
positive_news_features, negative_news_features = client.features.contrast(
dataset_1=[
[
{"role": "user", "content": f"Is the following good or bad news for investors? {text}"},
{"role": "assistant", "content": "good"}
] for text in positive_examples[0:FEATURE_COMPUTE_SIZE]['text'].tolist()
],
dataset_2=[
[
{"role": "user", "content": f"Is the following good or bad news for investors? {text}"},
{"role": "assistant", "content": "bad"}
] for text in negative_examples[0:FEATURE_COMPUTE_SIZE]['text'].tolist()
],
dataset_1_feature_rerank_query="bull market",
dataset_2_feature_rerank_query="bear market",
model=variant,
top_k=50
)
features_to_look_at = positive_news_features | negative_news_features
features_to_look_at
[ ]:
from itertools import combinations
class FeatureMixer:
def __init__(self, feature_group):
self.feature_group = feature_group
def grid(self, k_features_per_combo: int =2):
"""Perform a grid search over all possible combinations of features"""
# Get all possible combinations of features
return list(combinations(self.feature_group, k_features_per_combo))
[ ]:
import pandas as pd
import concurrent.futures as futures
import tqdm
MIN_SAMPLES_PER_CLASS = min(
len(negative_examples),
len(positive_examples),
len(neutral_examples),
CLASSIFIER_FULL_SET_SIZE
)
MAX_WORKERS = 3
def _get_feature_acts_for_sample_class(
sample_class: pd.DataFrame,
features_to_use_for_classification: goodfire.FeatureGroup,
k=100,
):
if k < len(features_to_use_for_classification):
raise ValueError("k must be greater than the number of features to use for classification")
samples = []
with futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures_list = []
for idx, row in sample_class[0:MIN_SAMPLES_PER_CLASS].iterrows():
text = row['text']
futures_list.append(
executor.submit(
client.features.inspect,
[
{
"role": "user",
"content": f"is the following good or bad for investors? {text}"
}
],
model=variant,
features=features_to_use_for_classification,
)
)
for future in tqdm.tqdm(futures_list):
context = future.result()
features = context.top(k=k)
samples.append(features)
return samples
print("Computing positive news features...")
positive_class_features = _get_feature_acts_for_sample_class(positive_examples, features_to_look_at, k=100)
print("Computing negative news features...")
negative_class_features = _get_feature_acts_for_sample_class(negative_examples, features_to_look_at, k=100)
print("Computing neutral news features...")
neutral_class_features = _get_feature_acts_for_sample_class(neutral_examples, features_to_look_at, k=100)
[ ]:
from sklearn import tree
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score
import numpy as np
import tqdm
# Grid search may take a while, you can curate the feature list to speed this process up significantly
def train_tree(x, y, depth):
train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=0.5, random_state=42)
# Create a nice regularized tree
model = tree.DecisionTreeClassifier(
max_depth=depth,
min_samples_leaf=len(train_x) // 20,
random_state=42
)
model.fit(train_x, train_y)
pred = model.predict(test_x)
# Calculate the f1 score of the model
accuracy = balanced_accuracy_score(test_y, pred)
score = f1_score(test_y, pred)
return model, pred, score, accuracy
def find_best_combo(features, k_features_per_combo = 2):
combos = FeatureMixer(features).grid(k_features_per_combo=k_features_per_combo)
best_combo = None
best_model = None
mean_act_negative = 0
mean_act_positive = 0
support_vector_distances = 0
best_score = 0
best_accuracy = 0
MAX_WORKERS = 8
futures_list = []
with futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
for combo in combos:
def _test_combo(combo):
# Create a linear regression model
def _select_feature_acts(combo, row):
output = []
for index, feature in enumerate(combo):
for feature_act in row:
if feature_act.feature.uuid == feature.uuid:
output.append(feature_act.activation)
break
return output
x_negative = [_select_feature_acts(combo, row) for row in negative_class_features]
# x_neutral = [_select_feature_acts(combo, row) for row in neutral_class_features]
x_positive = [_select_feature_acts(combo, row) for row in positive_class_features]
y_negative = [-1] * len(x_negative)
# y_neutral = [0] * len(x_neutral)
y_positive = [1] * len(x_positive)
x = x_negative + x_positive
y = y_negative + y_positive
model, pred, score, accuracy = train_tree(x, y, depth=len(combo))
return model, pred, score, accuracy, combo
futures_list.append(executor.submit(_test_combo, combo))
for future in tqdm.tqdm(futures_list):
model, pred, score, accuracy, combo = future.result()
if score > best_score:
best_score = score
best_combo = combo
best_model = model
best_accuracy = accuracy
return best_combo, best_score, best_model, best_accuracy
best_combo_at_k = {}
for i in range(3):
best_combo, best_score, best_model, best_accuracy = find_best_combo(features_to_look_at, k_features_per_combo = i + 1)
print(i + 1, best_combo, best_score, best_accuracy, best_model)
best_combo_at_k[i + 1] = (best_combo, best_score, best_model)
100%|██████████| 85/85 [00:00<00:00, 146.55it/s]
1 (Feature("Ending or terminating relationships, employment, or accounts"),) 0.6859903381642513 0.5827991452991453 DecisionTreeClassifier(max_depth=1, min_samples_leaf=7, random_state=42)
100%|██████████| 3570/3570 [02:01<00:00, 29.38it/s]
2 (Feature("Company announces new business deal or acquisition"), Feature("Evaluative language indicating positive or negative assessment")) 0.7169811320754716 0.703525641025641 DecisionTreeClassifier(max_depth=2, min_samples_leaf=7, random_state=42)
100%|██████████| 98770/98770 [45:25<00:00, 36.23it/s]
3 (Feature("Currency symbols and abbreviations"), Feature("Ending or terminating relationships, employment, or accounts"), Feature("Suppressed or internalized negative emotions")) 0.7272727272727273 0.688034188034188 DecisionTreeClassifier(max_depth=3, min_samples_leaf=7, random_state=42)
[ ]:
# Inspect features to understand their nuances better
best_individual_feature = best_combo_at_k[3][0][0]
client.features._experimental.neighbors(best_individual_feature)
# Seems to be associated with more precise numbers! Corraborated by checking samples where this feature was particularly active.
FeatureGroup([
0: "Measuring precise liquid volumes",
1: "Numerical character attributes in games (10-100 range)",
2: "Units of measurement in scientific and technical contexts",
3: "Numerical values in personal budgeting and cost estimation contexts",
4: "Diverse units of measurement for physical attributes in natural science descriptions",
5: "Large numerical measurements and quantities",
6: "Large monetary amounts in thousands",
7: "Numerical values in text",
8: "Numerical values in financial and statistical contexts",
9: "Percentage symbol recognition"
])
[ ]:
# Anyways let's look at the best overall tree
BEST_TREE_INDEX = 3
best_features = best_combo_at_k[BEST_TREE_INDEX][0]
best_tree = best_combo_at_k[BEST_TREE_INDEX][2]
[ ]:
# Let's visualize the tree
import graphviz
dot_data = tree.export_graphviz(best_tree, out_file=None, feature_names=[feature.label for feature in best_features], class_names=['negative', 'positive'], filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph