Train and Inference MNIST
Get Endpoint via Float16
This tutorial guides you train and inference AI development using Float16's spot mode.
Spot mode
Spot mode is cost effective for interruptable workload such as train AI model, offline inference, pre-processing and etc.
Spot mode is offer discount 80% when compare with run mode and server mode.
Step 1 : Prepare Your Script
https://github.com/float16-cloud/examples/tree/main/official/spot/torch-train-and-infernce-mnist
(train.py)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
print(f"PyTorch version: {torch.__version__}")
# Define the neural network (same as before)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = nn.functional.log_softmax(x, dim=1)
return output
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data loading
def load_data(data_path):
print(f"load data : {data_path}")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root=data_path, train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root=data_path, train=False, download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
return train_loader, test_loader
# Initialize the model, loss function, and optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training parameters
num_epochs = 10
save_interval = 2 # Save checkpoint every 2 epochs
# Checkpoint file path
checkpoint_path = 'mnist_checkpoint.pth'
# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
# Function to load checkpoint
def load_checkpoint(model, optimizer):
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"Resuming training from epoch {start_epoch}")
return start_epoch
else:
print("No checkpoint found. Starting training from scratch.")
return 0
def train(model, train_loader, test_loader, num_epochs, save_interval):
# Load checkpoint if it exists
print(f"load checkpoint")
start_epoch = load_checkpoint(model, optimizer)
# Training loop
for epoch in range(start_epoch, num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
# Save checkpoint
if (epoch + 1) % save_interval == 0:
save_checkpoint(epoch, model, optimizer)
print(f"Checkpoint saved at epoch {epoch+1}")
# Evaluate on test set
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
# Save final model
torch.save(model.state_dict(), 'mnist_model.pth')
print("Training completed. Final model saved.")
data_path = "../datasets/mnist-datasets" # Make sure this matches the path in download_mnist.py
train_loader, test_loader = load_data(data_path)
train(model, train_loader, test_loader, num_epochs, save_interval)(inference.py)
(download-mnist-datasets.py)
Step 2 : Download and Upload datasets
Download datasets into you local machine
Upload MNIST datasets to the server
Step 3 : Training the MNIST model
Start training the MNIST model with spot mode.
Resulting Files
mnist_checkpoint.pth: MNIST model weight
Step 4 : Inference the MNIST model
Make sure you have uploaded the test image to the server and changed the image path before running the inference.
Congratulations! You've successfully use your first server mode on Float16's serverless GPU platform.
Explore More
Learn how to use Float16 CLI for various use cases in our tutorials.
Happy coding with Float16 Serverless GPU!
Last updated