| import torch |
| import numpy as np |
| from scipy.linalg import sqrtm |
|
|
| def dna_to_tensor(seq): |
| mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} |
| indices = [mapping[base] for base in seq] |
| return torch.tensor(indices, dtype=torch.long) |
|
|
|
|
| def compute_fbd(true_seqs, gen_seqs, score_model): |
| """ |
| The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings |
| """ |
| embeds1 = score_model() |
| embeds2 = score_model() |
|
|
| if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0: |
| return float('nan') |
| mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False) |
| mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False) |
| ssdiff = np.sum((mu1 - mu2) ** 2.0) |
| covmean = sqrtm(sigma1.dot(sigma2)) |
| if np.iscomplexobj(covmean): |
| covmean = covmean.real |
| dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) |
| return dist |
|
|