| import pandas as pd |
| from tqdm import tqdm |
| from rdkit import Chem, RDLogger |
| from datasets import load_dataset |
| from multiprocessing import Pool, cpu_count |
| import os |
|
|
| |
| RDLogger.DisableLog('rdApp.*') |
|
|
| class SmilesEnumerator: |
| """ |
| A simple class to encapsulate the SMILES randomization logic. |
| Needed for multiprocessing to work correctly with instance methods. |
| """ |
| def randomize_smiles(self, smiles): |
| """Generates a randomized SMILES string.""" |
| try: |
| mol = Chem.MolFromSmiles(smiles) |
| |
| return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
| except: |
| |
| return smiles |
|
|
| def create_augmented_pair(smiles_string): |
| """ |
| Worker function: takes one SMILES string and returns a tuple |
| containing two different randomized versions of it. |
| """ |
| enumerator = SmilesEnumerator() |
| smiles_1 = enumerator.randomize_smiles(smiles_string) |
| smiles_2 = enumerator.randomize_smiles(smiles_string) |
| return smiles_1, smiles_2 |
|
|
| def main(): |
| """ |
| Main function to run the parallel data preprocessing. |
| """ |
| |
| |
| dataset_name = 'jablonkagroup/pubchem-smiles-molecular-formula' |
| |
| smiles_column_name = 'smiles' |
| |
| output_path = 'data/pubchem_2_epoch_50M' |
|
|
| |
| print(f"Loading dataset '{dataset_name}'...") |
| |
| |
| dataset = load_dataset(dataset_name)['train'].select(range(50_000_000)) |
| |
| smiles_list = dataset[smiles_column_name] |
| print(f"Successfully fetched {len(smiles_list)} SMILES strings.") |
|
|
| |
| |
| num_workers = cpu_count() |
| print(f"Starting SMILES augmentation with {num_workers} worker processes...") |
|
|
| |
| with Pool(num_workers) as p: |
| |
| results = list(tqdm(p.imap(create_augmented_pair, smiles_list), total=len(smiles_list), desc="Augmenting Pairs")) |
|
|
| |
| print("Processing complete. Converting to DataFrame...") |
| |
| df = pd.DataFrame(results, columns=['smiles_1', 'smiles_2']) |
|
|
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
| print(f"Saving augmented pairs to '{output_path}'...") |
| |
| df.to_parquet(output_path) |
| |
| print("All done. Your pre-computed dataset is ready!") |
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|