| import sys |
| import torch |
| import random |
| import numpy as np |
| from tqdm import tqdm |
| from src.utils.model_utils import _print |
|
|
| class UnconditionalSampler: |
| def __init__(self, tokenizer, model): |
| self.model = model |
| self.tokenizer = tokenizer |
|
|
| self.device = self.model.device |
| self.mask_id = self.tokenizer.mask_token_id |
| self.seed_everything(seed=42) |
|
|
| @torch.inference_mode() |
| def sample_unconditional(self, xt, num_steps, tau=0.7, kappa_fn=lambda t: t, eta=1, alpha=1., banned_token_ids=None, return_logits=None): |
| """ |
| Stochastic remasking sampling method for iterative refinement of sequences. |
| |
| Args: |
| xt (Tensor): Initial token tensor. |
| num_steps (int): Number of refinement steps. |
| tau (float): Temperature parameter for softmax sampling. |
| kappa_fn (callable): Function controlling the unmasking schedule. |
| eta (float): Scaling factor for score adjustments. |
| alpha (float): Weighting for confidence-based scoring. |
| |
| Returns: |
| Tensor: Final sampled sequence tensor. |
| """ |
| |
| dt = 1 / num_steps |
| fix_mask = xt != self.mask_id |
| attention_mask = torch.ones_like(xt).to(self.device) |
|
|
| for i in range(1, num_steps + 1): |
| kappa_t = kappa_fn(i * dt) |
| logits = self.model(input_ids=xt, attention_mask=attention_mask) |
| last_mask = xt == self.mask_id |
| unmask_t = ~last_mask & ~fix_mask |
|
|
| x0, logp = self.stochastic_sample_from_categorical(logits, tau, banned_token_ids=banned_token_ids) |
|
|
| |
| entropy = torch.distributions.Categorical(logits=logits).entropy() |
| score = alpha * logp + (1 - alpha) * -entropy |
| score = score.masked_fill(fix_mask, float('inf')) |
|
|
| score[unmask_t] = score[unmask_t] * eta |
|
|
| num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long() |
| lowest_k_mask = self.topk_lowest_masking(score, num_to_mask) |
|
|
| xt[lowest_k_mask] = self.mask_id |
| mask_2_x0 = last_mask & ~lowest_k_mask |
| xt[mask_2_x0] = x0[mask_2_x0] |
|
|
| |
|
|
| xt[xt == self.mask_id] = x0[xt == self.mask_id] |
| return xt, logits if return_logits else xt |
|
|
| def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0, banned_token_ids=None): |
| """ |
| Sample from a categorical distribution with optional temperature scaling and Gumbel noise. |
| """ |
| logits = logits.double() |
|
|
| if banned_token_ids is not None: |
| banned_token_mask = torch.zeros_like(logits, device=logits.device).bool() |
| for token_id in banned_token_ids: |
| banned_token_mask[..., token_id] = True |
| logits = logits.masked_fill(banned_token_mask, float('-inf')) |
|
|
| if temperature != 0: |
| gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8) |
| logits = logits / temperature + noise_scale * gumbel_noise |
| scores, tokens = logits.log_softmax(dim=-1).max(dim=-1) |
|
|
| return tokens, scores |
|
|
| def topk_lowest_masking(self, scores, cutoff_len): |
| """ |
| scores: [b, n] |
| cutoff_len: [b, 1] |
| returns: |
| mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise |
| """ |
| sorted_index = scores.sort(-1)[0] |
| cutoff = sorted_index.gather(dim=-1, index=cutoff_len) |
| return scores < cutoff |
|
|
| def seed_everything(self, seed): |
| """ |
| Set the seed for reproducibility across various libraries. |
| """ |
| if seed is None: |
| return |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |