Creating a decision tree

In this unit, you'll use the TF-DF (TensorFlow Decision Forest) library train, tune, and interpret a decision tree.


Before studying the dataset, do the following:

  1. Create a new Colab notebook.
  2. Install the TensorFlow Decision Forests library by placing the following line of code in your new Colab notebook:
    pip install tensorflow_decision_forests
  3. Import the following libraries:
    import numpy as np
    import pandas as pd
    import tensorflow_decision_forests as tfdf

The Palmer Penguins dataset

This Colab uses the Palmer Penguins dataset, which contains size measurements for three penguin species:

  • Chinstrap
  • Gentoo
  • Adelie

This is a classification problem—the goal is to predict the species of penguin based on data in the Palmer's Penguins dataset.

Let’s meet the penguins.

Three different penguin

Figure 16. Three different penguin species. Image by @allisonhorst


The following code calls a pandas function to load the Palmer Penguins dataset into memory:

path = ""
pandas_dataset = pd.read_csv(path)

# Display the first 3 examples.

The following table formats the first 3 examples in the Palmer Penguins dataset:

Table 3. The first 3 examples in Palmer Penguins

species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex year
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 male 2007
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 female 2007
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 female 2007

The full dataset contains a mix of numerical (for example, bill_depth_mm), categorical (for example, island), and missing features. Unlike neural networks, decision forests support all these feature types natively, so you don't have to do one-hot encoding, normalization, or extra is_present feature.

To simplify interpretability, the following code cell manually converts the penguin species into integer labels:

label = "species"

classes = list(pandas_dataset[label].unique())
print(f"Label classes: {classes}")
# >> Label classes: ['Adelie', 'Gentoo', 'Chinstrap']

pandas_dataset[label] = pandas_dataset[label].map(classes.index)

The following code cell splits the dataset into a training set and testing set:

# Use the ~10% of the examples as the testing set
# and the remaining ~90% of the examples as the training set.
test_indices = np.random.rand(len(pandas_dataset)) < 0.1
pandas_train_dataset = pandas_dataset[~test_indices]
pandas_test_dataset = pandas_dataset[test_indices]

print("Training examples: ", len(pandas_train_dataset))
# >> Training examples: 309

print("Testing examples: ", len(pandas_test_dataset))
# >> Testing examples: 35

Training a model with default hyperparameters

We can train our first CART (Classification and Regression Trees) model without specifying any hyperparameters. That's because the tfdf.keras.CartModel function provides good default hyperparameter values. You will learn more about how this type of model works later in the course.

tf_train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pandas_train_dataset, label=label)
model = tfdf.keras.CartModel()

The preceding call to tfdf.keras.CartModel did not specify columns to use as input features. Therefore, every column in the training set is used. The call also did not specify the semantics (for example, numerical, categorical, text) of the input features. Therefore, tfdf.keras.CartModel will automatically infer the semantics.

Call tfdf.model_plotter.plot_model_in_colab to display the resulting decision tree:

tfdf.model_plotter.plot_model_in_colab(model, max_depth=10)

In Colab, you can use the mouse to display details about specific elements.

A decision tree with a few

Figure 17. Visualization of a decision tree in Colab.

Colab shows that the root condition evaluated 277 examples. However, you might remember that the training dataset contained 309 examples. The remaining 32 examples were used for validation.

The first condition tests the value of bill_length_mm. Tables 1 and 2 show the likelihood of different species depending on the outcome of the first condition.


Table 4. Likelihood of different species if bill_length_mm ≥ 16.35

Species Likelihood
Adelie (red) 62%
Chinstrap (green) 33%
Gentoo (blue) 4%


Table 5. Likelihood of different species if bill_length_mm < 16.35

Species Likelihood
Adelie (red) 1%
Chinstrap (green) 1%
Gentoo (blue) 97%

bill_length_mm is a numerical feature. Therefore, the value 16.35 was found using the exact splitting for binary classification with numerical features algorithm.

If bill_length_mm < 16.5 is True, further testing whether the body_mass_g ≥ 4175 can perfectly separate between 86 Gentoos and 5 Gentoos+Adelie.

The following code provides the training and test accuracy of the model:

print("Train evaluation: ", model.evaluate(tf_train_dataset, return_dict=True))
# >> Train evaluation:  {'loss': 0.0, 'accuracy': 0.96116}

tf_test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(pandas_test_dataset, label=label)
print("Test evaluation: ", model.evaluate(tf_test_dataset, return_dict=True))
# >> Test evaluation:  {'loss': 0.0, 'accuracy': 0.97142}

It is rare, but possible, that the test accuracy is higher than the training accuracy. In that case, the test set possibly differs from the training set. However, this is not the case here as the train/test was split randomly. A more likely explanation is that the test dataset is very small (only 35 examples), so the accuracy estimation is noisy.

To train a better model, optimize the hyperparameters with the Keras tuner.

Important note: Don't tune hyperparameters on the test dataset. Instead, tune hyperparameters on a separate validation dataset (on large datasets) or using cross-validation (on small datasets). For simplicity, this course uses the test dataset for tuning.

The following code optimizes two parameters:

  • the minimum number of examples in a condition node (min_examples)
  • the ratio of the training dataset used for pruning validation (validation_ratio)

Since we don't know the best values for these parameters, we provide various possibilities for the tuner to try. We've picked four possible values for min_examples and three for validation_ratio. Increasing the number of candidate hyperparameter values increases the chance of training a better model, but it also increases the training time.

pip install keras-tuner
import keras_tuner as kt

def build_model(hp):
  model = tfdf.keras.CartModel(
          # Try four possible values for "min_examples" hyperparameter.
          # min_examples=10 would limit the growth of the decision tree,
          # while min_examples=1 would lead to deeper decision trees.
         [1, 2, 5, 10]),
         # Three possible values for the "validation_ratio" hyperparameter.
         [0.0, 0.05, 0.10]),
  return model

tuner = kt.RandomSearch(
    project_name="tune_cart"), validation_data=tf_test_dataset)
best_model = tuner.get_best_models()[0]

print("Best hyperparameters: ", tuner.get_best_hyperparameters()[0].values)
# >> Best hyperparameters:  {'min_examples': 2, 'validation_ratio': 0.0}

The candidate values of the hyperparameter (for example, [1,2,5,10] for min_examples) depend on your understanding of the dataset and the amount of computing resources available.

The following code retrains and evaluates the model using those optimized hyperparameters:

model = tfdf.keras.CartModel(min_examples=2, validation_ratio=0.0)

print("Test evaluation: ", model.evaluate(tf_test_dataset, return_dict=True))
# >> Test evaluation:  {'loss': 0.0, 'accuracy': 1.0}

The accuracy of 1.0 means our model perfectly explains our test dataset in this toy example.

The following code plots the new decision tree:

tfdf.model_plotter.plot_model_in_colab(model, max_depth=10)

A decision tree with six levels of

Figure 18. A decision tree with six levels of nodes.


As expected by the new hyperparameter values, this decision tree is deeper than before because:

  • The minimum number of examples was reduced (from 5 to 2).
  • Validation pruning was disabled (validation_ratio=0.0) leading to more available training examples and no pruning.

Usage and limitation

As mentioned earlier, a single decision tree often has lower quality than modern machine learning methods like random forests, gradient boosted trees, and neural networks. However, decision trees are still useful in the following cases:

  • As a simple and inexpensive baseline to evaluate more complex approaches.
  • When there is a tradeoff between the model quality and interpretability.
  • As a proxy for the interpretation of the decision forests model, which the course will explore later on.