| | import copy |
| | import random |
| |
|
| | import numpy as np |
| | import torch |
| | from torch_geometric.data import Batch |
| | from torch_geometric.loader import DataLoader |
| |
|
| | from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch |
| | from utils.torsion import modify_conformer_torsion_angles |
| | from scipy.spatial.transform import Rotation as R |
| | from utils.utils import crop_beyond |
| | from utils.logging_utils import get_logger |
| |
|
| |
|
| | def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7, |
| | initial_noise_std_proportion=-1.0, choose_residue=False): |
| | |
| | center_pocket = data_list[0]['receptor'].pos.mean(dim=0) |
| | if pocket_knowledge: |
| | complex = data_list[0] |
| | d = torch.cdist(complex['receptor'].pos, torch.from_numpy(complex['ligand'].orig_pos[0]).float() - complex.original_center) |
| | label = torch.any(d < pocket_cutoff, dim=1) |
| |
|
| | if torch.any(label): |
| | center_pocket = complex['receptor'].pos[label].mean(dim=0) |
| | else: |
| | print("No pocket residue below minimum distance ", pocket_cutoff, "taking closest at", torch.min(d)) |
| | center_pocket = complex['receptor'].pos[torch.argmin(torch.min(d, dim=1)[0])] |
| |
|
| | if not no_torsion: |
| | |
| | for complex_graph in data_list: |
| | torsion_updates = np.random.uniform(low=-np.pi, high=np.pi, size=complex_graph['ligand'].edge_mask.sum()) |
| | complex_graph['ligand'].pos = \ |
| | modify_conformer_torsion_angles(complex_graph['ligand'].pos, |
| | complex_graph['ligand', 'ligand'].edge_index.T[ |
| | complex_graph['ligand'].edge_mask], |
| | complex_graph['ligand'].mask_rotate[0], torsion_updates) |
| |
|
| | for complex_graph in data_list: |
| | |
| | molecule_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True) |
| | random_rotation = torch.from_numpy(R.random().as_matrix()).float() |
| | complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T + center_pocket |
| | |
| |
|
| | if not no_random: |
| | if choose_residue: |
| | idx = random.randint(0, len(complex_graph['receptor'].pos)-1) |
| | tr_update = torch.normal(mean=complex_graph['receptor'].pos[idx:idx+1], std=0.01) |
| | elif initial_noise_std_proportion >= 0.0: |
| | std_rec = torch.sqrt(torch.mean(torch.sum(complex_graph['receptor'].pos ** 2, dim=1))) |
| | tr_update = torch.normal(mean=0, std=std_rec * initial_noise_std_proportion / 1.73, size=(1, 3)) |
| | else: |
| | |
| | tr_update = torch.normal(mean=0, std=-initial_noise_std_proportion * tr_sigma_max, size=(1, 3)) |
| | complex_graph['ligand'].pos += tr_update |
| |
|
| |
|
| | def is_iterable(arr): |
| | try: |
| | some_object_iterator = iter(arr) |
| | return True |
| | except TypeError as te: |
| | return False |
| |
|
| |
|
| | def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_schedule, device, t_to_sigma, model_args, |
| | no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None, confidence_model_args=None, |
| | t_schedule=None, batch_size=32, no_final_step_noise=False, pivot=None, return_full_trajectory=False, |
| | temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False): |
| | N = len(data_list) |
| | trajectory = [] |
| | logger = get_logger() |
| | if return_features: |
| | lig_features, rec_features = [], [] |
| | assert batch_size >= N, "Not implemented yet" |
| |
|
| | loader = DataLoader(data_list, batch_size=batch_size) |
| | assert not (return_full_trajectory or return_features or pivot), "Not implemented yet in new inference version" |
| |
|
| | mask_rotate = torch.from_numpy(data_list[0]['ligand'].mask_rotate[0]).to(device) |
| |
|
| | confidence = None |
| | if confidence_model is not None: |
| | confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size)) |
| | confidence = [] |
| |
|
| | with torch.no_grad(): |
| | for batch_id, complex_graph_batch in enumerate(loader): |
| | b = complex_graph_batch.num_graphs |
| | n = len(complex_graph_batch['ligand'].pos) // b |
| | complex_graph_batch = complex_graph_batch.to(device) |
| |
|
| | for t_idx in range(inference_steps): |
| | t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx] |
| | dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx] |
| | dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx] |
| | dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx] |
| |
|
| | tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor) |
| |
|
| | if hasattr(model_args, 'crop_beyond') and model_args.crop_beyond is not None: |
| | |
| | mod_complex_graph_batch = copy.deepcopy(complex_graph_batch).to_data_list() |
| | for batch in mod_complex_graph_batch: |
| | crop_beyond(batch, tr_sigma * 3 + model_args.crop_beyond, model_args.all_atoms) |
| | mod_complex_graph_batch = Batch.from_data_list(mod_complex_graph_batch) |
| | else: |
| | mod_complex_graph_batch = complex_graph_batch |
| |
|
| | set_time(mod_complex_graph_batch, t_schedule[t_idx] if t_schedule is not None else None, t_tr, t_rot, t_tor, b, |
| | 'all_atoms' in model_args and model_args.all_atoms, device) |
| |
|
| | tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3] |
| | mean_scores = torch.mean(tr_score, dim=-1) |
| | num_nans = torch.sum(torch.isnan(mean_scores)) |
| | if num_nans > 0: |
| | name = complex_graph_batch['name'] |
| | if isinstance(name, list): |
| | name = name[0] |
| | logger.warning(f"Complex {name} Batch {batch_id+1} Inference Iteration {t_idx}: " |
| | f"{num_nans} / {mean_scores.numel()} samples failed") |
| |
|
| | |
| | |
| | tr_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tr_score.abs())), posinf=eps, neginf=-eps) |
| | rot_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(rot_score.abs())), posinf=eps, neginf=-eps) |
| | tor_score.nan_to_num_(nan=(eps := 0.01*torch.nanmean(tor_score.abs())), posinf=eps, neginf=-eps) |
| | del eps |
| |
|
| | tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min))) |
| | rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min))) |
| |
|
| | if ode: |
| | tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score) |
| | rot_perturb = (0.5 * rot_score * dt_rot * rot_g ** 2) |
| | else: |
| | tr_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \ |
| | else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device) |
| | tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z) |
| |
|
| | rot_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \ |
| | else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device) |
| | rot_perturb = (rot_score * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z) |
| |
|
| | if not model_args.no_torsion: |
| | tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min))) |
| | if ode: |
| | tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score) |
| | else: |
| | tor_z = torch.zeros(tor_score.shape, device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \ |
| | else torch.normal(mean=0, std=1, size=tor_score.shape, device=device) |
| | tor_perturb = (tor_g ** 2 * dt_tor * tor_score + tor_g * np.sqrt(dt_tor) * tor_z) |
| | torsions_per_molecule = tor_perturb.shape[0] // b |
| | else: |
| | tor_perturb = None |
| |
|
| | if not is_iterable(temp_sampling): |
| | temp_sampling = [temp_sampling] * 3 |
| | if not is_iterable(temp_psi): |
| | temp_psi = [temp_psi] * 3 |
| |
|
| | if not is_iterable(temp_sampling): temp_sampling = [temp_sampling] * 3 |
| | if not is_iterable(temp_psi): temp_psi = [temp_psi] * 3 |
| | if not is_iterable(temp_sigma_data): temp_sigma_data = [temp_sigma_data] * 3 |
| |
|
| | assert len(temp_sampling) == 3 |
| | assert len(temp_psi) == 3 |
| | assert len(temp_sigma_data) == 3 |
| |
|
| | if temp_sampling[0] != 1.0: |
| | tr_sigma_data = np.exp(temp_sigma_data[0] * np.log(model_args.tr_sigma_max) + (1 - temp_sigma_data[0]) * np.log(model_args.tr_sigma_min)) |
| | lambda_tr = (tr_sigma_data + tr_sigma) / (tr_sigma_data + tr_sigma / temp_sampling[0]) |
| | tr_perturb = (tr_g ** 2 * dt_tr * (lambda_tr + temp_sampling[0] * temp_psi[0] / 2) * tr_score + tr_g * np.sqrt(dt_tr * (1 + temp_psi[0])) * tr_z) |
| |
|
| | if temp_sampling[1] != 1.0: |
| | rot_sigma_data = np.exp(temp_sigma_data[1] * np.log(model_args.rot_sigma_max) + (1 - temp_sigma_data[1]) * np.log(model_args.rot_sigma_min)) |
| | lambda_rot = (rot_sigma_data + rot_sigma) / (rot_sigma_data + rot_sigma / temp_sampling[1]) |
| | rot_perturb = (rot_g ** 2 * dt_rot * (lambda_rot + temp_sampling[1] * temp_psi[1] / 2) * rot_score + rot_g * np.sqrt(dt_rot * (1 + temp_psi[1])) * rot_z) |
| |
|
| | if temp_sampling[2] != 1.0: |
| | tor_sigma_data = np.exp(temp_sigma_data[2] * np.log(model_args.tor_sigma_max) + (1 - temp_sigma_data[2]) * np.log(model_args.tor_sigma_min)) |
| | lambda_tor = (tor_sigma_data + tor_sigma) / (tor_sigma_data + tor_sigma / temp_sampling[2]) |
| | tor_perturb = (tor_g ** 2 * dt_tor * (lambda_tor + temp_sampling[2] * temp_psi[2] / 2) * tor_score + tor_g * np.sqrt(dt_tor * (1 + temp_psi[2])) * tor_z) |
| |
|
| | |
| | complex_graph_batch['ligand'].pos = \ |
| | modify_conformer_batch(complex_graph_batch['ligand'].pos, complex_graph_batch, tr_perturb, rot_perturb, |
| | tor_perturb if not model_args.no_torsion else None, mask_rotate) |
| |
|
| | if visualization_list is not None: |
| | for idx_b in range(b): |
| | visualization_list[batch_id * batch_size + idx_b].add(( |
| | complex_graph_batch['ligand'].pos[idx_b*n:n*(idx_b+1)].detach().cpu() + |
| | data_list[batch_id * batch_size + idx_b].original_center.detach().cpu()), |
| | part=1, order=t_idx + 2) |
| |
|
| | for i in range(b): |
| | data_list[batch_id * batch_size + i]['ligand'].pos = complex_graph_batch['ligand'].pos[i*n:n*(i+1)] |
| |
|
| | if visualization_list is not None: |
| | for idx, visualization in enumerate(visualization_list): |
| | visualization.add((data_list[idx]['ligand'].pos.detach().cpu() + data_list[idx].original_center.detach().cpu()), |
| | part=1, order=2) |
| |
|
| | if confidence_model is not None: |
| | if confidence_data_list is not None: |
| | confidence_complex_graph_batch = next(confidence_loader) |
| | confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos.cpu() |
| |
|
| | if hasattr(confidence_model_args, 'crop_beyond') and confidence_model_args.crop_beyond is not None: |
| | confidence_complex_graph_batch = confidence_complex_graph_batch.to_data_list() |
| | for batch in confidence_complex_graph_batch: |
| | crop_beyond(batch, confidence_model_args.crop_beyond, confidence_model_args.all_atoms) |
| | confidence_complex_graph_batch = Batch.from_data_list(confidence_complex_graph_batch) |
| |
|
| | confidence_complex_graph_batch = confidence_complex_graph_batch.to(device) |
| | set_time(confidence_complex_graph_batch, 0, 0, 0, 0, b, confidence_model_args.all_atoms, device) |
| | out = confidence_model(confidence_complex_graph_batch) |
| | else: |
| | out = confidence_model(complex_graph_batch) |
| |
|
| | if type(out) is tuple: |
| | out = out[0] |
| | confidence.append(out) |
| |
|
| | if confidence_model is not None: |
| | confidence = torch.cat(confidence, dim=0) |
| | confidence = torch.nan_to_num(confidence, nan=-1000) |
| |
|
| | if return_full_trajectory: |
| | return data_list, confidence, trajectory |
| | elif return_features: |
| | lig_features = torch.cat(lig_features, dim=0) |
| | rec_features = torch.cat(rec_features, dim=0) |
| | return data_list, confidence, lig_features, rec_features |
| |
|
| | return data_list, confidence |
| |
|
| |
|
| | def compute_affinity(data_list, affinity_model, affinity_data_list, device, parallel, all_atoms, include_miscellaneous_atoms): |
| |
|
| | with torch.no_grad(): |
| | if affinity_model is not None: |
| | assert parallel <= len(data_list) |
| | loader = DataLoader(data_list, batch_size=parallel) |
| | complex_graph_batch = next(iter(loader)).to(device) |
| | positions = complex_graph_batch['ligand'].pos |
| |
|
| | assert affinity_data_list is not None |
| | complex_graph = affinity_data_list[0] |
| | N = complex_graph['ligand'].num_nodes |
| | complex_graph['ligand'].x = complex_graph['ligand'].x.repeat(parallel, 1) |
| | complex_graph['ligand'].edge_mask = complex_graph['ligand'].edge_mask.repeat(parallel) |
| | complex_graph['ligand', 'ligand'].edge_index = torch.cat( |
| | [N * i + complex_graph['ligand', 'ligand'].edge_index for i in range(parallel)], dim=1) |
| | complex_graph['ligand', 'ligand'].edge_attr = complex_graph['ligand', 'ligand'].edge_attr.repeat(parallel, 1) |
| | complex_graph['ligand'].pos = positions |
| |
|
| | affinity_loader = DataLoader([complex_graph], batch_size=1) |
| | affinity_batch = next(iter(affinity_loader)).to(device) |
| | set_time(affinity_batch, 0, 0, 0, 0, 1, all_atoms, device, include_miscellaneous_atoms=include_miscellaneous_atoms) |
| | _, affinity = affinity_model(affinity_batch) |
| | else: |
| | affinity = None |
| |
|
| | return affinity |