Keeping Up with PyTorch Lightning and Hydra — 2nd Edition

How I shrank my training script by 50% using the new features from PyTorch lightning 1.1 and Hydra 1.0

Peter Yu
Towards Data Science

--

Short note on the 2nd edition: Back in August 2020, I wrote a story about how I used PyTorch Lightning 0.9.0 and Hydra’s fourth release candidate for 1.0.0 to shrink my training script by 50%. Half a year later in February 2021, we now have PyTorch Lightning 1.1 and Hydra 1.0. No major changes were introduced in Hydra 1.0, but PyTorch Lightning 1.0 did include major changes such as the deprecation of the Result abstraction. So I decided to write this 2nd edition of my original post to “keep up” with PyTorch Lightning and Hydra. You can still read the 1st edition if you’d like, but the 2nd edition should cover the latest changes as of February 2021. Enjoy!

Try to keep up! — Source

Introduction

PyTorch Lightning 1.1 and Hydra 1.0 were recently released with a choke-full of new features and mostly final APIs. I thought it’d be a good time for me to revisit my side project Leela Zero PyTorch to see how these new versions can be integrated into it. In this post, I’ll talk about some of the new features of the two libraries, and how they helped Leela Zero PyTorch. I’m not going to talk about the details about Leela Zero PyTorch all too much here, so if you want to read more about my side project for more context, you can read my previous blog post about it here.

PyTorch Lightning 1.1

After months of hard work, the PyTorch Lightning released 1.0 in October 2020. It introduces a number of new features and a final stable API. They then released 1.1 a couple of months later with exciting model parallelism support. We will focus on the final API introduced in 1.0, and dedicate a separate story for model parallelism in the future. Before we jump in, if you want to read more about these releases, check out the official blog posts: 1.0 and 1.1. If you want to learn more about PyTorch Lightning in general, check out the Github page as well as the official documentation.

Simplified Logging

Have you found yourself repetitively implementing *_epoch_end methods just so that you can aggregate results from your *_step methods? Have you found yourself getting tripped on how to properly log the metrics calculated in your *_step and *_epoch_end methods? You’re not alone, and PyTorch Lightning 1.0 has introduced a new method self.log()to solve these very problems.

All you really have to do is to call self.log() with the metrics you want to log, and it will handle the details of logging and aggregation for you, which can be customized via keyword arguments. Let’s take a look at how they’re used in my project:

The new, simplified logging interface helps you not repeat yourself in metrics logging.

In training_step(), I calculate the overall loss and log it with prog_bar set to True so that it will be displayed in the progress bar. Then I log the mean squared error loss, cross entropy loss and the accuracy (which is calculated using PyTorch Lightning’s new metrics package, which will be discussed shortly) using the convenience method log_dict(). By default, self.log() logs only at the current step, not at the epoch level, if called in training_step() (on_step=True, on_epoch=False), but we can change the behavior if we want. In validation_step() and test_step(), self.log() behaves in the opposite way in regards to aggregation: it logs only at the epoch level (on_step=False, on_epoch=True). We don’t need to write the code to aggregate them at the epoch level since it’s taken care for you automatically. Again, this behavior can be customized via keyword arguments. You can read more about the details of how logging is done in PyTorch Lightning here.

Metrics

Continuing their work in 0.8, the PyTorch Lightning team has introduced even more implementations of metrics in 1.0. Every metrics implementation in PyTorch Lightning is a PyTorch Module, and has its functional counterpart, making it extremely easy and flexible to use. The module implementations take care of aggregating metrics data across steps, while the functional ones are for simple on-the-fly calculations. For my project, I decided to use the module implementation of accuracy as I am interested in seeing the accuracy at the epoch level for validation and testing. Let’s take a look at the code:

Accuracy modules in action. Notice the two different ways of invoking them.

The first step is to initialize separate metrics modules as instance attributes of the Lightning Module. Then, it’s just a matter of calling them where you have data to calculate the metrics you want. However, you’ll see two distinct ways to calling the metrics modules in my code. The first way can be found in training_step(), where I log the return value of calling the metrics module. This is a more direct way of logging where the metrics calculated from the given parameters is logged without any automatic aggregation. As a result, only the step-wise accuracy would be logged for training, and I could’ve just used the functional implementation for the same end result.

The second way can be found in validation_step() and test_step(). I call the modules with the data on one line, then I log the modules directly. Here, automatic aggregation happens according to the log settings, i.e. the accuracy metrics would be aggregated and reported at the epoch level, as that’s the log setting we are using (remember, logging happens at the epoch level for validation and testing by default). There is no need for you to aggregate the data and calculate the metrics in the separate *_epoch_end() methods, saving you lines and the headache of having to deal with data aggregation yourself.

There are many other metrics implementations included in PyTorch Lightning now, including advanced NLP metrics like the BLEU score. You can read more about it here.

LightningDataModule

Another pain point you may have had with PyTorch Lightning is handling various data sets. Up until 0.9, PyTorch Lightning has remained silent on how to organize your data processing code, except that you use PyTorch’s Dataset and DataLoader. This certainly gave you a lot of freedom, but made it hard to keep your data set implementation clean, maintainable and easily sharable with others. In 0.9, PyTorch Lightning introduced a new way of organizing data processing code in LightningDataModule, and it officially became part of the stable API in 1.0. LightningDataModule encapsulates the most common steps in data processing. It has a simple interface with five methods: prepare_data(), setup(), train_dataloader(), val_dataloader() and test_dataloader(). Let’s go over how each of them is implemented in my project to understand its role.

LightningDataModule helps you organize your data processing code.
  • prepare_data(): This method is for anything that must be done in the main process before forking subprocesses for distributed training. Tasks such as downloading, preprocessing or saving to disk are good candidates for this method. One thing to be wary of is that any state set here will not be carried over to the subprocesses in distributed training, so you should be careful not to set any state here. In my project, I rely on Leela Zero for preprocessing Go sgf files, so I decided to skip implementing this method, but I could technically implement the preprocessing step in this method.
  • setup(): This method is for anything that must be done for each subprocess for distributed training. You should construct actual PyTorch Datasets and set any necessary states here. In Leela Zero PyTorch, I initialize my Datasets, which read in the data from disk and turns them into tensors, and save them as states.
  • *_dataloader(): This is where you initialize DataLoaders for training/validation/testing. In my case, I simply use the Datasets that were constructed in setup() as well as the configurations passed in during the initialization of the LightningDataModule to initialize the DataLoaders.

Now, it’s just a matter of passing the LightningDataModule into trainer.fit() and trainer.test(). You can also imagine a scenario where I implement another LightningDataModule for a different type of data set such as chess game data, and the trainer will accept it just the same. I can take it further and use Hydra’s object instantiation pattern and easily switch between various data modules.

Hydra 1.0

Hydra is a “framework for elegantly configuring complex applications.” As you’d probably know already, deep learning training scripts can quickly become complex with lots of knobs and dials. Hydra can help you handle this complexity in an elegant way.

Hydra released its official 1.0 back in September 2020, and now is on its way to the next 1.1. release. Before we jump in, if you want to learn more about Hydra in general, check out the official website as well as the official documentation!

@hydra.main()

You can add this decorator to any function that accepts OmegaConf’s DictConfig, and Hydra will automatically handle various aspects of your script. This is not a new feature per se, but a feature I originally decided not to use due to the fact that it takes over the output directory structure as well as the working directory. I actually used Hydra’s experimental Compose API, which I will discuss later, to get around this issue. However, after talking to Omry, the creator of Hydra, I realized that not only is this not the recommended approach, but also I lose a number of cool features provided by Hydra such as automatic handling of the command line interface, automatic help messages and tab completion. Furthermore, after using it for some time, I’ve found that Hydra’s output directory and working directory management are quite useful, because I do not have to manually set up the logging directory structure on PyTorch Lightning’s side. You can read more about this decorator in Hydra’s basic tutorial.

Package Directive

In Hydra 0.11, there was only one global namespace for the configurations, but in 1.0, you can organize your configurations in different namespaces using package directives. This allows you to keep your yaml configuration files flat and clean without unnecessary nesting. Let’s take a look at the network size configuration from Leela Zero PyTorch:

“@package _group_” indicates this configuration should be under the current group, which in this case is “network”.
The network size configuration has been added under “network” as specified. Please note that “board_size” and “in_channels” come from the data configuration (composition!)

As you can see, package directives make your configuration more manageable. You can read more about package directives and their more advanced use cases here.

Instantiating Objects

Hydra provides a feature where you can instantiate an object or call a function based on configurations. This is extremely useful when you want your script to have a simple interface to switch between various implementations. This is not a new feature either, but its interface has vastly improved in 1.0. In my case, I use it to switch between network sizes, training loggers and datasets. Let’s take the network size configuration as an example.

Configuration for the “big”, “huge” and “small” networks
Instantiate the network based on the selected configuration. Notice that you can pass in additional arguments to instantiate() as I did with cfg.train here.

NetworkLightningModule accepts two arguments for its __init__(), network_conf and train_conf. The former is passed in from the configuration, and the latter is passed in as an extra argument in instantiate() (cfg.train). All you have to do to select different network sizes is to pass in +network={small,big,huge} in the command line. You can even imagine selecting a totally different architecture by creating a new config with a different _target_, and passing in the config name in the command line. No need to pass in all the small details via the command line! You can read more about this pattern here.

Compose API

Although Hydra’s Compose API is not the recommended way for writing scripts, it’s still recommended and useful for writing unit tests. I used it to write unit tests for the main training script. Again, this is not a new feature, but Hydra 1.0 does bring in a cleaner interface for the Compose API using Python’s context manager (the with statement).

You can easily compose a configuration dictionary using Hydra’s Compose API. It helps you keep your unit tests clean and easily debuggable.

You can read more about the Compose API here, and how to use it unit tests here.

Unused Features: Structured Configs and Variable Interpolation

There are many other features in Hydra 1.0 I didn’t take advantage of, mostly due to the fact that I haven’t had enough time to integrate them. I’ll go over the biggest one in this section — structured configs.

Structured configs are a major new feature introduced in 1.0 that utilize Python’s dataclasses to provide runtime and static type checking, which can be extremely useful as your application grows in complexity. I’ll probably integrate them in the future when I can find time, so please stay tuned for another blog post!

Conclusion

Since I wrote my first blog post about Leela Zero PyTorch, both Hydra and PyTorch Lightning have introduced a number of new features and abstractions that can help you greatly simplify your PyTorch scripts. Take a look at my training script to see how they helped:

My new and old training scripts. The line count went from 56 to 28. 50% reduction!

As you can see above, my main training script now consists of mere 28 lines compared to 56 lines before. Moreover, each part of the training pipeline, the neural network architecture, data set and logger, is modular and easily swappable. This enables faster iteration, easier maintenance and better reproducibility, allowing you to focus on the most fun and important parts of your projects. I hope this blog post has been helpful as you “keep up” with these two awesome libraries! You can find the code for Leela Zero PyTorch here.

--

--

PhD Student at UMich Researching NLP and Cognitive Architectures • Perviously Real-time Distributed System Engineer turned NLP Research Engineer at ASAPP