Learn PyTorch Lightning: A Lightweight PyTorch Wrapper¶
PyTorch Lightning is a lightweight PyTorch wrapper that helps organize your PyTorch code, making it more readable and maintainable. You can read about how to convert your PyTorch code: https://lightning.ai/docs/pytorch/stable/starter/converting.html. It abstracts away much of the boilerplate code, allowing you to focus on the core logic of your models. This tutorial will guide you through the basics of PyTorch Lightning.
1. Introduction to PyTorch Lightning¶
PyTorch Lightning separates the research code from the engineering code, helping you write scalable and more readable code. It automates most of the training loop and other common functionalities, making it easier to replicate results and scale your projects.
2. Installing PyTorch Lightning¶
Before you begin, you need to have PyTorch installed. Then, install PyTorch Lightning via pip:
pip install pytorch-lightning
3. Creating a Lightning Module¶
A Lightning Module is where you define your model, just like a standard PyTorch nn.Module
, but you also define the training step, validation step, etc. Here's a simple example:
import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(28 * 28, 10)
def forward(self, x):
return self.layer(x.view(x.size(0), -1))
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
return loss
# Uncomment to add validation step
# def validation_step(self, batch, batch_idx):
# x, y = batch
# logits = self(x)
# loss = F.cross_entropy(logits, y)
# # Add logging
# self.log('val_loss', loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
model = LitModel()
4. Data Preparation¶
PyTorch Lightning works with the standard PyTorch DataLoader. Let's load the MNIST dataset as an example:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define data loaders
train_loader = DataLoader(
datasets.MNIST("", train=True, download=True, transform=transforms.ToTensor()), batch_size=32, shuffle=True
)
# Check data
x, y = next(iter(train_loader))
print(x.shape, y.shape)
5. Training the Model¶
Training a model with PyTorch Lightning is straightforward. You just need to initialize a Trainer
and call the fit
method:
# Initialize our model
model = LitModel()
print(model)
# Initialize a trainer
trainer = pl.Trainer(max_epochs=3)
# Train the model
trainer.fit(model, train_loader)
6. Validation and Testing¶
You can easily add validation and test steps in your LitModel
. For validation, implement the validation_step
method:
class LitModel(pl.LightningModule):
# ...
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
# Add logging
self.log('val_loss', loss)
Use a similar approach for the test_step
.
7. Logging and Callbacks¶
PyTorch Lightning comes with built-in support for logging and callbacks. You can use TensorBoard, or other loggers like MLFlow, Comet, etc.
from pytorch_lightning.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", dirpath="./my_model", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}"
)
trainer = pl.Trainer(max_epochs=3, callbacks=[checkpoint_callback])
8. Advanced Features¶
PyTorch Lightning also supports distributed training, mixed precision training, and more. These features can be easily activated in the Trainer (make sure you have the required hardware and software installed). Here for example, we would start training on two GPUs and use mixed precision training:
trainer = pl.Trainer(gpus=2, precision=16)
If you don't know the number of GPUs available, you can set gpus=-1
and PyTorch Lightning will automatically use all available GPUs.
For more options on the Trainer
, check out the documentation.
Conclusion¶
PyTorch Lightning is a powerful tool for organizing PyTorch code and making it more efficient and maintainable. It abstracts away the engineering details, allowing you to focus on the research part. This tutorial covered the basics, but there's a lot more to explore, including advanced features like distributed training, and integrations with other tools and libraries. Be sure to check out the official documentation for more information.