A standard PyTorch training loop structure.
# Hyperparameters
EPOCHS = 10
for epoch in range(EPOCHS):
running_loss = 0.0
# Iterate over Batches
for batch_idx, (data, target) in enumerate(train_loader):
# 1. Zero Gradients
optimizer.zero_grad()
# 2. Forward Pass
output = model(data)
# 3. Calculate Loss
loss = criterion(output, target)
# 4. Backward Pass & Update
loss.backward()
optimizer.step()
print(f"Epoch {epoch} complete!")