Train and Inference MNIST

Get Endpoint via Float16

This tutorial guides you train and inference AI development using Float16's spot mode.

  • Float16 CLI installed

  • Logged into Float16 account

  • VSCode or preferred text editor recommended

Spot mode

Spot mode is cost effective for interruptable workload such as train AI model, offline inference, pre-processing and etc.

Spot mode is offer discount 80% when compare with run mode and server mode.

Step 1 : Prepare Your Script

https://github.com/float16-cloud/examples/tree/main/official/spot/torch-train-and-infernce-mnist

(train.py)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

print(f"PyTorch version: {torch.__version__}")
# Define the neural network (same as before)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data loading
def load_data(data_path):
    print(f"load data : {data_path}")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root=data_path, train=True, download=False, transform=transform)
    test_dataset = datasets.MNIST(root=data_path, train=False, download=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    return train_loader, test_loader

# Initialize the model, loss function, and optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training parameters
num_epochs = 10
save_interval = 2  # Save checkpoint every 2 epochs

# Checkpoint file path
checkpoint_path = 'mnist_checkpoint.pth'

# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)

# Function to load checkpoint
def load_checkpoint(model, optimizer):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
        return start_epoch
    else:
        print("No checkpoint found. Starting training from scratch.")
        return 0

def train(model, train_loader, test_loader, num_epochs, save_interval):
    # Load checkpoint if it exists
    print(f"load checkpoint")
    start_epoch = load_checkpoint(model, optimizer)

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        # Save checkpoint
        if (epoch + 1) % save_interval == 0:
            save_checkpoint(epoch, model, optimizer)
            print(f"Checkpoint saved at epoch {epoch+1}")

        # Evaluate on test set
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

    # Save final model
    torch.save(model.state_dict(), 'mnist_model.pth')
    print("Training completed. Final model saved.")

data_path = "../datasets/mnist-datasets"  # Make sure this matches the path in download_mnist.py
train_loader, test_loader = load_data(data_path)
train(model, train_loader, test_loader, num_epochs, save_interval)

(inference.py)

# inference_mnist.py

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np

# Define the same neural network architecture used for training
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

# Function to preprocess the image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# Function to load model and make prediction
def predict_digit(image_path, model_path):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model
    model = Net().to(device)
    print(f"Loading model from {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Preprocess the image
    image = preprocess_image(image_path).to(device)

    # Make prediction
    with torch.no_grad():
        output = model(image)
        prediction = output.argmax(dim=1, keepdim=True)

    return prediction.item()

# Example usage
model_path = "mnist_model.pth"  # Path to your saved model
image_path = "../datasets/mnist-datasets/test"  # Path to the image you want to classify

# If you want to test with multiple images
test_images = ["../datasets/mnist-datasets/test/image1.jpg", "../datasets/mnist-datasets/test/image2.jpg", "../datasets/mnist-datasets/test/image3.jpg"]
result = []
for img in test_images:
    digit = predict_digit(img, model_path)
    result.append((img, digit))
print(f"Results: : {result}")
with open("mnist-results.txt", "w") as f:
    for img, digit in result:
        f.write(f"{img}: {digit}\n")

(download-mnist-datasets.py)

import os
from torchvision import datasets, transforms

def download_mnist(data_path):
    if not os.path.exists(data_path):
        os.makedirs(data_path)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Download training data
    train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
    
    # Download test data
    test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

    print(f"MNIST dataset downloaded and saved to {data_path}")

if __name__ == "__main__":
    data_path = "../mnist-datasets"  # You can change this to your preferred location
    download_mnist(data_path)

Step 2 : Download and Upload datasets

Download datasets into you local machine

python3 download-mnist-datasets.py

Upload MNIST datasets to the server

float16 storage upload -f ./mnist-datasets -d datasets

Step 3 : Training the MNIST model

Start training the MNIST model with spot mode.

float16 run train.py --spot

Resulting Files

  • mnist_checkpoint.pth: MNIST model weight

Step 4 : Inference the MNIST model

float16 run inference.py
Results: : [
('../datasets/mnist-datasets/test/image1.jpg', 5), 
('../datasets/mnist-datasets/test/image2.jpg', 7), 
('../datasets/mnist-datasets/test/image3.jpg', 5)
]

Explore More

Learn how to use Float16 CLI for various use cases in our tutorials.

Happy coding with Float16 Serverless GPU!

Last updated