The Scikit-Learn (sklearn) Python package has a nice function sklearn.tree.plot_tree to plot (decision) trees. The documentation is found here.
However, the default plot just by using the command
could be low resolution if you try to save it from a IDE like Spyder.
The solution is to first import matplotlib.pyplot:
import matplotlib.pyplot as plt
Then, the following code will allow you to save the sklearn tree as .eps (or you could change the format accordingly):
plt.figure() tree.plot_tree(clf,filled=True) plt.savefig('tree.eps',format='eps',bbox_inches = "tight")
To elaborate, clf is your Decision Tree classifier (to be defined before plotting the tree):
# Example from https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html clf = tree.DecisionTreeClassifier(random_state=0) clf = clf.fit(iris.data, iris.target)
The outcome is a Vector Graphics format (.eps) tree that will retain its full resolution when zoomed in. The bbox_inches=”tight” command prevents truncating of the image. Without that command, sometimes the sklearn tree will just be cropped off and be incomplete.
One thought on “How to save sklearn tree plot as file (Vector Graphics)”
Reblogged this on Project ENGAGE.