Revision

Back to Python


Import

Base import of Pytorch requires torch and torch.nn which contains most deep learning blocks.

import torch
from torch import nn



Model Creation

A Pytorch model is created as a class that inherits from nn.Module:

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

The forward method must be defined.
In this example from Pytorch documentation it uses the attribute self.linear_relu_stack which is created as a nn.Sequential.

The backward is implicitly created from the forward method and the Pytorch composants used.



Training

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

The training of the model is easily done following these steps:

  1. Set the model in train mode,
  2. Make predictions using input data and the model (more specifically using the forward method of the model),
  3. Compute the loss using the predictions, the true labels and the defined loss function,
  4. Reinitialize then compute the gradients by backpropagating the loss using the implicitly define backward method of the model,
  5. Update the parameters of the model (its weights) using the computed gradients and the chosen optimizer.

These steps are done batch by batch. Also the input and output must be transfer to the device used for computation (GPU for example).



Ressources

See: