CNN on the CIFAR100 Data

Biostat 203B

Author

Dr. Hua Zhou @ UCLA

Published

February 28, 2024

1 Setup

Display system information for reproducibility.

import IPython
print(IPython.sys_info())
{'commit_hash': '8b1204b6c',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/opt/venv/lib/python3.10/site-packages/IPython',
 'ipython_version': '8.21.0',
 'os_name': 'posix',
 'platform': 'Linux-6.6.12-linuxkit-aarch64-with-glibc2.35',
 'sys_executable': '/opt/venv/bin/python',
 'sys_platform': 'linux',
 'sys_version': '3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]'}
sessionInfo()
R version 4.3.2 (2023-10-31)
Platform: aarch64-unknown-linux-gnu (64-bit)
Running under: Ubuntu 22.04.3 LTS

Matrix products: default
BLAS:   /usr/lib/aarch64-linux-gnu/openblas-pthread/libblas.so.3 
LAPACK: /usr/lib/aarch64-linux-gnu/openblas-pthread/libopenblasp-r0.3.20.so;  LAPACK version 3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: Etc/UTC
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] digest_0.6.34     fastmap_1.1.1     xfun_0.42         Matrix_1.6-1.1   
 [5] lattice_0.21-9    reticulate_1.35.0 knitr_1.45        htmltools_0.5.7  
 [9] png_0.1-8         rmarkdown_2.25    cli_3.6.2         grid_4.3.2       
[13] compiler_4.3.2    rstudioapi_0.15.0 tools_4.3.2       evaluate_0.23    
[17] Rcpp_1.0.12       yaml_2.3.8        rlang_1.1.3       jsonlite_1.8.8   
[21] htmlwidgets_1.6.4

Load some libraries.

# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 1.2)
# Display all columns
pd.set_option('display.max_columns', None)

# Load Tensorflow and Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
library(keras)
library(jpeg)

In this example, we train a CNN (convolution neural network) on the CIFAR100 data set. Achieve testing accuracy 44.5% after 30 epochs. Random guess would have an accuracy of about 1%.

  • The CIFAR100 database is a large database of \(32 \times 32\) color images that is commonly used for training and testing machine learning algorithms.

  • 50,000 training images, 10,000 testing images.

2 Prepare data

Acquire data:

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

     8192/169001437 [..............................] - ETA: 0s
    40960/169001437 [..............................] - ETA: 6:06
    98304/169001437 [..............................] - ETA: 4:02
   221184/169001437 [..............................] - ETA: 2:55
   393216/169001437 [..............................] - ETA: 2:00
   745472/169001437 [..............................] - ETA: 1:16
   917504/169001437 [..............................] - ETA: 1:11
  1835008/169001437 [..............................] - ETA: 41s 
  3194880/169001437 [..............................] - ETA: 26s
  3489792/169001437 [..............................] - ETA: 27s
  5963776/169001437 [>.............................] - ETA: 17s
  7168000/169001437 [>.............................] - ETA: 15s
  8970240/169001437 [>.............................] - ETA: 13s
 11231232/169001437 [>.............................] - ETA: 10s
 12640256/169001437 [=>............................] - ETA: 10s
 14770176/169001437 [=>............................] - ETA: 9s 
 16121856/169001437 [=>............................] - ETA: 8s
 17784832/169001437 [==>...........................] - ETA: 8s
 19636224/169001437 [==>...........................] - ETA: 7s
 20889600/169001437 [==>...........................] - ETA: 7s
 21184512/169001437 [==>...........................] - ETA: 8s
 22388736/169001437 [==>...........................] - ETA: 7s
 24248320/169001437 [===>..........................] - ETA: 7s
 24313856/169001437 [===>..........................] - ETA: 8s
 27975680/169001437 [===>..........................] - ETA: 7s
 29229056/169001437 [====>.........................] - ETA: 7s
 30941184/169001437 [====>.........................] - ETA: 7s
 32481280/169001437 [====>.........................] - ETA: 7s
 34136064/169001437 [=====>........................] - ETA: 6s
 36732928/169001437 [=====>........................] - ETA: 6s
 38232064/169001437 [=====>........................] - ETA: 6s
 39780352/169001437 [======>.......................] - ETA: 6s
 42164224/169001437 [======>.......................] - ETA: 5s
 43769856/169001437 [======>.......................] - ETA: 5s
 46391296/169001437 [=======>......................] - ETA: 5s
 47906816/169001437 [=======>......................] - ETA: 5s
 50241536/169001437 [=======>......................] - ETA: 5s
 52346880/169001437 [========>.....................] - ETA: 4s
 53821440/169001437 [========>.....................] - ETA: 4s
 55590912/169001437 [========>.....................] - ETA: 4s
 57540608/169001437 [=========>....................] - ETA: 4s
 59506688/169001437 [=========>....................] - ETA: 4s
 61038592/169001437 [=========>....................] - ETA: 4s
 62971904/169001437 [==========>...................] - ETA: 4s
 64454656/169001437 [==========>...................] - ETA: 4s
 66297856/169001437 [==========>...................] - ETA: 4s
 68239360/169001437 [===========>..................] - ETA: 3s
 70270976/169001437 [===========>..................] - ETA: 3s
 71958528/169001437 [===========>..................] - ETA: 3s
 73687040/169001437 [============>.................] - ETA: 3s
 75497472/169001437 [============>.................] - ETA: 3s
 77193216/169001437 [============>.................] - ETA: 3s
 79020032/169001437 [=============>................] - ETA: 3s
 80879616/169001437 [=============>................] - ETA: 3s
 82583552/169001437 [=============>................] - ETA: 3s
 84680704/169001437 [==============>...............] - ETA: 3s
 86499328/169001437 [==============>...............] - ETA: 3s
 88465408/169001437 [==============>...............] - ETA: 2s
 89849856/169001437 [==============>...............] - ETA: 2s
 91496448/169001437 [===============>..............] - ETA: 2s
 93167616/169001437 [===============>..............] - ETA: 2s
 94830592/169001437 [===============>..............] - ETA: 2s
 96460800/169001437 [================>.............] - ETA: 2s
 98050048/169001437 [================>.............] - ETA: 2s
 99713024/169001437 [================>.............] - ETA: 2s
101376000/169001437 [================>.............] - ETA: 2s
102948864/169001437 [=================>............] - ETA: 2s
104660992/169001437 [=================>............] - ETA: 2s
106373120/169001437 [=================>............] - ETA: 2s
107757568/169001437 [==================>...........] - ETA: 2s
109420544/169001437 [==================>...........] - ETA: 2s
110993408/169001437 [==================>...........] - ETA: 2s
112369664/169001437 [==================>...........] - ETA: 2s
113754112/169001437 [===================>..........] - ETA: 1s
115236864/169001437 [===================>..........] - ETA: 1s
116621312/169001437 [===================>..........] - ETA: 1s
118104064/169001437 [===================>..........] - ETA: 1s
119480320/169001437 [====================>.........] - ETA: 1s
120823808/169001437 [====================>.........] - ETA: 1s
122691584/169001437 [====================>.........] - ETA: 1s
124108800/169001437 [=====================>........] - ETA: 1s
125493248/169001437 [=====================>........] - ETA: 1s
126033920/169001437 [=====================>........] - ETA: 1s
126517248/169001437 [=====================>........] - ETA: 1s
126681088/169001437 [=====================>........] - ETA: 1s
126828544/169001437 [=====================>........] - ETA: 1s
127033344/169001437 [=====================>........] - ETA: 1s
128876544/169001437 [=====================>........] - ETA: 1s
129638400/169001437 [======================>.......] - ETA: 1s
132857856/169001437 [======================>.......] - ETA: 1s
133685248/169001437 [======================>.......] - ETA: 1s
134561792/169001437 [======================>.......] - ETA: 1s
135995392/169001437 [=======================>......] - ETA: 1s
137674752/169001437 [=======================>......] - ETA: 1s
139436032/169001437 [=======================>......] - ETA: 1s
140894208/169001437 [========================>.....] - ETA: 1s
142704640/169001437 [========================>.....] - ETA: 1s
144416768/169001437 [========================>.....] - ETA: 0s
146284544/169001437 [========================>.....] - ETA: 0s
147955712/169001437 [=========================>....] - ETA: 0s
149553152/169001437 [=========================>....] - ETA: 0s
151306240/169001437 [=========================>....] - ETA: 0s
152895488/169001437 [==========================>...] - ETA: 0s
154812416/169001437 [==========================>...] - ETA: 0s
156672000/169001437 [==========================>...] - ETA: 0s
158384128/169001437 [===========================>..] - ETA: 0s
160153600/169001437 [===========================>..] - ETA: 0s
161816576/169001437 [===========================>..] - ETA: 0s
163307520/169001437 [===========================>..] - ETA: 0s
164864000/169001437 [============================>.] - ETA: 0s
166871040/169001437 [============================>.] - ETA: 0s
168435712/169001437 [============================>.] - ETA: 0s
169001437/169001437 [==============================] - 6s 0us/step
# Training set
x_train.shape
(50000, 32, 32, 3)
y_train.shape
(50000, 1)
# Test set
x_test.shape
(10000, 32, 32, 3)
y_test.shape
(10000, 1)
cifar100 <- dataset_cifar100()
x_train <- cifar100$train$x
y_train <- cifar100$train$y
x_test <- cifar100$test$x
y_test <- cifar100$test$y

Training set:

dim(x_train)
[1] 50000    32    32     3
dim(y_train)
[1] 50000     1

Testing set:

dim(y_train)
[1] 50000     1
dim(y_test)
[1] 10000     1

For CNN, we keep the \(32 \times 32 \times 3\) tensor structure, instead of vectorizing into a long vector.

# Rescale
x_train = x_train / 255
x_test = x_test / 255
# Train
x_train.shape
(50000, 32, 32, 3)
# Test
x_test.shape
(10000, 32, 32, 3)
# rescale
x_train <- x_train / 255
x_test <- x_test / 255
dim(x_train)
[1] 50000    32    32     3
dim(x_test)
[1] 10000    32    32     3

Encode \(y\) as binary class matrix:

y_train = keras.utils.to_categorical(y_train, 100)
y_test = keras.utils.to_categorical(y_test, 100)
# Train
y_train.shape
(50000, 100)
# Test
y_test.shape
(10000, 100)
# First train instance
y_train[0]
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)
y_train <- to_categorical(y_train, 100)
y_test <- to_categorical(y_test, 100)
dim(y_train)
[1] 50000   100
dim(y_test)
[1] 10000   100
# head(y_train)

Show a few images:

import matplotlib.pyplot as plt

# Feature: 32x32 color image
for i in range(25):
  plt.figure()
  plt.imshow(x_train[i]);
  plt.show()

par(mar = c(0, 0, 0, 0), mfrow = c(5, 5))
index <- sample(seq(50000), 25)
for (i in index) plot(as.raster(x_train[i,,, ]))

3 Define the model

Define a sequential model (a linear stack of layers) with 2 fully-connected hidden layers (256 and 128 neurons):

model = keras.Sequential(
  [
    keras.Input(shape = (32, 32, 3)),
    layers.Conv2D(
      filters = 32, 
      kernel_size = (3, 3),
      padding = 'same',
      activation = 'relu',
      # input_shape = (32, 32, 3)
      ),
    layers.MaxPooling2D(pool_size = (2, 2)),
    layers.Conv2D(
      filters = 64, 
      kernel_size = (3, 3),
      padding = 'same',
      activation = 'relu'
      ),
    layers.MaxPooling2D(pool_size = (2, 2)),
    layers.Conv2D(
      filters = 128, 
      kernel_size = (3, 3),
      padding = 'same',
      activation = 'relu'
      ),
    layers.MaxPooling2D(pool_size = (2, 2)),
    layers.Conv2D(
      filters = 256, 
      kernel_size = (3, 3),
      padding = 'same',
      activation = 'relu'
      ),
    layers.MaxPooling2D(pool_size = (2, 2)),
    layers.Flatten(),
    layers.Dropout(rate = 0.5),
    layers.Dense(units = 512, activation = 'relu'),
    layers.Dense(units = 100, activation = 'softmax')
]
)

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 32, 32, 32)        896       
                                                                 
 max_pooling2d (MaxPooling2  (None, 16, 16, 32)        0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 16, 16, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 8, 8, 64)          0         
 g2D)                                                            
                                                                 
 conv2d_2 (Conv2D)           (None, 8, 8, 128)         73856     
                                                                 
 max_pooling2d_2 (MaxPoolin  (None, 4, 4, 128)         0         
 g2D)                                                            
                                                                 
 conv2d_3 (Conv2D)           (None, 4, 4, 256)         295168    
                                                                 
 max_pooling2d_3 (MaxPoolin  (None, 2, 2, 256)         0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 1024)              0         
                                                                 
 dropout (Dropout)           (None, 1024)              0         
                                                                 
 dense (Dense)               (None, 512)               524800    
                                                                 
 dense_1 (Dense)             (None, 100)               51300     
                                                                 
=================================================================
Total params: 964516 (3.68 MB)
Trainable params: 964516 (3.68 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

Plot the model:

tf.keras.utils.plot_model(
    model,
    to_file = "model.png",
    show_shapes = True,
    show_dtype = False,
    show_layer_names = True,
    rankdir = "TB",
    expand_nested = False,
    dpi = 96,
    layer_range = None,
    show_layer_activations = False,
)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.

model <- keras_model_sequential()  %>% 
  layer_conv_2d(
    filters = 32, 
    kernel_size = c(3, 3),
    padding = "same", 
    activation = "relu",
    input_shape = c(32, 32, 3)
    ) %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_conv_2d(
    filters = 64, 
    kernel_size = c(3, 3),
    padding = "same", 
    activation = "relu"
    ) %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_conv_2d(
    filters = 128, 
    kernel_size = c(3, 3),
    padding = "same", 
    activation = "relu"
    ) %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_conv_2d(
    filters = 256, 
    kernel_size = c(3, 3),
    padding = "same", 
    activation = "relu"
    ) %>%
  layer_max_pooling_2d(pool_size = c(2, 2)) %>%
  layer_flatten() %>%
  layer_dropout(rate = 0.5) %>%
  layer_dense(units = 512, activation = "relu") %>%
  layer_dense(units = 100, activation = "softmax")
  
summary(model)
Model: "sequential_1"
________________________________________________________________________________
 Layer (type)                       Output Shape                    Param #     
================================================================================
 conv2d_7 (Conv2D)                  (None, 32, 32, 32)              896         
 max_pooling2d_7 (MaxPooling2D)     (None, 16, 16, 32)              0           
 conv2d_6 (Conv2D)                  (None, 16, 16, 64)              18496       
 max_pooling2d_6 (MaxPooling2D)     (None, 8, 8, 64)                0           
 conv2d_5 (Conv2D)                  (None, 8, 8, 128)               73856       
 max_pooling2d_5 (MaxPooling2D)     (None, 4, 4, 128)               0           
 conv2d_4 (Conv2D)                  (None, 4, 4, 256)               295168      
 max_pooling2d_4 (MaxPooling2D)     (None, 2, 2, 256)               0           
 flatten_1 (Flatten)                (None, 1024)                    0           
 dropout_1 (Dropout)                (None, 1024)                    0           
 dense_3 (Dense)                    (None, 512)                     524800      
 dense_2 (Dense)                    (None, 100)                     51300       
================================================================================
Total params: 964516 (3.68 MB)
Trainable params: 964516 (3.68 MB)
Non-trainable params: 0 (0.00 Byte)
________________________________________________________________________________

Compile the model with appropriate loss function, optimizer, and metrics:

model.compile(
  loss = "categorical_crossentropy",
  optimizer = "rmsprop",
  metrics = ["accuracy"]
)
model %>% compile(
  loss = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(),
  metrics = c('accuracy')
)

4 Training and validation

80%/20% split for the train/validation set. On my laptop, each epoch takes about half minute.

batch_size = 128
epochs = 30

history = model.fit(
  x_train,
  y_train,
  batch_size = batch_size,
  epochs = epochs,
  validation_split = 0.2
)

Plot training history:

Code
hist = pd.DataFrame(history.history)
hist['epoch'] = np.arange(1, epochs + 1)
hist = hist.melt(
  id_vars = ['epoch'],
  value_vars = ['loss', 'accuracy', 'val_loss', 'val_accuracy'],
  var_name = 'type',
  value_name = 'value'
)
hist['split'] = np.where(['val' in s for s in hist['type']], 'validation', 'train')
hist['metric'] = np.where(['loss' in s for s in hist['type']], 'loss', 'accuracy')

# Accuracy trace plot
plt.figure()
sns.relplot(
  data = hist[hist['metric'] == 'accuracy'],
  kind = 'scatter',
  x = 'epoch',
  y = 'value',
  hue = 'split'
).set(
  xlabel = 'Epoch',
  ylabel = 'Accuracy'
);
plt.show()

Code
# Loss trace plot
plt.figure()
sns.relplot(
  data = hist[hist['metric'] == 'loss'],
  kind = 'scatter',
  x = 'epoch',
  y = 'value',
  hue = 'split'
).set(
  xlabel = 'Epoch',
  ylabel = 'Loss'
);
plt.show()

system.time({
history <- model %>% fit(
  x_train, y_train, 
  epochs = 30, batch_size = 128, 
  validation_split = 0.2
)
})
Epoch 1/30
313/313 - 12s - loss: 4.2324 - accuracy: 0.0479 - val_loss: 3.8928 - val_accuracy: 0.1007 - 12s/epoch - 37ms/step
Epoch 2/30
313/313 - 11s - loss: 3.6810 - accuracy: 0.1334 - val_loss: 3.4087 - val_accuracy: 0.1859 - 11s/epoch - 36ms/step
Epoch 3/30
313/313 - 11s - loss: 3.3213 - accuracy: 0.1960 - val_loss: 3.3177 - val_accuracy: 0.2009 - 11s/epoch - 36ms/step
Epoch 4/30
313/313 - 11s - loss: 3.0629 - accuracy: 0.2442 - val_loss: 2.9877 - val_accuracy: 0.2666 - 11s/epoch - 36ms/step
Epoch 5/30
313/313 - 11s - loss: 2.8499 - accuracy: 0.2852 - val_loss: 2.9542 - val_accuracy: 0.2652 - 11s/epoch - 36ms/step
Epoch 6/30
313/313 - 11s - loss: 2.6751 - accuracy: 0.3219 - val_loss: 2.7540 - val_accuracy: 0.3177 - 11s/epoch - 37ms/step
Epoch 7/30
313/313 - 12s - loss: 2.5119 - accuracy: 0.3538 - val_loss: 2.5667 - val_accuracy: 0.3526 - 12s/epoch - 39ms/step
Epoch 8/30
313/313 - 12s - loss: 2.3687 - accuracy: 0.3819 - val_loss: 2.5618 - val_accuracy: 0.3470 - 12s/epoch - 38ms/step
Epoch 9/30
313/313 - 11s - loss: 2.2353 - accuracy: 0.4117 - val_loss: 2.4659 - val_accuracy: 0.3695 - 11s/epoch - 37ms/step
Epoch 10/30
313/313 - 12s - loss: 2.1188 - accuracy: 0.4354 - val_loss: 2.4559 - val_accuracy: 0.3791 - 12s/epoch - 38ms/step
Epoch 11/30
313/313 - 12s - loss: 2.0101 - accuracy: 0.4614 - val_loss: 2.3046 - val_accuracy: 0.4076 - 12s/epoch - 37ms/step
Epoch 12/30
313/313 - 12s - loss: 1.9032 - accuracy: 0.4827 - val_loss: 2.3713 - val_accuracy: 0.3998 - 12s/epoch - 37ms/step
Epoch 13/30
313/313 - 12s - loss: 1.7979 - accuracy: 0.5089 - val_loss: 2.2715 - val_accuracy: 0.4226 - 12s/epoch - 37ms/step
Epoch 14/30
313/313 - 12s - loss: 1.7050 - accuracy: 0.5266 - val_loss: 2.3203 - val_accuracy: 0.4111 - 12s/epoch - 37ms/step
Epoch 15/30
313/313 - 11s - loss: 1.6266 - accuracy: 0.5452 - val_loss: 2.2828 - val_accuracy: 0.4242 - 11s/epoch - 37ms/step
Epoch 16/30
313/313 - 11s - loss: 1.5506 - accuracy: 0.5646 - val_loss: 2.2076 - val_accuracy: 0.4397 - 11s/epoch - 37ms/step
Epoch 17/30
313/313 - 11s - loss: 1.4676 - accuracy: 0.5815 - val_loss: 2.3869 - val_accuracy: 0.4154 - 11s/epoch - 36ms/step
Epoch 18/30
313/313 - 11s - loss: 1.4044 - accuracy: 0.5954 - val_loss: 2.2411 - val_accuracy: 0.4457 - 11s/epoch - 36ms/step
Epoch 19/30
313/313 - 12s - loss: 1.3215 - accuracy: 0.6185 - val_loss: 2.3227 - val_accuracy: 0.4362 - 12s/epoch - 37ms/step
Epoch 20/30
313/313 - 11s - loss: 1.2654 - accuracy: 0.6323 - val_loss: 2.3015 - val_accuracy: 0.4403 - 11s/epoch - 36ms/step
Epoch 21/30
313/313 - 11s - loss: 1.2092 - accuracy: 0.6478 - val_loss: 2.2746 - val_accuracy: 0.4444 - 11s/epoch - 36ms/step
Epoch 22/30
313/313 - 11s - loss: 1.1527 - accuracy: 0.6573 - val_loss: 2.3165 - val_accuracy: 0.4467 - 11s/epoch - 36ms/step
Epoch 23/30
313/313 - 11s - loss: 1.1081 - accuracy: 0.6709 - val_loss: 2.3376 - val_accuracy: 0.4478 - 11s/epoch - 36ms/step
Epoch 24/30
313/313 - 11s - loss: 1.0640 - accuracy: 0.6835 - val_loss: 2.3923 - val_accuracy: 0.4321 - 11s/epoch - 36ms/step
Epoch 25/30
313/313 - 11s - loss: 1.0223 - accuracy: 0.6936 - val_loss: 2.5082 - val_accuracy: 0.4384 - 11s/epoch - 36ms/step
Epoch 26/30
313/313 - 11s - loss: 0.9673 - accuracy: 0.7070 - val_loss: 2.4766 - val_accuracy: 0.4387 - 11s/epoch - 36ms/step
Epoch 27/30
313/313 - 11s - loss: 0.9430 - accuracy: 0.7146 - val_loss: 2.3742 - val_accuracy: 0.4426 - 11s/epoch - 36ms/step
Epoch 28/30
313/313 - 11s - loss: 0.9014 - accuracy: 0.7236 - val_loss: 2.4442 - val_accuracy: 0.4514 - 11s/epoch - 36ms/step
Epoch 29/30
313/313 - 11s - loss: 0.8834 - accuracy: 0.7289 - val_loss: 2.6080 - val_accuracy: 0.4310 - 11s/epoch - 36ms/step
Epoch 30/30
313/313 - 11s - loss: 0.8477 - accuracy: 0.7383 - val_loss: 2.5085 - val_accuracy: 0.4524 - 11s/epoch - 36ms/step
    user   system  elapsed 
2887.588  126.822  345.324 
plot(history)

5 Testing

Evaluate model performance on the test data:

score = model.evaluate(x_test, y_test, verbose = 0)
print("Test loss:", score[0])
Test loss: 2.5904178619384766
print("Test accuracy:", score[1])
Test accuracy: 0.43849998712539673

model %>% evaluate(x_test, y_test)
313/313 - 1s - loss: 2.4325 - accuracy: 0.4631 - 1s/epoch - 4ms/step
    loss accuracy 
2.432508 0.463100 

Generate predictions on new data:

model %>% predict(x_test) %>% k_argmax()
313/313 - 1s - 1s/epoch - 4ms/step
tf.Tensor([95 80 55 ... 51 42 26], shape=(10000), dtype=int64)