Contents
from sklearn import datasets
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
digits = datasets.load_digits()
kmeans = KMeans(n_clusters=10, random_state=0).fit(digits.data)
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, c in enumerate(kmeans.cluster_centers_):
axes[i // 5 , i % 5].imshow(c.reshape(8, 8), cmap=plt.cm.binary)
axes[i // 5, i % 5].axis('off')
# for c in range(10):
# axes[1, c].imshow(digits.data[kmeans.labels_ == c][0].reshape(8, 8), cmap=plt.cm.binary)
# axes[1, c].axis('off')
plt.show()