
Building deep learning models can often consume hours or even days before tangible results appear. You might find yourself needing to pause model training to adjust the learning rate, save training logs for future reference, or visualize the training progress in TensorBoard. The workload involved in performing these fundamental tasks can be considerable, which is precisely where TensorFlow callbacks become valuable.
This article will explore the specifics, implementation, and examples of TensorFlow callbacks. Here’s a brief overview of what we will cover:
- Understanding callback functions
- Events that trigger callbacks
- Available callbacks in TensorFlow 2.0
- Conclusion
Understanding Callback Functions
In simple terms, callbacks are specialized functions that execute at various stages during the training process. They assist in preventing overfitting, visualizing training processes, debugging code, saving checkpoints, and generating logs, among other tasks. TensorFlow provides numerous predefined callbacks, and multiple callbacks can be used simultaneously. We will examine some of the available callbacks along with their practical applications.
When Callbacks are Triggered
Callbacks are activated by specific events. There are several event types during training that can trigger a callback, such as:
on_epoch_begin
: triggered at the start of a new epoch.
on_epoch_end
: triggered at the conclusion of an epoch.
on_batch_begin
: triggered when a new batch is ready for training.
on_batch_end
: triggered when a batch finishes training.
on_train_begin
: triggered when training starts.
on_train_end
: triggered when training concludes.
To utilize a callback in model training, simply pass the callback object within the model.fit
function; for instance:
model.fit(x, y, callbacks=list_of_callbacks)
Available Callbacks in TensorFlow 2.0
Let’s examine the available callbacks in the tf.keras.callbacks
module.
1. EarlyStopping
This callback is widely used to monitor specific metrics and halt model training when improvements stagnate. For instance, if you want to halt training when accuracy fails to improve by 0.05, this callback is your solution. It helps mitigate overfitting to a certain degree.
tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False)
monitor
: identifies the metrics to observe.
min_delta
: the minimal improvement expected in each epoch.
patience
: the number of epochs to wait before stopping training.
verbose
: indicates whether to print additional logs.
mode
: specifies whether the monitored metrics should be increasing, decreasing, or inferred from the name; possible values are 'min'
, 'max'
, or 'auto'
.
baseline
: expected values for the monitored metrics.
restore_best_weights
: if set to True
, the model uses weights from the epoch with the best monitored metrics; otherwise, the weights from the last epoch will be used.
The EarlyStopping
callback is triggered at the on_epoch_end
event during training.
2. ModelCheckpoint
This callback is particularly useful for saving the model frequently during the training process. It’s especially beneficial when training deep learning models that take extensive time to train. The callback monitors the training and saves model checkpoints regularly, based on specific metrics.
tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch')
filepath
: the path to save the model, with formatting options like model-{epoch:02d}-{val_loss:0.2f}
to create filenames that include specified values.
monitor
: the metric to monitor.
save_best_only
: if True
, the best model will not be overwritten.
mode
: defines if monitored metrics should be increasing, decreasing, or inferred; possible values are 'min'
, 'max'
, or 'auto'
.
save_weights_only
: if True
, only weights will be saved rather than the complete model.
save_freq
: if set to 'epoch'
, the model saves after each epoch; if an integer is provided, it saves after that many batches.
The ModelCheckpoint
callback is also executed at the on_epoch_end
trigger during training.
3. TensorBoard
This callback is exceptional for visualizing your model’s training summary. It generates logs for TensorBoard, enabling you to visualize your training progress later. We will delve deeper into TensorBoard in another article.
tf.keras.callbacks.TensorBoard(log_dir='logs', histogram_freq=0,
write_graph=True,
write_images=False,
update_freq='epoch',
profile_batch=2,
embeddings_freq=0,
embeddings_metadata=None,
**kwargs)
For now, let’s focus on the log_dir
parameter, which specifies the folder where the logs will be stored. To launch TensorBoard, execute the command:
tensorboard --logdir=path_to_your_logs
You can initiate TensorBoard before or after beginning your training process.
The TensorBoard callback is also triggered at on_epoch_end
.
4. LearningRateScheduler
This callback is invaluable when you need to adjust the learning rate as the training progresses. For example, you might want to decrease the learning rate after a certain number of epochs. The LearningRateScheduler
allows for such modification.
tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
schedule
: a function that accepts the epoch index and returns a new learning rate.
verbose
: indicates if additional logs should be printed.
Below is an example of reducing the learning rate after three epochs.
Function to be supplied to the ‘schedule’ parameter for the LearningRateScheduler callback.
Through this output, we see that the learning rate decreases after the fourth epoch. Here, verbose
is set to 1
to monitor the learning rate.
This callback is also triggered at the on_epoch_end
.
5. CSVLogger
This callback is designed to log training details into a CSV file. The logged parameters include epoch
, accuracy
, loss
, val_accuracy
, and val_loss
. It’s essential to include accuracy
as a metric during the model compilation; otherwise, an execution error may occur.
tf.keras.callbacks.CSVLogger(filename, separator=',',
append=False)
The logger accepts parameters filename
, separator
, and append
. The append
option defines whether to append to an existing file or overwrite it with a new file.
The CSVLogger
callback gets executed at each on_epoch_end
event. Upon ending an epoch, the logs are stored into a file.
6. LambdaCallback
This callback is handy when you need to trigger custom functions for various events when provided callbacks are not sufficient. For instance, this can be used to save logs into a database.
tf.keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None,
on_batch_begin=None,
on_batch_end=None,
on_train_begin=None,
on_train_end=None,
**kwargs)
The parameters for this callback expect a function that receives the specified arguments:
on_epoch_begin
and on_epoch_end
: epoch, logs
on_batch_begin
and on_batch_end
: batch, logs
on_train_begin
and on_train_end
: logs
For illustration, here’s an example:
Function designed to log into a file upon completing a batch.
This callback records the logs into a file after processing a batch. The output stored in the file looks as follows:
This callback is executed for all events, executing the custom functions based on specified parameters.
7. ReduceLROnPlateau
This callback adjusts the learning rate when metrics have ceased to improve. Unlike the LearningRateScheduler
, it reduces the learning rate based on the metrics rather than epochs.
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,
patience=10,
verbose=0,
mode='auto',
min_delta=0.0001,
cooldown=0,
min_lr=0,
**kwargs)
<pSeveral parameters resemble those in the EarlyStopping
callback, focusing on:
monitor
, patience
, verbose
, mode
, min_delta
: similar to EarlyStopping
.
factor
: the rate by which the learning rate will be reduced (new learning rate = old learning rate * factor).
cooldown
: the number of epochs to wait before monitoring the metrics again.
min_lr
: the minimal limit for the learning rate (the learning rate will not drop below this value).
This callback is also activated at the on_epoch_end
event.
8. RemoteMonitor
This callback proves useful when you need to send logs to an API. Its functionality can also be replicated using the LambdaCallback
.
tf.keras.callbacks.RemoteMonitor(root='http://localhost:9000', path='/publish/epoch/end/',
field='data',
headers=None,
send_as_json=False)
root
: the designated URL.
path
: the endpoint’s name/path.
field
: the key name that will receive all logs.
header
: the headers required to be sent.
send_as_json
: if True
, the data will be sent in JSON format.
To observe this callback working, you’ll need an endpoint running on localhost:8000. You can employ Node.js to set this up. Save the following code in a file named server.js:
After saving, start the server by running node server.js
(ensure that Node.js is installed). Post-epoch, you’ll observe logs in the Node console. If the server isn’t active, a warning message will appear.
This callback is also executed at the on_epoch_end
event.
9. BaseLogger & History
These callbacks are automatically utilized in all Keras models. The history
object returned by model.fit
contains a dictionary with the average accuracy and loss across epochs. The parameters
property includes a dictionary of training parameters (epochs
, steps
, verbose
). If you implement a callback that alters the learning rate, it will also be included within the history object.
Output representing model_history.history.
BaseLogger
aggregates averages of your metrics throughout the epochs. Thus, metrics displayed at the end of an epoch represent averages across all batches.
10. TerminateOnNaN
This callback is designed to halt training if the loss becomes NaN
.
tf.keras.callbacks.TerminateOnNaN()
Conclusion
You can choose from these callbacks as per your requirements. It is often beneficial (or even necessary) to utilize multiple callbacks, such as TensorBoard
for tracking progress, EarlyStopping
or LearningRateScheduler
to avoid overfitting, and ModelCheckpoint
to save model training progress.
You can execute code for any of the callbacks within tensorflow.keras
. We hope this information aids you in your model training journey.
Happy Deep Learning!
Thank you for learning with the DigitalOcean Community. Explore our offerings for compute, storage, networking, and managed databases.
Learn more about our products
Welcome to DediRock, your trusted partner in high-performance hosting solutions. At DediRock, we specialize in providing dedicated servers, VPS hosting, and cloud services tailored to meet the unique needs of businesses and individuals alike. Our mission is to deliver reliable, scalable, and secure hosting solutions that empower our clients to achieve their digital goals. With a commitment to exceptional customer support, cutting-edge technology, and robust infrastructure, DediRock stands out as a leader in the hosting industry. Join us and experience the difference that dedicated service and unwavering reliability can make for your online presence. Launch our website.