How to save sklearn tree plot as file (Vector Graphics)

The Scikit-Learn (sklearn) Python package has a nice function sklearn.tree.plot_tree to plot (decision) trees. The documentation is found here.

However, the default plot just by using the command

tree.plot_tree(clf)

could be low resolution if you try to save it from a IDE like Spyder.

The solution is to first import matplotlib.pyplot:

import matplotlib.pyplot as plt

Then, the following code will allow you to save the sklearn tree as .eps (or you could change the format accordingly):

plt.figure()
tree.plot_tree(clf,filled=True)  
plt.savefig('tree.eps',format='eps',bbox_inches = "tight")

To elaborate, clf is your Decision Tree classifier (to be defined before plotting the tree):

# Example from https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html
clf = tree.DecisionTreeClassifier(random_state=0)
clf = clf.fit(iris.data, iris.target)

The outcome is a Vector Graphics format (.eps) tree that will retain its full resolution when zoomed in. The bbox_inches=”tight” command prevents truncating of the image. Without that command, sometimes the sklearn tree will just be cropped off and be incomplete.

Advertisement

Author: mathtuition88

Math and Education Blog

One thought on “How to save sklearn tree plot as file (Vector Graphics)”

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.

%d bloggers like this: