Python

Decision tree visual example

A decision tree can be visualized. A decision tree is one of the many Machine Learning algorithms.
It’s used as classifier: given input data, it is class A or class B? In this lecture we will visualize a decision tree using the Python module pydotplus and the module graphviz

Related course:
Related course: Data Science and Machine Learning with Python – Hands On!

If you want to do decision tree analysis, to understand the decision tree algorithm / model or if you just need a decision tree maker - you’ll need to visualize the decision tree.

Decision Tree

Install
You need to install pydotplus and graphviz. These can be installed with your package manager and pip.
Graphviz is a tool for drawing graphics using dot files. Pydotplus is a module to Graphviz’s Dot language.

Data Collection
We start by defining the code and data collection. Let’s make the decision tree on man or woman. Given input features: “height, hair length and voice pitch” it will predict if its a man or woman.

We start with the training data:
training data

In code that looks like:

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

# Data Collection
X = [ [180, 15,0],
[177, 42,0],
[136, 35,1],
[174, 65,0],
[141, 28,1]]

Y = ['man', 'woman', 'woman', 'man', 'woman']

data_feature_names = [ 'height', 'hair length', 'voice pitch' ]

Train Classifier
The next step is to train the classifier (decision tree) with the training data.
Training is always necessary for supervised learning algorithms

# Training
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X,Y)

Decision Tree Visualization
We then visualize the tree using this complete code:

# Visualize data
dot_data = tree.export_graphviz(clf,
feature_names=data_feature_names,
out_file=None,
filled=True,
rounded=True)
graph = pydotplus.graph_from_dot_data(dot_data)

colors = ('turquoise', 'orange')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

This will save the visualization to the image tree.png, which looks like this:

decision tree machine learning

If you want to make predictions, check out the decision tree article.

Previous Post Next Post

Cookie policy | Privacy policy | ©

Machine Learning