📚
Docs - Float16
homeapp
  • 🚀GETTING STARTED
    • Introduction
    • Account
      • Dashboard
      • Profile
      • Payment
      • Workspace
      • Service Quota
    • LLM as a service
      • Quick Start
        • Set the credentials
      • Supported Model
      • Limitation
      • API Reference
    • One Click Deploy
      • Quick Start
        • Instance Detail
        • Re-generate API Key
        • Terminate Instance
      • Features
        • OpenAI Compatible
        • Long context and Auto scheduler
        • Quantization
        • Context caching
      • Limitation
      • Validated model
      • Endpoint Specification
    • Serverless GPU
      • Quick Start
        • Mode
        • Task Status
        • App Features
          • Project Detail
      • Tutorials
        • Hello World
        • Install new library
        • Prepare model weight
        • S3 Copy output from remote
        • R2 Copy output from remote
        • Direct upload and download
        • Server mode
        • LLM Dynamic Batching
        • Train and Inference MNIST
        • Etc.
      • CLI References
      • ❓FAQ
    • Playground
      • FloatChat
      • FloatPrompt
      • Quantize by Float16
  • 📚Use Case
    • Q&A Bot (RAG)
    • Text-to-SQL
    • OpenAI with Rate Limit
    • OpenAI with Guardrail
    • Multiple Agents
    • Q&A Chatbots (RAG + Agents)
  • ✳️Journey
    • ✨The Beginner's LLM Development Journey
    • 📖Glossary
      • [English Version] LLM Glossary
      • [ภาษาไทย] LLM Glossary
    • 🧠How to install node
  • Prompting
    • 📚Variable
    • ⛓️Condition
    • 🔨Demonstration
    • ⌛Loop
    • 📙Formatting
    • 🐣Chat
    • 🔎Technical term (Retrieve)
  • Privacy Policy
  • Terms & Conditions
Powered by GitBook
On this page
  • Spot mode
  • Step 1 : Prepare Your Script
  • Step 2 : Download and Upload datasets
  • Step 3 : Training the MNIST model
  • Resulting Files
  • Step 4 : Inference the MNIST model
  • Explore More
  1. GETTING STARTED
  2. Serverless GPU
  3. Tutorials

Train and Inference MNIST

Get Endpoint via Float16

PreviousLLM Dynamic BatchingNextEtc.

Last updated 1 month ago

This tutorial guides you train and inference AI development using Float16's spot mode.

  • Float16 CLI installed

  • Logged into Float16 account

  • VSCode or preferred text editor recommended

Spot mode

Spot mode is cost effective for interruptable workload such as train AI model, offline inference, pre-processing and etc.

Spot mode is offer discount 80% when compare with run mode and server mode.

Step 1 : Prepare Your Script

(train.py)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

print(f"PyTorch version: {torch.__version__}")
# Define the neural network (same as before)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data loading
def load_data(data_path):
    print(f"load data : {data_path}")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root=data_path, train=True, download=False, transform=transform)
    test_dataset = datasets.MNIST(root=data_path, train=False, download=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    return train_loader, test_loader

# Initialize the model, loss function, and optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training parameters
num_epochs = 10
save_interval = 2  # Save checkpoint every 2 epochs

# Checkpoint file path
checkpoint_path = 'mnist_checkpoint.pth'

# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)

# Function to load checkpoint
def load_checkpoint(model, optimizer):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")
        return start_epoch
    else:
        print("No checkpoint found. Starting training from scratch.")
        return 0

def train(model, train_loader, test_loader, num_epochs, save_interval):
    # Load checkpoint if it exists
    print(f"load checkpoint")
    start_epoch = load_checkpoint(model, optimizer)

    # Training loop
    for epoch in range(start_epoch, num_epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        # Save checkpoint
        if (epoch + 1) % save_interval == 0:
            save_checkpoint(epoch, model, optimizer)
            print(f"Checkpoint saved at epoch {epoch+1}")

        # Evaluate on test set
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

    # Save final model
    torch.save(model.state_dict(), 'mnist_model.pth')
    print("Training completed. Final model saved.")

data_path = "../datasets/mnist-datasets"  # Make sure this matches the path in download_mnist.py
train_loader, test_loader = load_data(data_path)
train(model, train_loader, test_loader, num_epochs, save_interval)

(inference.py)

# inference_mnist.py

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np

# Define the same neural network architecture used for training
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

# Function to preprocess the image
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    image = Image.open(image_path).convert('L')  # Convert to grayscale
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image

# Function to load model and make prediction
def predict_digit(image_path, model_path):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the model
    model = Net().to(device)
    print(f"Loading model from {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    # Preprocess the image
    image = preprocess_image(image_path).to(device)

    # Make prediction
    with torch.no_grad():
        output = model(image)
        prediction = output.argmax(dim=1, keepdim=True)

    return prediction.item()

# Example usage
model_path = "mnist_model.pth"  # Path to your saved model
image_path = "../datasets/mnist-datasets/test"  # Path to the image you want to classify

# If you want to test with multiple images
test_images = ["../datasets/mnist-datasets/test/image1.jpg", "../datasets/mnist-datasets/test/image2.jpg", "../datasets/mnist-datasets/test/image3.jpg"]
result = []
for img in test_images:
    digit = predict_digit(img, model_path)
    result.append((img, digit))
print(f"Results: : {result}")
with open("mnist-results.txt", "w") as f:
    for img, digit in result:
        f.write(f"{img}: {digit}\n")

(download-mnist-datasets.py)

import os
from torchvision import datasets, transforms

def download_mnist(data_path):
    if not os.path.exists(data_path):
        os.makedirs(data_path)
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Download training data
    train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
    
    # Download test data
    test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

    print(f"MNIST dataset downloaded and saved to {data_path}")

if __name__ == "__main__":
    data_path = "../mnist-datasets"  # You can change this to your preferred location
    download_mnist(data_path)

Step 2 : Download and Upload datasets

Download datasets into you local machine

python3 download-mnist-datasets.py

Upload MNIST datasets to the server

float16 storage upload -f ./mnist-datasets -d datasets

Step 3 : Training the MNIST model

Start training the MNIST model with spot mode.

float16 run train.py --spot

Resulting Files

  • mnist_checkpoint.pth: MNIST model weight

Step 4 : Inference the MNIST model

Make sure you have uploaded the test image to the server and changed the image path before running the inference.

float16 run inference.py
Results: : [
('../datasets/mnist-datasets/test/image1.jpg', 5), 
('../datasets/mnist-datasets/test/image2.jpg', 7), 
('../datasets/mnist-datasets/test/image3.jpg', 5)
]

Congratulations! You've successfully use your first server mode on Float16's serverless GPU platform.

Explore More

Learn how to use Float16 CLI for various use cases in our tutorials.

Happy coding with Float16 Serverless GPU!

🚀
https://github.com/float16-cloud/examples/tree/main/official/spot/torch-train-and-infernce-mnist

Hello World

Launch your first serverless GPU function and kickstart your journey.

Install new library

Enhance your toolkit by adding new libraries tailored to your project needs.

Copy output from remote

Efficiently transfer computation results from remote to your local storage.

Deploy FastAPI Helloworld

Quick start to deploy FastAPI without change the code.

Upload and Download via CLI and Website

Direct upload and download file(s) to server.

More examples

Open source from community and Float16 team.