# Cleaned Code ```python import os import math import zipfile import urllib.request import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms # ========================================================= # 1. TINY-IMAGENET DOWNLOAD + PREPARATION # ========================================================= def prepare_tiny_imagenet(): """ Downloads and extracts Tiny-ImageNet if not already present. Returns: train_dir, val_dir """ dataset_url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" zip_path = "./tiny-imagenet-200.zip" extract_path = "./tiny-imagenet-200" # ----------------------------------------------------- # Download dataset archive # ----------------------------------------------------- if not os.path.exists(zip_path): print( "Downloading Tiny-ImageNet (~230MB)... " "Please wait..." ) urllib.request.urlretrieve( dataset_url, zip_path ) print("Download complete!") # ----------------------------------------------------- # Extract dataset archive # ----------------------------------------------------- if not os.path.exists(extract_path): print("Extracting dataset...") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall("./") print("Extraction complete!") return ( os.path.join(extract_path, "train"), os.path.join(extract_path, "val") ) train_dir, val_dir = prepare_tiny_imagenet() # ========================================================= # 2. VALIDATION FOLDER RESTRUCTURING # ========================================================= # # Tiny-ImageNet validation images are originally placed # in a single shared folder. # # This section reorganizes them into class-specific # folders so torchvision.datasets.ImageFolder can # load them correctly. # val_img_dir = "./tiny-imagenet-200/val/images" val_annotations = ( "./tiny-imagenet-200/val/val_annotations.txt" ) if os.path.exists(val_img_dir): print( "Reorganizing Tiny-ImageNet validation " "folder structure..." ) with open(val_annotations, "r") as f: lines = f.readlines() for line in lines: parts = line.strip().split("\t") img_name = parts[0] class_name = parts[1] class_dir = os.path.join( "./tiny-imagenet-200/val", class_name ) os.makedirs(class_dir, exist_ok=True) src_path = os.path.join( val_img_dir, img_name ) dst_path = os.path.join( class_dir, img_name ) if os.path.exists(src_path): os.rename(src_path, dst_path) os.rmdir(val_img_dir) print( "Validation folder restructuring complete!" ) # ========================================================= # 3. DATA AUGMENTATION + NORMALIZATION # ========================================================= transform_train = transforms.Compose([ # Horizontal augmentation transforms.RandomHorizontalFlip(), # Mild rotational augmentation transforms.RandomRotation(15), transforms.ToTensor(), # Tiny-ImageNet normalization statistics transforms.Normalize( (0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262) ) ]) transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize( (0.4802, 0.4481, 0.3975), (0.2302, 0.2265, 0.2262) ) ]) # ========================================================= # 4. DATASET + DATALOADER SETUP # ========================================================= train_dataset = datasets.ImageFolder( root=train_dir, transform=transform_train ) val_dataset = datasets.ImageFolder( root=val_dir, transform=transform_val ) train_loader = DataLoader( train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True ) # ========================================================= # 5. CORE RELATIONAL LAYER — LOOKTHEM LAYER # ========================================================= class LookThemLayer(nn.Module): """ Token-relational processing layer. Each token owns two independent micro-networks whose outputs are compared against every other token using ratio-based relational interactions. The interaction maps are transformed and then redistributed back into token-space. """ def __init__(self, num_tokens, in_features, hidden_dim): super(LookThemLayer, self).__init__() self.num_tokens = num_tokens self.in_features = in_features # ================================================= # BRANCH 1 PARAMETERS # ================================================= self.mod1_w1 = nn.Parameter( torch.randn( num_tokens, in_features, hidden_dim ) ) self.mod1_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod1_w2 = nn.Parameter( torch.randn( num_tokens, hidden_dim, 1 ) ) self.mod1_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) # ================================================= # BRANCH 2 PARAMETERS # ================================================= self.mod2_w1 = nn.Parameter( torch.randn( num_tokens, in_features, hidden_dim ) ) self.mod2_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod2_w2 = nn.Parameter( torch.randn( num_tokens, hidden_dim, 1 ) ) self.mod2_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) # ================================================= # RELATIONAL TRANSFORMATION PARAMETERS # ================================================= self.trans_w = nn.Parameter( torch.randn(num_tokens, 1, 1) ) self.trans_b = nn.Parameter( torch.zeros(num_tokens, 1) ) self._init_weights() def _init_weights(self): """ Kaiming initialization for all learnable projection matrices. """ for w in [ self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w ]: nn.init.kaiming_uniform_( w, a=math.sqrt(5) ) def forward(self, x): """ Input shape: [B, Tokens, Features] Output shape: [B, Tokens, Features] """ N = self.num_tokens # ================================================= # BRANCH 1 FORWARD PASS # ================================================= h1 = ( torch.einsum( 'bti,tij->btj', x, self.mod1_w1 ) + self.mod1_b1 ) out_m1 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h1), self.mod1_w2 ) + self.mod1_b2 ) # ================================================= # BRANCH 2 FORWARD PASS # ================================================= h2 = ( torch.einsum( 'bti,tij->btj', x, self.mod2_w1 ) + self.mod2_b1 ) out_m2 = ( torch.einsum( 'btj,tjk->btk', F.gelu(h2), self.mod2_w2 ) + self.mod2_b2 ) # Numerical stabilization out_m2_safe = out_m2 + 1e-5 # ================================================= # PAIRWISE TOKEN RELATIONAL COMPARISON # ================================================= compare = torch.tanh( out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) ) compare2 = torch.tanh( out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) ) # ================================================= # RELATIONAL MAP TRANSFORMATION # ================================================= bias_reshaped = self.trans_b.view( 1, 1, N, 1 ) trans_compare = ( torch.einsum( 'bije,jef->bijf', compare, self.trans_w ) + bias_reshaped ) trans_compare2 = ( torch.einsum( 'bije,jef->bijf', compare2, self.trans_w ) + bias_reshaped ) # ================================================= # BIDIRECTIONAL INTERACTION FUSION # ================================================= interaction = ( trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1) ) / 2 # Remove self-interaction mask = 1.0 - torch.eye( N, device=x.device ) interaction_masked = ( interaction * mask.view(1, N, N, 1) ) # Aggregate external token interactions return ( interaction_masked.sum(dim=2) / (N - 1.0) ) # ========================================================= # 6. MAIN ARCHITECTURE — LOOKTHEM V5 # ========================================================= class LookThemV5(nn.Module): """ Dual-stream asymmetric relational architecture. Stream A: High-resolution grayscale macro-structure stream. Stream B: RGB color-essence stream compressed into lower spatial resolution. Both streams are fused at feature-level and processed through the relational LookThem core. """ def __init__(self): super(LookThemV5, self).__init__() # ================================================= # RGB → GRAYSCALE CONVERSION WEIGHTS # ================================================= self.register_buffer( 'grayscale_weights', torch.tensor( [0.299, 0.587, 0.114] ).view(1, 3, 1, 1) ) # ================================================= # STREAM A — MACRO STRUCTURE STREAM # ================================================= # # Preserves higher spatial resolution (16x16) # to retain broader structural information. # self.stream_a = nn.Sequential( nn.Conv2d( 1, 16, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU() ) # ================================================= # TOKEN BRIDGE # ================================================= # # Compresses spatial dimension: # # 256 spatial positions → 64 tokens # # while preserving feature channels. # self.token_bridge = nn.Linear(256, 64) # ================================================= # STREAM B — COLOR ESSENCE STREAM # ================================================= # # RGB stream reduced into 8x8 spatial layout # using pure stride-based standard convolutions. # self.stream_b = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU() ) # ================================================= # RELATIONAL COGNITION CORE # ================================================= self.lookthem = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=32 ) # ================================================= # CLASSIFIER HEAD # ================================================= # # Flattened relational token representation # followed by lightweight anti-overfit head. # self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(64 * 64, 256), nn.ReLU(), nn.Dropout(0.4), nn.Linear(256, 200) ) def forward(self, x): batch_size = x.size(0) # ================================================= # STREAM A — GRAYSCALE MACRO EXTRACTION # ================================================= # Convert RGB image into grayscale x_gray = torch.sum( x * self.grayscale_weights, dim=1, keepdim=True ) feat_a = self.stream_a(x_gray) # Shape: # [B, 32, 16, 16] feat_a_flat = feat_a.view( batch_size, 32, 256 ) # Spatial compression: # 256 → 64 tokens feat_a_compressed = self.token_bridge( feat_a_flat ) feat_a_tokens = ( feat_a_compressed.transpose(1, 2) ) # Final shape: # [B, 64 Tokens, 32 Features] # ================================================= # STREAM B — RGB COLOR EXTRACTION # ================================================= feat_b = self.stream_b(x) feat_b_tokens = ( feat_b .view(batch_size, 32, 64) .transpose(1, 2) ) # Final shape: # [B, 64 Tokens, 32 Features] # ================================================= # ASYMMETRIC FEATURE FUSION # ================================================= # # Token count remains fixed while # feature dimensionality is doubled. # tokens_combined = torch.cat( [feat_a_tokens, feat_b_tokens], dim=2 ) # Final shape: # [B, 64 Tokens, 64 Features] # ================================================= # RELATIONAL COGNITION # ================================================= out_lookthem = self.lookthem( tokens_combined ) # ================================================= # CLASSIFICATION # ================================================= return self.classifier(out_lookthem) # ========================================================= # 7. TRAINING RUNTIME + CHECKPOINT SYSTEM # ========================================================= device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) model = LookThemV5().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam( model.parameters(), lr=0.001, weight_decay=1e-4 ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=20 ) start_epoch = 0 checkpoint_path = "lookthem_v5_checkpoint.pth" # ========================================================= # CHECKPOINT RESUME # ========================================================= if os.path.exists(checkpoint_path): print( "Checkpoint detected. " "Resuming previous experiment..." ) checkpoint = torch.load(checkpoint_path) model.load_state_dict( checkpoint['model_state_dict'] ) optimizer.load_state_dict( checkpoint['optimizer_state_dict'] ) scheduler.load_state_dict( checkpoint['scheduler_state_dict'] ) start_epoch = checkpoint['epoch'] print( f"Successfully resumed from " f"epoch {start_epoch + 1}" ) print( f"Starting LookThem V5 " f"(Asymmetric Fusion) on {device}..." ) # ========================================================= # 8. TRAINING LOOP # ========================================================= for epoch in range(start_epoch, 20): model.train() total_loss = 0 correct = 0 total = 0 for data, target in train_loader: data = data.to(device) target = target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() scheduler.step() acc = 100. * correct / total current_lr = optimizer.param_groups[0]['lr'] print( f"Epoch {epoch+1:02d}/20 | " f"Train Loss: " f"{total_loss / len(train_loader):.4f} | " f"Train Acc: {acc:.2f}% | " f"LR: {current_lr:.6f}" ) # ----------------------------------------------------- # Periodic checkpoint save # ----------------------------------------------------- if (epoch + 1) % 5 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, checkpoint_path) print( f"[CHECKPOINT] " f"Epoch {epoch+1} saved successfully." ) # ========================================================= # 9. FINAL VALIDATION # ========================================================= model.eval() test_loss = 0 test_correct = 0 test_total = 0 print("\nStarting final validation...") with torch.no_grad(): for data, target in val_loader: data = data.to(device) target = target.to(device) output = model(data) loss = criterion(output, target) test_loss += loss.item() _, predicted = output.max(1) test_total += target.size(0) test_correct += predicted.eq(target).sum().item() final_test_acc = ( 100. * test_correct / test_total ) print("=== FINAL LOOKTHEM V5 RESULTS ===") print( f"Test Loss: " f"{test_loss / len(val_loader):.4f} | " f"Test Accuracy: {final_test_acc:.2f}%" ) # Save final trained weights torch.save( model.state_dict(), "LookThem_V5_Final.pth" ) print( f"Training complete! " f"Final model size: " f"{os.path.getsize('LookThem_V5_Final.pth') / (1024*1024):.2f} MB" ) ```