Train and Inference MNIST
Get Endpoint via Float16
Spot mode
Step 1 : Prepare Your Script
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
print(f"PyTorch version: {torch.__version__}")
# Define the neural network (same as before)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = nn.functional.log_softmax(x, dim=1)
return output
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Data loading
def load_data(data_path):
print(f"load data : {data_path}")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root=data_path, train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root=data_path, train=False, download=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
return train_loader, test_loader
# Initialize the model, loss function, and optimizer
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training parameters
num_epochs = 10
save_interval = 2 # Save checkpoint every 2 epochs
# Checkpoint file path
checkpoint_path = 'mnist_checkpoint.pth'
# Function to save checkpoint
def save_checkpoint(epoch, model, optimizer):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
# Function to load checkpoint
def load_checkpoint(model, optimizer):
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"Resuming training from epoch {start_epoch}")
return start_epoch
else:
print("No checkpoint found. Starting training from scratch.")
return 0
def train(model, train_loader, test_loader, num_epochs, save_interval):
# Load checkpoint if it exists
print(f"load checkpoint")
start_epoch = load_checkpoint(model, optimizer)
# Training loop
for epoch in range(start_epoch, num_epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
# Save checkpoint
if (epoch + 1) % save_interval == 0:
save_checkpoint(epoch, model, optimizer)
print(f"Checkpoint saved at epoch {epoch+1}")
# Evaluate on test set
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
# Save final model
torch.save(model.state_dict(), 'mnist_model.pth')
print("Training completed. Final model saved.")
data_path = "../datasets/mnist-datasets" # Make sure this matches the path in download_mnist.py
train_loader, test_loader = load_data(data_path)
train(model, train_loader, test_loader, num_epochs, save_interval)Step 2 : Download and Upload datasets
Step 3 : Training the MNIST model
Resulting Files
Step 4 : Inference the MNIST model
Explore More
Last updated