Train and Inference MNIST
Get Endpoint via Float16
This tutorial guides you train and inference AI development using Float16's spot mode.
Float16 CLI installed
Logged into Float16 account
VSCode or preferred text editor recommended
Spot mode
Spot mode is cost effective for interruptable workload such as train AI model, offline inference, pre-processing and etc.
Spot mode is offer discount 80% when compare with run mode and server mode.
Step 1 : Prepare Your Script
https://github.com/float16-cloud/examples/tree/main/official/spot/torch-train-and-infernce-mnist
(train.py)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
print(f"PyTorch version: {torch.__version__}")
# Define the neural network (same as before)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = nn.functional.log_softmax(x, dim=1)
return output
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data loading
def load_data(data_path):
print(f"load data : {data_path}")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root=data_path, train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root=data_path, train=False, download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
return train_loader, test_loader
# Initialize the model, loss function, and optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training parameters
num_epochs = 10
save_interval = 2 # Save checkpoint every 2 epochs
# Checkpoint file path
checkpoint_path = 'mnist_checkpoint.pth'
# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
# Function to load checkpoint
def load_checkpoint(model, optimizer):
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"Resuming training from epoch {start_epoch}")
return start_epoch
else:
print("No checkpoint found. Starting training from scratch.")
return 0
def train(model, train_loader, test_loader, num_epochs, save_interval):
# Load checkpoint if it exists
print(f"load checkpoint")
start_epoch = load_checkpoint(model, optimizer)
# Training loop
for epoch in range(start_epoch, num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
# Save checkpoint
if (epoch + 1) % save_interval == 0:
save_checkpoint(epoch, model, optimizer)
print(f"Checkpoint saved at epoch {epoch+1}")
# Evaluate on test set
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
# Save final model
torch.save(model.state_dict(), 'mnist_model.pth')
print("Training completed. Final model saved.")
data_path = "../datasets/mnist-datasets" # Make sure this matches the path in download_mnist.py
train_loader, test_loader = load_data(data_path)
train(model, train_loader, test_loader, num_epochs, save_interval)
(inference.py)
# inference_mnist.py
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
# Define the same neural network architecture used for training
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = nn.functional.log_softmax(x, dim=1)
return output
# Function to preprocess the image
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
image = Image.open(image_path).convert('L') # Convert to grayscale
image = transform(image).unsqueeze(0) # Add batch dimension
return image
# Function to load model and make prediction
def predict_digit(image_path, model_path):
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = Net().to(device)
print(f"Loading model from {model_path}")
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# Preprocess the image
image = preprocess_image(image_path).to(device)
# Make prediction
with torch.no_grad():
output = model(image)
prediction = output.argmax(dim=1, keepdim=True)
return prediction.item()
# Example usage
model_path = "mnist_model.pth" # Path to your saved model
image_path = "../datasets/mnist-datasets/test" # Path to the image you want to classify
# If you want to test with multiple images
test_images = ["../datasets/mnist-datasets/test/image1.jpg", "../datasets/mnist-datasets/test/image2.jpg", "../datasets/mnist-datasets/test/image3.jpg"]
result = []
for img in test_images:
digit = predict_digit(img, model_path)
result.append((img, digit))
print(f"Results: : {result}")
with open("mnist-results.txt", "w") as f:
for img, digit in result:
f.write(f"{img}: {digit}\n")
(download-mnist-datasets.py)
import os
from torchvision import datasets, transforms
def download_mnist(data_path):
if not os.path.exists(data_path):
os.makedirs(data_path)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Download training data
train_dataset = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
# Download test data
test_dataset = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)
print(f"MNIST dataset downloaded and saved to {data_path}")
if __name__ == "__main__":
data_path = "../mnist-datasets" # You can change this to your preferred location
download_mnist(data_path)
Step 2 : Download and Upload datasets
Download datasets into you local machine
python3 download-mnist-datasets.py
Upload MNIST datasets to the server
float16 storage upload -f ./mnist-datasets -d datasets
Step 3 : Training the MNIST model
Start training the MNIST model with spot mode.
float16 run train.py --spot
Resulting Files
mnist_checkpoint.pth
: MNIST model weight
Step 4 : Inference the MNIST model
Make sure you have uploaded the test image to the server and changed the image path before running the inference.
float16 run inference.py
Results: : [
('../datasets/mnist-datasets/test/image1.jpg', 5),
('../datasets/mnist-datasets/test/image2.jpg', 7),
('../datasets/mnist-datasets/test/image3.jpg', 5)
]
Congratulations! You've successfully use your first server mode on Float16's serverless GPU platform.
Explore More
Learn how to use Float16 CLI for various use cases in our tutorials.
Happy coding with Float16 Serverless GPU!
Last updated