import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs('visualizations', exist_ok=True)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
class VSSM(nn.Module):def __init__(self, input_size=784, hidden_size=32, state_size=16, output_size=10):super(VSSM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.state_size = state_sizeself.output_size = output_sizeself.encoder = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU())self.fc_mu = nn.Linear(hidden_size, state_size)self.fc_logvar = nn.Linear(hidden_size, state_size)self.transition = nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, state_size))self.decoder = nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, input_size))self.classifier = nn.Sequential(nn.Linear(state_size, hidden_size),nn.ReLU(),nn.Dropout(0.2), nn.Linear(hidden_size, output_size))def encode(self, x):h = self.encoder(x)mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvardef reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * std def decode(self, z):return self.decoder(z) def classify(self, z):return self.classifier(z) def forward(self, x):batch_size = x.size(0)x_flat = x.view(batch_size, -1) mu, logvar = self.encode(x_flat)z = self.reparameterize(mu, logvar)z_next = self.transition(z)recon_flat = self.decode(z_next)pred = self.classify(z)return recon_flat, pred, mu, logvar, z, x_flat
def vssm_loss(recon_x, x, pred, target, mu, logvar, lambda_kl=0.1, lambda_cls=1.0):recon_loss = F.mse_loss(recon_x, x.view(x.size(0), -1), reduction='sum')kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())cls_loss = F.cross_entropy(pred, target, reduction='sum')batch_size = x.size(0)total_loss = (recon_loss + lambda_kl * kl_loss + lambda_cls * cls_loss) / batch_sizereturn total_loss, recon_loss.item()/batch_size, kl_loss.item()/batch_size, cls_loss.item()/batch_size
def pltLoss(train_losses, test_losses, epochs):plt.figure(figsize=(10, 5))plt.plot(range(1, epochs+1), train_losses, 'b-', label='Training Loss')plt.plot(range(1, epochs+1), test_losses, 'r-', label='Test Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training and Test Loss')plt.legend()plt.grid(True)plt.tight_layout()plt.savefig('loss_curve.png')plt.close()
def plotTest(model, test_loader, device, epoch):model.eval()best_sample = Nonebest_confidence = -1best_info = Nonewith torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)recon_flat, pred, mu, logvar, z, x_flat = model(data)confidence = F.softmax(pred, dim=1).max(dim=1)[0]max_idx = confidence.argmin().item()if confidence[max_idx] > best_confidence:best_confidence = confidence[max_idx].item()best_sample = {'input': data[max_idx].cpu(),'recon': recon_flat[max_idx].cpu().view(1, 28, 28),'target': target[max_idx].cpu().item(),'pred': pred[max_idx].argmax().cpu().item(),'confidence': best_confidence,'mu': mu[max_idx].cpu().numpy(),'logvar': logvar[max_idx].cpu().numpy(),'z': z[max_idx].cpu().numpy(),'pred_dist': F.softmax(pred[max_idx], dim=0).cpu().numpy()}del data, target, recon_flat, pred, mu, logvar, z, x_flat, confidence, max_idxtorch.cuda.empty_cache()if best_sample is not None:plt.figure(figsize=(12, 8))plt.subplot(2, 3, 1)plt.title(f'Input Image (True: {best_sample["target"]})')plt.imshow(best_sample['input'].squeeze().numpy(), cmap='gray')plt.axis('off')plt.subplot(2, 3, 2)plt.title(f'Reconstructed Image')plt.imshow(best_sample['recon'].squeeze().numpy(), cmap='gray')plt.axis('off')plt.subplot(2, 3, 3)plt.title('Latent Mean (μ)')plt.bar(range(len(best_sample['mu'])), best_sample['mu'])plt.xlabel('Dimension')plt.ylabel('Value')plt.subplot(2, 3, 4)plt.title('Latent Log Variance (log σ²)')plt.bar(range(len(best_sample['logvar'])), best_sample['logvar'])plt.xlabel('Dimension')plt.ylabel('Value')plt.subplot(2, 3, 5)plt.title('Sampled Latent Variable (z)')plt.bar(range(len(best_sample['z'])), best_sample['z'])plt.xlabel('Dimension')plt.ylabel('Value')plt.subplot(2, 3, 6)plt.title(f'Prediction Distribution (Pred: {best_sample["pred"]}, Conf: {best_sample["confidence"]:.4f})')plt.bar(range(10), best_sample['pred_dist'])plt.xticks(range(10))plt.xlabel('Class')plt.ylabel('Probability')plt.tight_layout()plt.savefig(f'visualizations/epoch_{epoch}_best_sample.png')plt.close()
model = VSSM().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)
def train(model, train_loader, optimizer, epoch, device):model.train()train_loss = 0train_recon_loss = 0train_kl_loss = 0train_cls_loss = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()recon, pred, mu, logvar, z, x_flat = model(data)loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)loss.backward()optimizer.step()train_loss += loss.item()train_recon_loss += recon_losstrain_kl_loss += kl_losstrain_cls_loss += cls_lossif batch_idx % 100 == 0:print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')avg_loss = train_loss / len(train_loader)avg_recon_loss = train_recon_loss / len(train_loader)avg_kl_loss = train_kl_loss / len(train_loader)avg_cls_loss = train_cls_loss / len(train_loader)print(f'Epoch: {epoch} Average training loss: {avg_loss:.4f} 'f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')return avg_loss
def test(model, test_loader, device):model.eval()test_loss = 0test_recon_loss = 0test_kl_loss = 0test_cls_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)recon, pred, mu, logvar, z, x_flat = model(data)loss, recon_loss, kl_loss, cls_loss = vssm_loss(recon, data, pred, target, mu, logvar)test_loss += loss.item()test_recon_loss += recon_losstest_kl_loss += kl_losstest_cls_loss += cls_losspred_class = pred.argmax(dim=1, keepdim=True)correct += pred_class.eq(target.view_as(pred_class)).sum().item()avg_loss = test_loss / len(test_loader)avg_recon_loss = test_recon_loss / len(test_loader)avg_kl_loss = test_kl_loss / len(test_loader)avg_cls_loss = test_cls_loss / len(test_loader)accuracy = 100. * correct / len(test_loader.dataset)print(f'Average test loss: {avg_loss:.4f} 'f'(Recon: {avg_recon_loss:.4f}, KL: {avg_kl_loss:.4f}, Cls: {avg_cls_loss:.4f})')print(f'Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')return avg_loss, accuracy
epochs = 10
train_losses = []
test_losses = []
best_accuracy = 0.0for epoch in range(1, epochs + 1):print(f'\nEpoch {epoch}/{epochs}')train_loss = train(model, train_loader, optimizer, epoch, device)train_losses.append(train_loss)test_loss, accuracy = test(model, test_loader, device)test_losses.append(test_loss)plotTest(model, test_loader, device, epoch)scheduler.step(test_loss)if accuracy > best_accuracy:best_accuracy = accuracytorch.save(model.state_dict(), 'best_model.pth')print(f'Best model saved with accuracy: {accuracy:.2f}%')pltLoss(train_losses, test_losses, epoch)torch.cuda.empty_cache()print(f'\nTraining completed. Best accuracy: {best_accuracy:.2f}%')