Integrations
PyTorch
and
ZenML logo in purple, representing machine learning pipelines and MLOps framework.
Accelerate your PyTorch model development with ZenML
PyTorch
All integrations

PyTorch

Accelerate your PyTorch model development with ZenML
Add to ZenML
Category
Modeling
COMPARE
related resources
No items found.

Accelerate your PyTorch model development with ZenML

Seamlessly integrate PyTorch, a powerful deep learning framework, with ZenML to streamline your model development and experimentation process. By leveraging ZenML's model-agnostic pipelines and PyTorch's flexibility, you can rapidly iterate on models, track experiments, and deploy production-ready solutions with ease.

Features with ZenML

  • Seamless PyTorch Integration:
    Effortlessly incorporate PyTorch models and training logic into ZenML pipelines for a unified workflow.
  • Reproducible Experiments:
    Track and version PyTorch data objects and models using ZenML, ensuring reproducibility and facilitating collaboration.
  • Effortless Handling of PyTorch Data Artifacts and Models:
    ZenML knows how to serialize PyTorch artifacts like DataLoader and Module and allows you to use them across steps in different environments.
  • Streamlined Deployment:
    Seamlessly transition PyTorch models from experimentation to production using ZenML's deployment integrations.

Main Features

  • Flexible and expressive deep learning framework.
  • Extensive ecosystem of pre-trained models and extensions.
  • Optimizers, loss functions and other pre-defined helper classes to use out of the box.
  • Strong community support and comprehensive documentation.
  • Interoperability with popular data science tools and libraries.

How to use ZenML with
PyTorch
from zenml import pipeline
from zenml.integrations.constants import PYTORCH
from torch import nn
from torch.utils.data import DataLoader


@step(enable_cache=False)
def trainer(train_dataloader: DataLoader) -> nn.Module:
    """Trains on the train dataloader."""
    model = NeuralNetwork().to(DEVICE)  # NeuralNetwork extends nn.Module
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    size = len(train_dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(DEVICE), y.to(DEVICE)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    return model

@pipeline()
def fashion_mnist_pipeline():
    """Link all the steps and artifacts together."""
    train_dataloader, test_dataloader = importer_mnist()
    model = trainer(train_dataloader)
    evaluator(test_dataloader=test_dataloader, model=model)

This code example demonstrates a simple ZenML pipeline that incorporates a PyTorch model training step.

The data loader step (importer_mnist) returns a PyTorch DataLoader object that is serialized by ZenML and made available to the trainer step.

ZenML automatically tracks and versions your DataLoader and Module objects on every pipeline run. This helps you establish a lineage and makes reproducing runs easier.

The train_model step uses a PyTorch neural network module and trains it using random input data. The loss function definition and the optimizer are also used from PyTorch directly. The PyTorch ZenML Integration knows how to serialize the Module class and can load it in future steps from your ZenML artifact store.

Additional Resources
LLM LitGPT finetuning project that uses PyTorch
ZenML PyTotch Integration Code Docs

Accelerate your PyTorch model development with ZenML

Seamlessly integrate PyTorch, a powerful deep learning framework, with ZenML to streamline your model development and experimentation process. By leveraging ZenML's model-agnostic pipelines and PyTorch's flexibility, you can rapidly iterate on models, track experiments, and deploy production-ready solutions with ease.
PyTorch

Start Your Free Trial Now

No new paradigms - Bring your own tools and infrastructure
No data leaves your servers, we only track metadata
Free trial included - no strings attached, cancel anytime
Dashboard displaying machine learning models, including versions, authors, and tags. Relevant to model monitoring and ML pipelines.

Connect Your ML Pipelines to a World of Tools

Expand your ML pipelines with Apache Airflow and other 50+ ZenML Integrations
Tekton
Databricks
Databricks
Azure Blob Storage
AWS
Lightning AI
LightGBM
Azure Container Registry
Slack
Facets
Comet