A neural network can be applied to the classification problem. Given this example, determine the class.
Tensorflow has an implementation for the neural network included, which we’ll use to on csv data (the iris dataset).
Related Course:
Deep Learning with TensorFlow 2 and Keras
Iris Dataset
The iris dataset is split in two files: the training set and the test set. The network has a training phase. After training is completed it can be used to predict.
What does the iris dataset contain?
It 3 contains classes of plants (0,1,2) which is the last parameter of the file.
It has 4 attributes:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
In short: grabbed a bunch of plants of different types and measured. This is then stored in text files.
You can download the iris dataset on github.
The traning set is a simple file that looks like this:6.4,2.8,5.6,2.2,2
5.0,2.3,3.3,1.0,1
4.9,2.5,4.5,1.7,2
4.9,3.1,1.5,0.1,0
5.7,3.8,1.7,0.3,0
...
The test set looks similar5.9,3.0,4.2,1.5,1
6.9,3.1,5.4,2.1,2
5.1,3.3,1.7,0.5,0
6.0,3.4,4.5,1.6,1
...
The files have a header, that we’ll ignore.
Neural network on csv data
The csv files can be loaded with these two lines:
training_set = tf.contrib.learn.datasets.base.load_csv_with_header( |
The capital letters are the file names. Load type as integer and features as float.
Specify that all features have real data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] |
Create the neural network with one line of code. As second parameter the number of hidden units per layers are told. All layers are fully connected. [5,10] means the first layer has 5 nodes, the second layer has 10 nodes.
Then specify the number of possible classes with n_classes. In our dataset we have only 3 types of flowers (0,1,2).
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, |
Fit the model.
classifier.fit(input_fn=get_train_inputs, steps=2000) |
Then you can evaluate the classifier
# Define the test inputs |
Then given 4 new samples, you can predict the type (class) of flower:
# Classify new flower |
Neural Network on CSV sample
The example below summarizes what we talked about. You can copy this code and run it. Don’t forget to get the iris dataset (train and test).
# DNNClassifier on CSV input dataset. |