| | |
| |
|
| | """ |
| | Learnable preprocessing components for the block-based autoencoder. |
| | Extracted from modeling_autoencoder.py to a dedicated module. |
| | """ |
| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple |
| |
|
| | import torch |
| | from typing import Tuple |
| |
|
| | try: |
| | from .blocks import BaseBlock |
| | except Exception: |
| | from blocks import BaseBlock |
| |
|
| | import torch.nn as nn |
| |
|
| | try: |
| | from .configuration_autoencoder import AutoencoderConfig |
| | except Exception: |
| | from configuration_autoencoder import AutoencoderConfig |
| |
|
| |
|
| | class NeuralScaler(nn.Module): |
| | """Learnable alternative to StandardScaler using neural networks.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| | input_dim = config.input_dim |
| | hidden_dim = config.preprocessing_hidden_dim |
| |
|
| | self.mean_estimator = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) |
| | ) |
| | self.std_estimator = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() |
| | ) |
| |
|
| | self.weight = nn.Parameter(torch.ones(input_dim)) |
| | self.bias = nn.Parameter(torch.zeros(input_dim)) |
| |
|
| | self.register_buffer("running_mean", torch.zeros(input_dim)) |
| | self.register_buffer("running_std", torch.ones(input_dim)) |
| | self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
| | self.momentum = 0.1 |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if inverse: |
| | return self._inverse_transform(x) |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | if self.training: |
| | batch_mean = x.mean(dim=0, keepdim=True) |
| | batch_std = x.std(dim=0, keepdim=True) |
| | learned_mean_adj = self.mean_estimator(batch_mean) |
| | learned_std_adj = self.std_estimator(batch_std) |
| | effective_mean = batch_mean + learned_mean_adj |
| | effective_std = batch_std + learned_std_adj + 1e-8 |
| | with torch.no_grad(): |
| | self.num_batches_tracked += 1 |
| | if self.num_batches_tracked == 1: |
| | self.running_mean.copy_(batch_mean.squeeze()) |
| | self.running_std.copy_(batch_std.squeeze()) |
| | else: |
| | self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) |
| | self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) |
| | else: |
| | effective_mean = self.running_mean.unsqueeze(0) |
| | effective_std = self.running_std.unsqueeze(0) + 1e-8 |
| | normalized = (x - effective_mean) / effective_std |
| | transformed = normalized * self.weight + self.bias |
| | if len(original_shape) == 3: |
| | transformed = transformed.view(original_shape) |
| | reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
| | return transformed, reg_loss |
| |
|
| | def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if not self.config.learn_inverse_preprocessing: |
| | return x, torch.tensor(0.0, device=x.device) |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | x = (x - self.bias) / (self.weight + 1e-8) |
| | effective_mean = self.running_mean.unsqueeze(0) |
| | effective_std = self.running_std.unsqueeze(0) + 1e-8 |
| | x = x * effective_std + effective_mean |
| | if len(original_shape) == 3: |
| | x = x.view(original_shape) |
| | return x, torch.tensor(0.0, device=x.device) |
| |
|
| |
|
| | class LearnableMinMaxScaler(nn.Module): |
| | """Learnable MinMax scaler that adapts bounds during training.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| | input_dim = config.input_dim |
| | hidden_dim = config.preprocessing_hidden_dim |
| | self.min_estimator = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) |
| | ) |
| | self.range_estimator = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() |
| | ) |
| | self.weight = nn.Parameter(torch.ones(input_dim)) |
| | self.bias = nn.Parameter(torch.zeros(input_dim)) |
| | self.register_buffer("running_min", torch.zeros(input_dim)) |
| | self.register_buffer("running_range", torch.ones(input_dim)) |
| | self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
| | self.momentum = 0.1 |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if inverse: |
| | return self._inverse_transform(x) |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | eps = 1e-8 |
| | if self.training: |
| | batch_min = x.min(dim=0, keepdim=True).values |
| | batch_max = x.max(dim=0, keepdim=True).values |
| | batch_range = (batch_max - batch_min).clamp_min(eps) |
| | learned_min_adj = self.min_estimator(batch_min) |
| | learned_range_adj = self.range_estimator(batch_range) |
| | effective_min = batch_min + learned_min_adj |
| | effective_range = batch_range + learned_range_adj + eps |
| | with torch.no_grad(): |
| | self.num_batches_tracked += 1 |
| | if self.num_batches_tracked == 1: |
| | self.running_min.copy_(batch_min.squeeze()) |
| | self.running_range.copy_(batch_range.squeeze()) |
| | else: |
| | self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum) |
| | self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum) |
| | else: |
| | effective_min = self.running_min.unsqueeze(0) |
| | effective_range = self.running_range.unsqueeze(0) |
| | scaled = (x - effective_min) / effective_range |
| | transformed = scaled * self.weight + self.bias |
| | if len(original_shape) == 3: |
| | transformed = transformed.view(original_shape) |
| | reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
| | if self.training: |
| | reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean() |
| | return transformed, reg_loss |
| |
|
| | def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if not self.config.learn_inverse_preprocessing: |
| | return x, torch.tensor(0.0, device=x.device) |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | x = (x - self.bias) / (self.weight + 1e-8) |
| | x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0) |
| | if len(original_shape) == 3: |
| | x = x.view(original_shape) |
| | return x, torch.tensor(0.0, device=x.device) |
| |
|
| |
|
| | class LearnableRobustScaler(nn.Module): |
| | """Learnable Robust scaler using median and IQR with learnable adjustments.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| | input_dim = config.input_dim |
| | hidden_dim = config.preprocessing_hidden_dim |
| | self.median_estimator = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) |
| | ) |
| | self.iqr_estimator = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() |
| | ) |
| | self.weight = nn.Parameter(torch.ones(input_dim)) |
| | self.bias = nn.Parameter(torch.zeros(input_dim)) |
| | self.register_buffer("running_median", torch.zeros(input_dim)) |
| | self.register_buffer("running_iqr", torch.ones(input_dim)) |
| | self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
| | self.momentum = 0.1 |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if inverse: |
| | return self._inverse_transform(x) |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | eps = 1e-8 |
| | if self.training: |
| | qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0) |
| | q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :] |
| | iqr = (q75 - q25).clamp_min(eps) |
| | learned_med_adj = self.median_estimator(med) |
| | learned_iqr_adj = self.iqr_estimator(iqr) |
| | effective_median = med + learned_med_adj |
| | effective_iqr = iqr + learned_iqr_adj + eps |
| | with torch.no_grad(): |
| | self.num_batches_tracked += 1 |
| | if self.num_batches_tracked == 1: |
| | self.running_median.copy_(med.squeeze()) |
| | self.running_iqr.copy_(iqr.squeeze()) |
| | else: |
| | self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum) |
| | self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum) |
| | else: |
| | effective_median = self.running_median.unsqueeze(0) |
| | effective_iqr = self.running_iqr.unsqueeze(0) |
| | normalized = (x - effective_median) / effective_iqr |
| | transformed = normalized * self.weight + self.bias |
| | if len(original_shape) == 3: |
| | transformed = transformed.view(original_shape) |
| | reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
| | if self.training: |
| | reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean() |
| | return transformed, reg_loss |
| |
|
| | def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if not self.config.learn_inverse_preprocessing: |
| | return x, torch.tensor(0.0, device=x.device) |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | x = (x - self.bias) / (self.weight + 1e-8) |
| | x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0) |
| | if len(original_shape) == 3: |
| | x = x.view(original_shape) |
| | return x, torch.tensor(0.0, device=x.device) |
| |
|
| |
|
| | class LearnableYeoJohnsonPreprocessor(nn.Module): |
| | """Learnable Yeo-Johnson power transform with per-feature lambda and affine head.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| | input_dim = config.input_dim |
| | self.lmbda = nn.Parameter(torch.ones(input_dim)) |
| | self.weight = nn.Parameter(torch.ones(input_dim)) |
| | self.bias = nn.Parameter(torch.zeros(input_dim)) |
| | self.register_buffer("running_mean", torch.zeros(input_dim)) |
| | self.register_buffer("running_std", torch.ones(input_dim)) |
| | self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) |
| | self.momentum = 0.1 |
| |
|
| | def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: |
| | eps = 1e-6 |
| | lmbda = lmbda.unsqueeze(0) |
| | pos = x >= 0 |
| | if_part = torch.where(torch.abs(lmbda) > eps, ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, torch.log((x + 1.0).clamp_min(eps))) |
| | two_minus_lambda = 2.0 - lmbda |
| | else_part = torch.where(torch.abs(two_minus_lambda) > eps, -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, -torch.log((1.0 - x).clamp_min(eps))) |
| | return torch.where(pos, if_part, else_part) |
| |
|
| | def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: |
| | eps = 1e-6 |
| | lmbda = lmbda.unsqueeze(0) |
| | pos = y >= 0 |
| | x_pos = torch.where(torch.abs(lmbda) > eps, (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, torch.exp(y) - 1.0) |
| | two_minus_lambda = 2.0 - lmbda |
| | x_neg = torch.where(torch.abs(two_minus_lambda) > eps, 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), 1.0 - torch.exp(-y)) |
| | return torch.where(pos, x_pos, x_neg) |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if inverse: |
| | return self._inverse_transform(x) |
| | orig_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | y = self._yeo_johnson(x, self.lmbda) |
| | if self.training: |
| | batch_mean = y.mean(dim=0, keepdim=True) |
| | batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6) |
| | with torch.no_grad(): |
| | self.num_batches_tracked += 1 |
| | if self.num_batches_tracked == 1: |
| | self.running_mean.copy_(batch_mean.squeeze()) |
| | self.running_std.copy_(batch_std.squeeze()) |
| | else: |
| | self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) |
| | self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) |
| | mean = batch_mean |
| | std = batch_std |
| | else: |
| | mean = self.running_mean.unsqueeze(0) |
| | std = self.running_std.unsqueeze(0) |
| | y_norm = (y - mean) / std |
| | out = y_norm * self.weight + self.bias |
| | if len(orig_shape) == 3: |
| | out = out.view(orig_shape) |
| | reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var()) |
| | return out, reg |
| |
|
| | def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if not self.config.learn_inverse_preprocessing: |
| | return x, torch.tensor(0.0, device=x.device) |
| | orig_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | y = (x - self.bias) / (self.weight + 1e-8) |
| | y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0) |
| | out = self._yeo_johnson_inverse(y, self.lmbda) |
| | if len(orig_shape) == 3: |
| | out = out.view(orig_shape) |
| | return out, torch.tensor(0.0, device=x.device) |
| |
|
| |
|
| |
|
| | class PreprocessingBlock(BaseBlock): |
| | """Wraps a LearnablePreprocessor into a BaseBlock-compatible interface. |
| | Forward returns the transformed tensor and stores the regularization loss in .reg_loss. |
| | The inverse flag is configured at initialization to avoid leaking kwargs to other blocks. |
| | """ |
| |
|
| | def __init__(self, config: AutoencoderConfig, inverse: bool = False, proc: Optional[LearnablePreprocessor] = None): |
| | super().__init__() |
| | self.proc = proc if proc is not None else LearnablePreprocessor(config) |
| | self._output_dim = config.input_dim |
| | self.inverse = inverse |
| | self.reg_loss: torch.Tensor = torch.tensor(0.0) |
| |
|
| | @property |
| | def output_dim(self) -> int: |
| | return self._output_dim |
| |
|
| | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: |
| | y, reg = self.proc(x, inverse=self.inverse) |
| | self.reg_loss = reg |
| | return y |
| |
|
| | class CouplingLayer(nn.Module): |
| | """Coupling layer for normalizing flows.""" |
| |
|
| | def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.hidden_dim = hidden_dim |
| | if mask_type == "alternating": |
| | self.register_buffer("mask", torch.arange(input_dim) % 2) |
| | elif mask_type == "half": |
| | mask = torch.zeros(input_dim) |
| | mask[: input_dim // 2] = 1 |
| | self.register_buffer("mask", mask) |
| | else: |
| | raise ValueError(f"Unknown mask type: {mask_type}") |
| | masked_dim = int(self.mask.sum().item()) |
| | unmasked_dim = input_dim - masked_dim |
| | self.scale_net = nn.Sequential( |
| | nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim), nn.Tanh() |
| | ) |
| | self.translate_net = nn.Sequential( |
| | nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim) |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False): |
| | mask = self.mask.bool() |
| | x_masked = x[:, mask] |
| | x_unmasked = x[:, ~mask] |
| | s = self.scale_net(x_masked) |
| | t = self.translate_net(x_masked) |
| | if not inverse: |
| | y_unmasked = x_unmasked * torch.exp(s) + t |
| | log_det = s.sum(dim=1) |
| | else: |
| | y_unmasked = (x_unmasked - t) * torch.exp(-s) |
| | log_det = -s.sum(dim=1) |
| | y = torch.zeros_like(x) |
| | y[:, mask] = x_masked |
| | y[:, ~mask] = y_unmasked |
| | return y, log_det |
| |
|
| |
|
| | class NormalizingFlowPreprocessor(nn.Module): |
| | """Normalizing flow for learnable data preprocessing.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| | input_dim = config.input_dim |
| | hidden_dim = config.preprocessing_hidden_dim |
| | num_layers = config.flow_coupling_layers |
| | self.layers = nn.ModuleList() |
| | for i in range(num_layers): |
| | mask_type = "alternating" if i % 2 == 0 else "half" |
| | self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type)) |
| | if config.use_batch_norm: |
| | self.batch_norms = nn.ModuleList([nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)]) |
| | else: |
| | self.batch_norms = None |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False): |
| | original_shape = x.shape |
| | if x.dim() == 3: |
| | x = x.view(-1, x.size(-1)) |
| | log_det_total = torch.zeros(x.size(0), device=x.device) |
| | if not inverse: |
| | for i, layer in enumerate(self.layers): |
| | x, log_det = layer(x, inverse=False) |
| | log_det_total += log_det |
| | if self.batch_norms and i < len(self.layers) - 1: |
| | x = self.batch_norms[i](x) |
| | else: |
| | for i, layer in enumerate(reversed(self.layers)): |
| | if self.batch_norms and i > 0: |
| | bn_idx = len(self.layers) - 1 - i |
| | x = self.batch_norms[bn_idx](x) |
| | x, log_det = layer(x, inverse=True) |
| | log_det_total += log_det |
| | if len(original_shape) == 3: |
| | x = x.view(original_shape) |
| | reg_loss = 0.01 * log_det_total.abs().mean() |
| | return x, reg_loss |
| |
|
| |
|
| | class LearnablePreprocessor(nn.Module): |
| | """Unified interface for learnable preprocessing methods.""" |
| |
|
| | def __init__(self, config: AutoencoderConfig): |
| | super().__init__() |
| | self.config = config |
| | if not config.has_preprocessing: |
| | self.preprocessor = nn.Identity() |
| | elif config.is_neural_scaler: |
| | self.preprocessor = NeuralScaler(config) |
| | elif config.is_normalizing_flow: |
| | self.preprocessor = NormalizingFlowPreprocessor(config) |
| | elif getattr(config, "is_minmax_scaler", False): |
| | self.preprocessor = LearnableMinMaxScaler(config) |
| | elif getattr(config, "is_robust_scaler", False): |
| | self.preprocessor = LearnableRobustScaler(config) |
| | elif getattr(config, "is_yeo_johnson", False): |
| | self.preprocessor = LearnableYeoJohnsonPreprocessor(config) |
| | else: |
| | raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}") |
| |
|
| | def forward(self, x: torch.Tensor, inverse: bool = False): |
| | if isinstance(self.preprocessor, nn.Identity): |
| | return x, torch.tensor(0.0, device=x.device) |
| | return self.preprocessor(x, inverse=inverse) |
| |
|
| |
|
| |
|
| | __all__ = [ |
| | "NeuralScaler", |
| | "LearnableMinMaxScaler", |
| | "LearnableRobustScaler", |
| | "LearnableYeoJohnsonPreprocessor", |
| | "CouplingLayer", |
| | "NormalizingFlowPreprocessor", |
| | "LearnablePreprocessor", |
| | "PreprocessingBlock", |
| | ] |
| |
|