# source: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import itertools
def plot_confusion_matrix(cm, classes,
normalize=False,
decimals=2,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, np.round(cm[i, j], decimals=decimals),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
# EXAMPLE: Plot normalized confusion matrix
cnf_matrix = confusion_matrix(test_labels_1d, pred)
class_names = [str(i) for i in range(0, 10)]
plt.figure(figsize=(6,6))
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title='Normalized confusion matrix')
plt.show()
import pandas as pd
# simple crosstab with totals
pd.crosstab(y, yHat, rownames=["actual"], colnames=["predicted"], margins=True)
# normalized cross tab, rounded to two decimals
(pd.crosstab(test_labels_1d, pred, rownames=["actual"], colnames=["predicted"], margins=False, normalize="index")*100).round(2)
from IPython.display import SVG
from keras.utils.visualize_util import model_to_dot
SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))
# and to save it directly to a file
from keras.utils.visualize_util import plot
plot(model, show_shapes=True, to_file='/tmp/model.png')