Keeping Up with PyTorch Lightning and Hydra — 2nd Edition
How I shrank my training script by 50% using the new features from PyTorch lightning 1.1 and Hydra 1.0
Short note on the 2nd edition: Back in August 2020, I wrote a story about how I used PyTorch Lightning 0.9.0 and Hydra’s fourth release candidate for 1.0.0 to shrink my training script by 50%. Half a year later in February 2021, we now have PyTorch Lightning 1.1 and Hydra 1.0. No major changes were introduced in Hydra 1.0, but PyTorch Lightning 1.0 did include major changes such as the deprecation of the Result abstraction. So I decided to write this 2nd edition of my original post to “keep up” with PyTorch Lightning and Hydra. You can still read the 1st edition if you’d like, but the 2nd edition should cover the latest changes as of February 2021. Enjoy!
Introduction
PyTorch Lightning 1.1 and Hydra 1.0 were recently released with a choke-full of new features and mostly final APIs. I thought it’d be a good time for me to revisit my side project Leela Zero PyTorch to see how these new versions can be integrated into it. In this post, I’ll talk about some of the new features of the two libraries, and how they helped Leela Zero PyTorch. I’m not going to talk about the details about Leela Zero PyTorch all too much here, so if you want to read more about my side project for more context, you can read my previous blog post about it here.
PyTorch Lightning 1.1
After months of hard work, the PyTorch Lightning released 1.0 in October 2020. It introduces a number of new features and a final stable API. They then released 1.1 a couple of months later with exciting model parallelism support. We will focus on the final API introduced in 1.0, and dedicate a separate story for model parallelism in the future. Before we jump in, if you want to read more about these releases, check out the official blog posts: 1.0 and 1.1. If you want to learn more about PyTorch Lightning in general, check out the Github page as well as the official documentation.
Simplified Logging
Have you found yourself repetitively implementing *_epoch_end
methods just so that you can aggregate results from your *_step
methods? Have you found yourself getting tripped on how to properly log the metrics calculated in your *_step
and *_epoch_end
methods? You’re not alone, and PyTorch Lightning 1.0 has introduced a new method self.log()
to solve these very problems.
All you really have to do is to call self.log()
with the metrics you want to log, and it will handle the details of logging and aggregation for you, which can be customized via keyword arguments. Let’s take a look at how they’re used in my project:
In training_step()
, I calculate the overall loss and log it with prog_bar
set to True
so that it will be displayed in the progress bar. Then I log the mean squared error loss, cross entropy loss and the accuracy (which is calculated using PyTorch Lightning’s new metrics package, which will be discussed shortly) using the convenience method log_dict()
. By default, self.log()
logs only at the current step, not at the epoch level, if called in training_step()
(on_step=True
, on_epoch=False
), but we can change the behavior if we want. In validation_step()
and test_step()
, self.log()
behaves in the opposite way in regards to aggregation: it logs only at the epoch level (on_step=False
, on_epoch=True
). We don’t need to write the code to aggregate them at the epoch level since it’s taken care for you automatically. Again, this behavior can be customized via keyword arguments. You can read more about the details of how logging is done in PyTorch Lightning here.
Metrics
Continuing their work in 0.8, the PyTorch Lightning team has introduced even more implementations of metrics in 1.0. Every metrics implementation in PyTorch Lightning is a PyTorch Module, and has its functional counterpart, making it extremely easy and flexible to use. The module implementations take care of aggregating metrics data across steps, while the functional ones are for simple on-the-fly calculations. For my project, I decided to use the module implementation of accuracy as I am interested in seeing the accuracy at the epoch level for validation and testing. Let’s take a look at the code:
The first step is to initialize separate metrics modules as instance attributes of the Lightning Module. Then, it’s just a matter of calling them where you have data to calculate the metrics you want. However, you’ll see two distinct ways to calling the metrics modules in my code. The first way can be found in training_step()
, where I log the return value of calling the metrics module. This is a more direct way of logging where the metrics calculated from the given parameters is logged without any automatic aggregation. As a result, only the step-wise accuracy would be logged for training, and I could’ve just used the functional implementation for the same end result.
The second way can be found in validation_step()
and test_step()
. I call the modules with the data on one line, then I log the modules directly. Here, automatic aggregation happens according to the log settings, i.e. the accuracy metrics would be aggregated and reported at the epoch level, as that’s the log setting we are using (remember, logging happens at the epoch level for validation and testing by default). There is no need for you to aggregate the data and calculate the metrics in the separate *_epoch_end()
methods, saving you lines and the headache of having to deal with data aggregation yourself.
There are many other metrics implementations included in PyTorch Lightning now, including advanced NLP metrics like the BLEU score. You can read more about it here.
LightningDataModule
Another pain point you may have had with PyTorch Lightning is handling various data sets. Up until 0.9, PyTorch Lightning has remained silent on how to organize your data processing code, except that you use PyTorch’s Dataset and DataLoader. This certainly gave you a lot of freedom, but made it hard to keep your data set implementation clean, maintainable and easily sharable with others. In 0.9, PyTorch Lightning introduced a new way of organizing data processing code in LightningDataModule
, and it officially became part of the stable API in 1.0. LightningDataModule
encapsulates the most common steps in data processing. It has a simple interface with five methods: prepare_data()
, setup()
, train_dataloader()
, val_dataloader()
and test_dataloader()
. Let’s go over how each of them is implemented in my project to understand its role.
prepare_data()
: This method is for anything that must be done in the main process before forking subprocesses for distributed training. Tasks such as downloading, preprocessing or saving to disk are good candidates for this method. One thing to be wary of is that any state set here will not be carried over to the subprocesses in distributed training, so you should be careful not to set any state here. In my project, I rely on Leela Zero for preprocessing Go sgf files, so I decided to skip implementing this method, but I could technically implement the preprocessing step in this method.- setup(): This method is for anything that must be done for each subprocess for distributed training. You should construct actual PyTorch
Datasets
and set any necessary states here. In Leela Zero PyTorch, I initialize myDatasets
, which read in the data from disk and turns them into tensors, and save them as states. *_dataloader()
: This is where you initializeDataLoaders
for training/validation/testing. In my case, I simply use the Datasets that were constructed insetup()
as well as the configurations passed in during the initialization of theLightningDataModule
to initialize theDataLoaders
.
Now, it’s just a matter of passing the LightningDataModule
into trainer.fit()
and trainer.test()
. You can also imagine a scenario where I implement another LightningDataModule
for a different type of data set such as chess game data, and the trainer will accept it just the same. I can take it further and use Hydra’s object instantiation pattern and easily switch between various data modules.
Hydra 1.0
Hydra is a “framework for elegantly configuring complex applications.” As you’d probably know already, deep learning training scripts can quickly become complex with lots of knobs and dials. Hydra can help you handle this complexity in an elegant way.
Hydra released its official 1.0 back in September 2020, and now is on its way to the next 1.1. release. Before we jump in, if you want to learn more about Hydra in general, check out the official website as well as the official documentation!
@hydra.main()
You can add this decorator to any function that accepts OmegaConf’s DictConfig
, and Hydra will automatically handle various aspects of your script. This is not a new feature per se, but a feature I originally decided not to use due to the fact that it takes over the output directory structure as well as the working directory. I actually used Hydra’s experimental Compose API, which I will discuss later, to get around this issue. However, after talking to Omry, the creator of Hydra, I realized that not only is this not the recommended approach, but also I lose a number of cool features provided by Hydra such as automatic handling of the command line interface, automatic help messages and tab completion. Furthermore, after using it for some time, I’ve found that Hydra’s output directory and working directory management are quite useful, because I do not have to manually set up the logging directory structure on PyTorch Lightning’s side. You can read more about this decorator in Hydra’s basic tutorial.
Package Directive
In Hydra 0.11, there was only one global namespace for the configurations, but in 1.0, you can organize your configurations in different namespaces using package directives. This allows you to keep your yaml configuration files flat and clean without unnecessary nesting. Let’s take a look at the network size configuration from Leela Zero PyTorch:
As you can see, package directives make your configuration more manageable. You can read more about package directives and their more advanced use cases here.
Instantiating Objects
Hydra provides a feature where you can instantiate an object or call a function based on configurations. This is extremely useful when you want your script to have a simple interface to switch between various implementations. This is not a new feature either, but its interface has vastly improved in 1.0. In my case, I use it to switch between network sizes, training loggers and datasets. Let’s take the network size configuration as an example.
NetworkLightningModule
accepts two arguments for its __init__()
, network_conf
and train_conf
. The former is passed in from the configuration, and the latter is passed in as an extra argument in instantiate()
(cfg.train
). All you have to do to select different network sizes is to pass in +network={small,big,huge}
in the command line. You can even imagine selecting a totally different architecture by creating a new config with a different _target_
, and passing in the config name in the command line. No need to pass in all the small details via the command line! You can read more about this pattern here.
Compose API
Although Hydra’s Compose API is not the recommended way for writing scripts, it’s still recommended and useful for writing unit tests. I used it to write unit tests for the main training script. Again, this is not a new feature, but Hydra 1.0 does bring in a cleaner interface for the Compose API using Python’s context manager (the with
statement).
You can read more about the Compose API here, and how to use it unit tests here.
Unused Features: Structured Configs and Variable Interpolation
There are many other features in Hydra 1.0 I didn’t take advantage of, mostly due to the fact that I haven’t had enough time to integrate them. I’ll go over the biggest one in this section — structured configs.
Structured configs are a major new feature introduced in 1.0 that utilize Python’s dataclasses to provide runtime and static type checking, which can be extremely useful as your application grows in complexity. I’ll probably integrate them in the future when I can find time, so please stay tuned for another blog post!
Conclusion
Since I wrote my first blog post about Leela Zero PyTorch, both Hydra and PyTorch Lightning have introduced a number of new features and abstractions that can help you greatly simplify your PyTorch scripts. Take a look at my training script to see how they helped:
As you can see above, my main training script now consists of mere 28 lines compared to 56 lines before. Moreover, each part of the training pipeline, the neural network architecture, data set and logger, is modular and easily swappable. This enables faster iteration, easier maintenance and better reproducibility, allowing you to focus on the most fun and important parts of your projects. I hope this blog post has been helpful as you “keep up” with these two awesome libraries! You can find the code for Leela Zero PyTorch here.