Accelerate your PyTorch model development with ZenML
Seamlessly integrate PyTorch, a powerful deep learning framework, with ZenML to streamline your model development and experimentation process. By leveraging ZenML's model-agnostic pipelines and PyTorch's flexibility, you can rapidly iterate on models, track experiments, and deploy production-ready solutions with ease.
Features with ZenML
- Seamless PyTorch Integration:
Effortlessly incorporate PyTorch models and training logic into ZenML pipelines for a unified workflow. - Reproducible Experiments:
Track and version PyTorch data objects and models using ZenML, ensuring reproducibility and facilitating collaboration. - Effortless Handling of PyTorch Data Artifacts and Models:
ZenML knows how to serialize PyTorch artifacts like DataLoader and Module and allows you to use them across steps in different environments. - Streamlined Deployment:
Seamlessly transition PyTorch models from experimentation to production using ZenML's deployment integrations.
Main Features
- Flexible and expressive deep learning framework.
- Extensive ecosystem of pre-trained models and extensions.
- Optimizers, loss functions and other pre-defined helper classes to use out of the box.
- Strong community support and comprehensive documentation.
- Interoperability with popular data science tools and libraries.
How to use ZenML with
PyTorch
from zenml import pipeline
from zenml.integrations.constants import PYTORCH
from torch import nn
from torch.utils.data import DataLoader
@step(enable_cache=False)
def trainer(train_dataloader: DataLoader) -> nn.Module:
"""Trains on the train dataloader."""
model = NeuralNetwork().to(DEVICE) # NeuralNetwork extends nn.Module
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
size = len(train_dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(train_dataloader):
X, y = X.to(DEVICE), y.to(DEVICE)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
return model
@pipeline()
def fashion_mnist_pipeline():
"""Link all the steps and artifacts together."""
train_dataloader, test_dataloader = importer_mnist()
model = trainer(train_dataloader)
evaluator(test_dataloader=test_dataloader, model=model)
This code example demonstrates a simple ZenML pipeline that incorporates a PyTorch model training step.
The data loader step (importer_mnist
) returns a PyTorch DataLoader
object that is serialized by ZenML and made available to the trainer
step.
ZenML automatically tracks and versions your DataLoader
and Module
objects on every pipeline run. This helps you establish a lineage and makes reproducing runs easier.
The train_model
step uses a PyTorch neural network module and trains it using random input data. The loss function definition and the optimizer are also used from PyTorch directly. The PyTorch ZenML Integration knows how to serialize the Module
class and can load it in future steps from your ZenML artifact store.
Additional Resources
LLM LitGPT finetuning project that uses PyTorch
ZenML PyTotch Integration Code Docs