Decision Trees with Features

Colab here

[ ]:
!pip install goodfire==0.2.11
Collecting goodfire==0.2.11
  Downloading goodfire-0.2.11-py3-none-any.whl.metadata (1.2 kB)
Requirement already satisfied: httpx<0.28.0,>=0.27.2 in /usr/local/lib/python3.10/dist-packages (from goodfire==0.2.11) (0.27.2)
Collecting ipywidgets<9.0.0,>=8.1.5 (from goodfire==0.2.11)
  Downloading ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB)
Requirement already satisfied: numpy<2.0.0,>=1.26.4 in /usr/local/lib/python3.10/dist-packages (from goodfire==0.2.11) (1.26.4)
Requirement already satisfied: pydantic<3.0.0,>=2.9.2 in /usr/local/lib/python3.10/dist-packages (from goodfire==0.2.11) (2.9.2)
Requirement already satisfied: anyio in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (3.7.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (2024.8.30)
Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (1.0.7)
Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (3.10)
Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (1.3.1)
Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (0.14.0)
Collecting comm>=0.1.3 (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11)
  Downloading comm-0.2.2-py3-none-any.whl.metadata (3.7 kB)
Requirement already satisfied: ipython>=6.1.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (7.34.0)
Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (5.7.1)
Collecting widgetsnbextension~=4.0.12 (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11)
  Downloading widgetsnbextension-4.0.13-py3-none-any.whl.metadata (1.6 kB)
Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.10/dist-packages (from ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (3.0.13)
Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.9.2->goodfire==0.2.11) (0.7.0)
Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.9.2->goodfire==0.2.11) (2.23.4)
Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3.0.0,>=2.9.2->goodfire==0.2.11) (4.12.2)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (75.1.0)
Collecting jedi>=0.16 (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (3.0.48)
Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (2.18.0)
Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.2.0)
Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.1.7)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (4.9.0)
Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio->httpx<0.28.0,>=0.27.2->goodfire==0.2.11) (1.2.2)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=6.1.0->ipywidgets<9.0.0,>=8.1.5->goodfire==0.2.11) (0.2.13)
Downloading goodfire-0.2.11-py3-none-any.whl (27 kB)
Downloading ipywidgets-8.1.5-py3-none-any.whl (139 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 139.8/139.8 kB 6.6 MB/s eta 0:00:00
Downloading comm-0.2.2-py3-none-any.whl (7.2 kB)
Downloading widgetsnbextension-4.0.13-py3-none-any.whl (2.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 10.4 MB/s eta 0:00:00
Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 14.0 MB/s eta 0:00:00

[ ]:
from google.colab import userdata

# Add you Goodfire API Key to your Colab secrets
GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')

Quick start

Initialize the SDK

[ ]:
import goodfire

client = goodfire.Client(GOODFIRE_API_KEY)
[ ]:
import pandas as pd

splits = {'train': 'data/train-00000-of-00001-ab9a3b4799b09589.parquet', 'test': 'data/test-00000-of-00001-8bd1e21c671fb670.parquet', 'valid': 'data/valid-00000-of-00001-303e4ba2afe838d4.parquet'}
df = pd.read_parquet("hf://datasets/ChanceFocus/en-fpb/" + splits["train"])
[ ]:
df

Extract features for each input

[ ]:
variant = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")
[ ]:
shuffled_df = df.sample(frac=1, random_state=42)
positive_examples = shuffled_df[shuffled_df['answer'] == 'positive']
negative_examples = shuffled_df[shuffled_df['answer'] == 'negative']
neutral_examples = shuffled_df[shuffled_df['answer'] == 'neutral']

FEATURE_COMPUTE_SIZE = 100
CLASSIFIER_FULL_SET_SIZE = 150
[ ]:
positive_news_features, negative_news_features = client.features.contrast(
    dataset_1=[
      [
          {"role": "user", "content": f"Is the following good or bad news for investors? {text}"},
          {"role": "assistant", "content": "good"}
      ] for text in positive_examples[0:FEATURE_COMPUTE_SIZE]['text'].tolist()
    ],
    dataset_2=[
      [
          {"role": "user", "content": f"Is the following good or bad news for investors? {text}"},
          {"role": "assistant", "content": "bad"}
      ] for text in negative_examples[0:FEATURE_COMPUTE_SIZE]['text'].tolist()
    ],
    dataset_1_feature_rerank_query="bull market",
    dataset_2_feature_rerank_query="bear market",
    model=variant,
    top_k=50
)
features_to_look_at = positive_news_features | negative_news_features
features_to_look_at
[ ]:
from itertools import combinations


class FeatureMixer:
  def __init__(self, feature_group):
    self.feature_group = feature_group

  def grid(self, k_features_per_combo: int =2):
    """Perform a grid search over all possible combinations of features"""

    # Get all possible combinations of features
    return list(combinations(self.feature_group, k_features_per_combo))

[ ]:
import pandas as pd
import concurrent.futures as futures
import tqdm


MIN_SAMPLES_PER_CLASS = min(
    len(negative_examples),
    len(positive_examples),
    len(neutral_examples),
    CLASSIFIER_FULL_SET_SIZE
)
MAX_WORKERS = 3

def _get_feature_acts_for_sample_class(
    sample_class: pd.DataFrame,
    features_to_use_for_classification: goodfire.FeatureGroup,
    k=100,
):
  if k < len(features_to_use_for_classification):
    raise ValueError("k must be greater than the number of features to use for classification")

  samples = []
  with futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    futures_list = []
    for idx, row in sample_class[0:MIN_SAMPLES_PER_CLASS].iterrows():
      text = row['text']
      futures_list.append(
          executor.submit(
            client.features.inspect,
            [
                {
                    "role": "user",
                    "content": f"is the following good or bad for investors? {text}"
                }
            ],
            model=variant,
            features=features_to_use_for_classification,
        )
      )

    for future in tqdm.tqdm(futures_list):
      context = future.result()

      features = context.top(k=k)
      samples.append(features)

  return samples


print("Computing positive news features...")
positive_class_features = _get_feature_acts_for_sample_class(positive_examples, features_to_look_at, k=100)

print("Computing negative news features...")
negative_class_features = _get_feature_acts_for_sample_class(negative_examples, features_to_look_at, k=100)

print("Computing neutral news features...")
neutral_class_features = _get_feature_acts_for_sample_class(neutral_examples, features_to_look_at, k=100)

[ ]:
from sklearn import tree
from sklearn import svm
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score
import numpy as np
import tqdm


# Grid search may take a while, you can curate the feature list to speed this process up significantly

def train_tree(x, y, depth):
  train_x, test_x, train_y, test_y = train_test_split(x, y, train_size=0.5, random_state=42)

  # Create a nice regularized tree
  model = tree.DecisionTreeClassifier(
      max_depth=depth,
      min_samples_leaf=len(train_x) // 20,
      random_state=42
  )

  model.fit(train_x, train_y)

  pred = model.predict(test_x)

  # Calculate the f1 score of the model
  accuracy = balanced_accuracy_score(test_y, pred)
  score = f1_score(test_y, pred)

  return model, pred, score, accuracy


def find_best_combo(features, k_features_per_combo = 2):
  combos = FeatureMixer(features).grid(k_features_per_combo=k_features_per_combo)
  best_combo = None
  best_model = None
  mean_act_negative = 0
  mean_act_positive = 0
  support_vector_distances = 0
  best_score = 0
  best_accuracy = 0

  MAX_WORKERS = 8

  futures_list = []

  with futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    for combo in combos:
      def _test_combo(combo):
        # Create a linear regression model
        def _select_feature_acts(combo, row):
          output = []
          for index, feature in enumerate(combo):
            for feature_act in row:
              if feature_act.feature.uuid == feature.uuid:
                output.append(feature_act.activation)
                break

          return output

        x_negative = [_select_feature_acts(combo, row) for row in negative_class_features]
        # x_neutral = [_select_feature_acts(combo, row) for row in neutral_class_features]
        x_positive = [_select_feature_acts(combo, row) for row in positive_class_features]

        y_negative = [-1] * len(x_negative)
        # y_neutral = [0] * len(x_neutral)
        y_positive = [1] * len(x_positive)

        x = x_negative + x_positive
        y = y_negative + y_positive

        model, pred, score, accuracy = train_tree(x, y, depth=len(combo))

        return model, pred, score, accuracy, combo

      futures_list.append(executor.submit(_test_combo, combo))

    for future in tqdm.tqdm(futures_list):
      model, pred, score, accuracy, combo = future.result()

      if score > best_score:
        best_score = score
        best_combo = combo
        best_model = model
        best_accuracy = accuracy

  return best_combo, best_score, best_model, best_accuracy


best_combo_at_k = {}
for i in range(3):
  best_combo, best_score, best_model, best_accuracy = find_best_combo(features_to_look_at, k_features_per_combo = i + 1)
  print(i + 1, best_combo, best_score, best_accuracy, best_model)
  best_combo_at_k[i + 1] = (best_combo, best_score, best_model)

100%|██████████| 85/85 [00:00<00:00, 146.55it/s]
1 (Feature("Ending or terminating relationships, employment, or accounts"),) 0.6859903381642513 0.5827991452991453 DecisionTreeClassifier(max_depth=1, min_samples_leaf=7, random_state=42)
100%|██████████| 3570/3570 [02:01<00:00, 29.38it/s]
2 (Feature("Company announces new business deal or acquisition"), Feature("Evaluative language indicating positive or negative assessment")) 0.7169811320754716 0.703525641025641 DecisionTreeClassifier(max_depth=2, min_samples_leaf=7, random_state=42)
100%|██████████| 98770/98770 [45:25<00:00, 36.23it/s]
3 (Feature("Currency symbols and abbreviations"), Feature("Ending or terminating relationships, employment, or accounts"), Feature("Suppressed or internalized negative emotions")) 0.7272727272727273 0.688034188034188 DecisionTreeClassifier(max_depth=3, min_samples_leaf=7, random_state=42)
[ ]:
# Inspect features to understand their nuances better

best_individual_feature = best_combo_at_k[3][0][0]

client.features._experimental.neighbors(best_individual_feature)

# Seems to be associated with more precise numbers! Corraborated by checking samples where this feature was particularly active.
FeatureGroup([
   0: "Measuring precise liquid volumes",
   1: "Numerical character attributes in games (10-100 range)",
   2: "Units of measurement in scientific and technical contexts",
   3: "Numerical values in personal budgeting and cost estimation contexts",
   4: "Diverse units of measurement for physical attributes in natural science descriptions",
   5: "Large numerical measurements and quantities",
   6: "Large monetary amounts in thousands",
   7: "Numerical values in text",
   8: "Numerical values in financial and statistical contexts",
   9: "Percentage symbol recognition"
])
[ ]:
# Anyways let's look at the best overall tree

BEST_TREE_INDEX = 3
best_features = best_combo_at_k[BEST_TREE_INDEX][0]
best_tree = best_combo_at_k[BEST_TREE_INDEX][2]
[ ]:
# Let's visualize the tree

import graphviz


dot_data = tree.export_graphviz(best_tree, out_file=None, feature_names=[feature.label for feature in best_features], class_names=['negative', 'positive'], filled=True, rounded=True, special_characters=True)
graph = graphviz.Source(dot_data)
graph

../_images/examples_decision_trees_18_0.svg