### learn MNIST using Keras

#  Keras code to implement Figure 8.5 of
# Poole and Mackworth, Artificial Intelligence: foundations of
# computational agents, 3rd Edition, Cambridge, 2023

# Copyright (c) 2023 David Poole and Alan Mackworth. This program is released under
# CC Attribution-NonCommercial-ShareAlike  License http://creativecommons.org/licenses/by-nc-sa/4.0/

# If using Anaconda.org make sure you activate tensorflow:
## conda activate tf

from tensorflow import keras
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
from tensorflow.keras import layers

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

model = keras.Sequential([layers.Dense(512, activation="relu"),
                          layers.Dense(10, activation = "softmax")])

model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy",
      metrics=["accuracy"])

train_images_flat = train_images.reshape((60000,28*28))
train_images_flat = train_images_flat.astype("float32")
test_images_flat = test_images.reshape((10000,28*28))
test_images_flat = test_images_flat.astype("float32")
model.fit(train_images_flat, train_labels, epochs=5, batch_size=128)

#model.evaluate(test_images_flat, test_labels)
#model.evaluate(train_images_flat, train_labels)

train_predictions = model.predict(train_images_flat)

train_prediction_of_actual = [(train_predictions[i][train_labels[i]],i) for i in range(len(train_labels))]
train_prediction_of_actual.sort()
#train_prediction_of_actual[:10]
train_mids = [(p,i) for (p,i) in train_prediction_of_actual if 0.3 < p <= 0.5]

def show_pred(i):
    print("label =",train_labels[i], "prediction =",train_predictions[i])
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.show()

#show_pred(59915)

def show_predictions(predictions,prediction_of_actual,labels,images,extreme=True):
    n = len(prediction_of_actual)
    for pos in range(n):
        (_,i) = prediction_of_actual[pos]
        plt.subplot(1,n,(pos+1))
        print("label =",labels[i], "prediction =",predictions[i])
        plt.imshow(images[i], cmap=plt.cm.binary)
        #plt.axis('off')
        ax = plt.gca()
        ax.set_xticks([])
        ax.set_yticks([])
        plt.text(0,-15,"Actual="+str(labels[i]))
        (pam,am) = max((v,p) for (p,v) in enumerate(predictions[i]) if p != labels[i])
        print("am=",am)
        others =sum(predictions[i][j] for j in range(len(predictions[i])) if j != am)
        print("others=",others)
        if extreme:
            plt.text(0,-5,f"P({str(am)})=1-{others:.1e}")
            plt.text(0,-10,f"P({str(labels[i])})={predictions[i][labels[i]]:.1e}")
        else:
            plt.text(0,-5,f"P({str(am)})={pam:.4f}")
            plt.text(0,-10,f"P({str(labels[i])})={predictions[i][labels[i]]:.4f}")
    plt.show()

#show_predictions(train_predictions,train_prediction_of_actual[:6],train_labels,train_images)

test_predictions = model.predict(test_images_flat)
test_prediction_of_actual = [(test_predictions[i][test_labels[i]],i) for i in range(len(test_labels))]
test_prediction_of_actual.sort()
test_prediction_of_actual[:10]
test_mids = [(p,i) for (p,i) in test_prediction_of_actual if 0.2 < p <= 0.5]

#show_predictions(test_predictions,test_prediction_of_actual[:6],test_labels,test_images)

#show_predictions(train_predictions,train_mids[-6:],train_labels,train_images,extreme=False)
#show_predictions(test_predictions,test_mids[-6:],test_labels,test_images,extreme=False)
