| from collections import Counter |
| import torchvision.datasets as dset |
| from torch.utils.data import Dataset |
| import torch |
| from torch.utils.data import DataLoader |
| import glob |
| import os |
| from torch.utils.data import Dataset, DataLoader, random_split |
| from shutil import copyfile |
| import subprocess |
| import youtokentome as yttm |
| import re |
| import time |
| from tqdm import trange, tqdm |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import inspect |
|
|
| |
|
|
| DEVICE = "cpu" |
|
|
|
|
| class BPEModelManager: |
| def __init__(self, root_dir, vocab_size=5000): |
| self.root_dir = root_dir |
| self.vocab_size = vocab_size |
| self.model_path = os.path.join(root_dir, "bpe_model.model") |
|
|
| try: |
| self.bpe = yttm.BPE(model=self.model_path) |
| if self.bpe.vocab_size() != vocab_size: |
| print( |
| f"Vocab size mismatch: Expected {vocab_size}, got {self.bpe.vocab_size()}. Retraining model." |
| ) |
| self._backup_model() |
| raise ValueError |
| except ValueError: |
| self._train_bpe_model() |
| self.bpe = yttm.BPE(model=self.model_path) |
|
|
| def _backup_model(self): |
| backup_path = os.path.join(self.root_dir, "bpe_model.model.old") |
| copyfile(self.model_path, backup_path) |
|
|
| def _train_bpe_model(self): |
| data_path = os.path.join(self.root_dir, "data/corpus.txt") |
| processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
|
|
| with open(data_path, "r", errors="ignore") as reader: |
| raw_text = reader.read() |
|
|
| processed_text = self.preprocess_text(raw_text) |
|
|
| with open(processed_path, "w") as writer: |
| writer.write(processed_text) |
|
|
| yttm.BPE.train( |
| data=processed_path, |
| vocab_size=self.vocab_size, |
| model=self.model_path, |
| coverage=0.9999, |
| ) |
|
|
| def preprocess_text(self, text): |
| return text.lower() |
|
|
| def encode(self, text: str): |
| return self.bpe.encode([text], output_type=yttm.OutputType.ID) |
|
|
| def decode(self, ids): |
| return self.bpe.decode(ids.tolist())[0] |
|
|
| @staticmethod |
| def attention_mask(encoded_sequence, mask_token_ids=[0, 1, 2, 3]): |
| mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int).to( |
| encoded_sequence.device |
| ) |
| |
| |
| return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() |
|
|
|
|
| class CodeBPEModelManager(BPEModelManager): |
| mapping_dict = { |
| " ": " <INDENT> ", |
| "\n": " <NEWLINE> ", |
| } |
|
|
| def __init__(self, root_dir, vocab_size=5000): |
| super().__init__(root_dir, vocab_size) |
|
|
| def preprocess_text(self, text): |
| print("Formatting....") |
| processed_text = self.format_code(text) |
|
|
| for key, value in CodeBPEModelManager.mapping_dict.items(): |
| processed_text = processed_text.replace(key, value) |
|
|
| return processed_text |
|
|
| def encode(self, text: str): |
| processed_text = text |
| for key, value in CodeBPEModelManager.mapping_dict.items(): |
| processed_text = processed_text.replace(key, value) |
|
|
| return self.bpe.encode([processed_text], output_type=yttm.OutputType.ID)[0] |
|
|
| def decode(self, ids): |
| |
| |
| l = ids |
| if isinstance(l, torch.Tensor): |
| l = ids.tolist() |
| if isinstance(l, int): |
| l = [l] |
|
|
| result = self.bpe.decode(l)[0] |
| |
| for key, value in CodeBPEModelManager.mapping_dict.items(): |
| result = result.replace(value.strip(), key) |
|
|
| return result |
|
|
| def raw_decode(self, id: int): |
| return self.bpe.decode([id])[0] |
|
|
| def _train_bpe_model(self): |
| print("Training (1)....") |
| data_path = os.path.join(self.root_dir, "data/corpus.txt") |
| processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
|
|
| if input("Reformat? Will take time [y/N]") == "y": |
|
|
| with open(data_path, "r", errors="ignore", encoding="utf-8") as reader: |
| raw_text = reader.read() |
|
|
| processed_text = self.preprocess_text(raw_text) |
|
|
| with open(processed_path, "w", encoding="utf-8") as writer: |
| writer.write(processed_text) |
|
|
| print("removing temp file...") |
| temp_file = os.path.join(self.root_dir, "temp_code.py") |
| os.remove(temp_file) |
|
|
| print("Training....") |
| yttm.BPE.train( |
| data=processed_path, |
| vocab_size=self.vocab_size, |
| model=self.model_path, |
| coverage=1, |
| |
| ) |
|
|
| def format_code(self, code): |
| try: |
| temp_file = os.path.join(self.root_dir, "temp_code.py") |
| with open(temp_file, "w") as file: |
| file.write( |
| code.replace("\t", " ") |
| ) |
|
|
| subprocess.run(["black", temp_file, "--quiet"], check=True) |
| subprocess.run( |
| ["autopep8", "--in-place", "--ignore=E402", temp_file], check=True |
| ) |
|
|
| with open(temp_file, "r") as file: |
| formatted_code = file.read() |
|
|
| return formatted_code |
| except Exception as e: |
| print(f"Error during code formatting: {e}.") |
| return code |
|
|
|
|
| class CodeCustomTokenizerManager(BPEModelManager): |
| reserved_keywords = [ |
| "false", |
| "await", |
| "else", |
| "import", |
| "pass", |
| "none", |
| "break", |
| "except", |
| "in", |
| "raise", |
| "true", |
| "class", |
| "finally", |
| "is", |
| "return", |
| "and", |
| "continue", |
| "for", |
| "lambda", |
| "try", |
| "as", |
| "def", |
| "from", |
| "nonlocal", |
| "while", |
| "assert", |
| "del", |
| "global", |
| "not", |
| "with", |
| "async", |
| "elif", |
| "if", |
| "or", |
| "yield", |
| ] |
| symbols = [ |
| "(", |
| ")", |
| "[", |
| "]", |
| "{", |
| "}", |
| ".", |
| ",", |
| ":", |
| ";", |
| "+", |
| "-", |
| "*", |
| "/", |
| "%", |
| "=", |
| "<", |
| ">", |
| "&", |
| "|", |
| "^", |
| "~", |
| "!", |
| "==", |
| "!=", |
| "<=", |
| ">=", |
| "**", |
| "//", |
| "@", |
| "#", |
| "\\", |
| "'", |
| '"', |
| "`", |
| "0", |
| "1", |
| "2", |
| "3", |
| "4", |
| "5", |
| "6", |
| "7", |
| "8", |
| "9", |
| "0x", |
| "0d", |
| "0o", |
| ] |
|
|
| def __init__( |
| self, |
| root_dir, |
| vocab_size=5000, |
| cutoff_thresh=0.1, |
| use_vocab_size_instead=False, |
| use_whitespace=True, |
| ): |
| self.root_dir = root_dir |
|
|
| self.token_to_id = {"<PAD>": 0} |
| self.id_to_token = None |
|
|
| self._token_freqs = {} |
| self.total_num_tokens = 0 |
| print("This is CodeCustomTokenizerManager, vocab size will be disregarded.") |
|
|
| print(f"Cutoff threshold: {cutoff_thresh}") |
| self.cutoff_thresh = cutoff_thresh |
|
|
| self.use_whitespace = use_whitespace |
|
|
| if not use_whitespace: |
| print("Not using whitespace! Important I guess") |
|
|
| if use_vocab_size_instead: |
| print("Nevermind! Using vocab size instead, no cutoff thresh") |
|
|
| self.use_vocab_size_instead = use_vocab_size_instead |
|
|
| self.vocab_size = vocab_size |
|
|
| vocab_path = os.path.join(self.root_dir, "custom_tokens_vocab.txt") |
| try: |
| self.load_vocab(vocab_path) |
| except FileNotFoundError: |
| print("Making vocab!") |
| self.make_vocab() |
| self.save_vocab(vocab_path) |
|
|
| print(f"Vocab size: {len(self.token_to_id)}") |
|
|
| def make_vocab(self): |
| data_path = os.path.join(self.root_dir, "data/corpus.txt") |
| processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
|
|
| with open(data_path, "r", errors="ignore") as reader: |
| raw_text = reader.read() |
|
|
| processed_text = self.preprocess_text(raw_text) |
|
|
| with open(processed_path, "w") as writer: |
| writer.write(" ".join(processed_text)) |
|
|
| for token in processed_text: |
| if token not in self.token_to_id: |
| if len(self.token_to_id) == 0: |
| self.token_to_id = {"<PAD>": 0} |
|
|
| self.token_to_id[token] = len(self.token_to_id) |
|
|
| print(f"Number of tokens: {len(self.token_to_id)}") |
| |
| def make_token_freqs(self): |
|
|
| processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") |
| with open(processed_path, "r", errors="ignore") as reader: |
| raw_text = reader.read() |
| tokens = raw_text.split(" ") |
|
|
| token_freqs = {"<PAD>": 0} |
|
|
|
|
| for token in tqdm(tokens, leave=False): |
| if token not in token_freqs: |
| token_freqs[token] = 1 |
| else: |
| token_freqs[token] += 1 |
| |
| self._token_freqs = token_freqs |
| self.total_num_tokens = len(tokens) |
|
|
|
|
| def preprocess_text(self, code): |
| print("Preprocessing text...", code[:20]) |
|
|
| |
|
|
| |
| code = code.replace("# <FILESEP>", "<FILESEP>") |
| code = re.sub(r"#.*", "", code) |
| code = re.sub(r'"""(.*?)"""', "", code, flags=re.DOTALL) |
| code = re.sub(r"'''(.*?)'''", "", code, flags=re.DOTALL) |
|
|
| code = re.sub(r" ", " ", code) |
|
|
| print("Filtered comments") |
|
|
| |
|
|
| |
| |
| code = re.sub(r"[^ -~\s]+", "", code) |
| |
| print("Filtered non-ascii") |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| for word in self.reserved_keywords: |
| code = re.sub(rf"\b{word}\b", f" {word} ", code) |
|
|
| print("Reserved words") |
| for symbol in self.symbols: |
| code = code.replace(symbol, f" {symbol} ") |
|
|
| print("Symbols") |
|
|
| |
|
|
| |
| def split_token(token): |
| if token.startswith("<") and token.endswith( |
| ">" |
| ): |
| return [token.lower()] |
| result = re.sub(r"([a-z])([A-Z])", r"\1 \2", token) |
| result = re.sub(r"([_-])", r" \1 ", result) |
| result = re.sub(r"([^a-zA-Z])", r" \1 ", result) |
| return [part.lower() for part in result.split() if part.strip()] |
|
|
| code = code.replace(" ", " <TAB> ").replace("\n", " <NEWLINE> ") |
| if not self.use_whitespace: |
| code = code.replace("<TAB>", "").replace("<NEWLINE>", "") |
| print("Tabs + newlines") |
|
|
| tokens = [] |
| for token in tqdm(code.split(" "), leave=False): |
| if token.strip(): |
| tokens.extend(split_token(token)) |
|
|
| tokens = [tok.lower() for tok in tokens if tok.strip()] |
|
|
| print("Split tokens") |
| token_freqs = {"<PAD>": 0} |
| for token in tqdm(tokens, leave=False): |
| if token not in token_freqs: |
| token_freqs[token] = 1 |
| else: |
| token_freqs[token] += 1 |
| print("Counted freqs") |
|
|
| |
| |
| |
| |
| |
|
|
| total_num_tokens = len(tokens) |
|
|
| counter = Counter(list(token_freqs.values())) |
| num_ones = counter[1] |
| print( |
| f"Number of tokens that appear only once: {num_ones}. Percentage: {num_ones / total_num_tokens}" |
| ) |
|
|
| print(f"Mean token count: {np.mean(list(token_freqs.values()))}") |
| print(f"Median token count: {np.median(list(token_freqs.values()))}") |
|
|
| print( |
| f"Standard deviation of token count: {np.std(list(token_freqs.values()))}" |
| ) |
|
|
| print(f"Min token count: {np.min(list(token_freqs.values()))}") |
| print(f"Max token count: {np.max(list(token_freqs.values()))}") |
|
|
| print(f"Top 30 most frequent tokens:") |
| sorted_tokens = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True) |
| for token, freq in sorted_tokens[:30]: |
| print(f"{token}: {freq}") |
|
|
| print(f"Bottom 30 most frequent tokens:") |
| for token, freq in sorted_tokens[-30:]: |
| print(f"{token}: {freq}") |
|
|
| self._token_freqs = token_freqs |
| self.total_num_tokens = total_num_tokens |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
|
|
| |
| cutoff_thresh = self.cutoff_thresh |
| if self.use_vocab_size_instead: |
| print("Using vocab size instead") |
| print("deprecated") |
| print("cope") |
| exit() |
| sorted_tokens = sorted( |
| token_freqs.items(), key=lambda x: x[1], reverse=True |
| ) |
| allowed_tokens = set( |
| token for token, _ in sorted_tokens[: self.vocab_size - 1] |
| ) |
| for i in range(len(tokens)): |
| if tokens[i] not in allowed_tokens and tokens[i] != "<PAD>": |
| print(f"Replacing token with UNK: {tokens[i]}") |
| tokens[i] = "<UNK>" |
|
|
| else: |
| cutoff_amt = ( |
| 10 |
| ) |
| print(f"Cuttoff amount: {cutoff_amt}") |
|
|
| |
| low_freq_tokens = [ |
| token |
| for token, freq in token_freqs.items() |
| if freq < cutoff_amt and token != "<PAD>" |
| ] |
| low_freq_tokens_set = set(low_freq_tokens) |
| tokens = [ |
| "<UNK>" if token in low_freq_tokens_set else token |
| for token in tqdm(tokens) |
| ] |
|
|
| print(tokens[500:700]) |
|
|
| print("500-700") |
|
|
| return [tok for tok in tokens if tok.strip()] |
|
|
| def encode(self, code): |
| tokens = code.split(" ") |
| ids = [] |
|
|
| for token in tokens: |
| |
| if token not in self.token_to_id: |
| self.token_to_id[token] = len(self.token_to_id) |
| ids.append(self.token_to_id[token]) |
|
|
| return ids |
|
|
| def decode(self, ids): |
| result = "" |
| for id in ids.tolist(): |
| for token, id_iterator in self.token_to_id.items(): |
| if id_iterator == id: |
| result += token |
| result += " " |
|
|
| return result |
|
|
| def raw_decode(self, id: int): |
| for token, id_iterator in self.token_to_id.items(): |
| if id_iterator == id: |
| return token |
|
|
| def format_code(self, code): |
| try: |
| temp_file = os.path.join(self.root_dir, "temp_code.py") |
| with open(temp_file, "w") as file: |
| file.write( |
| code.replace("\t", " ") |
| ) |
|
|
| subprocess.run(["black", temp_file, "--quiet"], check=True) |
| subprocess.run( |
| ["autopep8", "--in-place", "--ignore=E402", temp_file], check=True |
| ) |
|
|
| with open(temp_file, "r") as file: |
| formatted_code = file.read() |
|
|
| return formatted_code |
| except Exception as e: |
| print(f"Error during code formatting: {e}.") |
| return code |
|
|
| def save_vocab(self, file_path): |
| with open(file_path, "w") as file: |
| for token, id in self.token_to_id.items(): |
| file.write(f"{token}\t{id}\n") |
|
|
| def load_vocab(self, file_path): |
| self.token_to_id = {} |
| with open(file_path, "r") as file: |
| for line in file.read().split("\n"): |
| try: |
| token, id = line.strip().split("\t") |
| self.token_to_id[token] = int(id) |
| except ValueError: |
| |
| |
| pass |
|
|
| @staticmethod |
| def attention_mask(encoded_sequence, mask_token_ids=[0]): |
| mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int) |
| |
| |
| return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() |
|
|
| def get_rarity_score(self, sequence): |
| scores = np.zeros_like(sequence) |
| for idx, token in enumerate(sequence): |
| |
| |
| |
| |
| |
| |
| if self._token_freqs == {}: |
| self.make_token_freqs() |
| if not self.id_to_token: |
| self.id_to_token = {v: k for k, v in self.token_to_id.items()} |
| token_count = self._token_freqs.get(self.id_to_token[token.item()], 0) |
| rarity_score = self.total_num_tokens / token_count if token_count > 0 else 0 |
| scores[idx] = rarity_score |
| |
| return np.float32(np.median(scores)) |
|
|
| def get_entropy_score(self, sequence): |
| if len(sequence) == 0: |
| return 0.0 |
|
|
| unique, counts = np.unique(sequence, return_counts=True) |
|
|
| probs = counts / counts.sum() |
| entropy = -np.sum(probs * np.log2(probs)) |
|
|
| if len(unique) > 1: |
| entropy /= np.log2(len(unique)) |
|
|
| return np.float32(entropy) |
|
|
|
|
| class DummySequentialDataManager: |
| def __init__(self, root_dir, vocab_size=5000): |
| print("init") |
| self.root_dir = root_dir |
| self.vocab_size = vocab_size |
| with open(os.path.join(root_dir, "data/corpus_processed.txt"), "w+") as f: |
| f.write("dummy") |
|
|
| def encode(self, text: str): |
| return [list(range(50))] |
|
|
| def decode(self, ids): |
| l = ids |
| if isinstance(l, torch.Tensor): |
| l = ids.tolist() |
| if isinstance(l, int): |
| l = [l] |
|
|
| return " ".join([str(id) for id in l]) |
|
|
| @staticmethod |
| def attention_mask(encoded_sequence, mask_token_ids=[]): |
| mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int).to( |
| encoded_sequence.device |
| ) |
| |
| |
| return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() |
|
|
|
|
| class TextCorpusDataset(Dataset): |
| def __init__( |
| self, |
| root_dir="./test-data", |
| train=False, |
| max_length=512, |
| vocab_size=10000, |
| IS_DUMMY=False, |
| IS_CODE=False, |
| IS_CUSTOM=False, |
| sliding_window=False, |
| stride=1, |
| get_rarity_score=False, |
| get_entropy_score=False, |
| ): |
| print(root_dir) |
|
|
| |
| print("[TextCorpusDataset]") |
| frame = inspect.currentframe() |
| args, _, _, values = inspect.getargvalues(frame) |
| print("Arguments passed:") |
| for arg in args[1:]: |
| print(f" {arg} = {values[arg]}") |
|
|
| self.root = root_dir |
| self.sliding_window = sliding_window |
| self.window_size = max_length |
| self.stride = stride |
| self.get_rarity_score = get_rarity_score |
| self.get_entropy_score = get_entropy_score |
|
|
| if IS_DUMMY: |
| self.manager = DummySequentialDataManager(root_dir=root_dir) |
| elif IS_CODE: |
| if IS_CUSTOM: |
| self.manager = CodeCustomTokenizerManager(root_dir=root_dir) |
| else: |
| self.manager = CodeBPEModelManager( |
| root_dir=root_dir, vocab_size=vocab_size |
| ) |
| else: |
| self.manager = BPEModelManager(root_dir=root_dir, vocab_size=vocab_size) |
|
|
| self.max_length = max_length |
| self.cache_file = os.path.join(root_dir, "encoded_chunked.pt") |
| self.rarity_cache_file = os.path.join(root_dir, "rarity_scores.pt") |
| self.entropy_cache_file = os.path.join(root_dir, "entropy_scores.pt") |
|
|
| start_t = time.time() |
| if os.path.exists(self.cache_file): |
| self.chunks = torch.load(self.cache_file, weights_only=True) |
| if self.chunks.size(-1) != self.max_length: |
| if ( |
| input( |
| "Attempting to fix and re-chunk data to correct length. Continue? [y/N]: " |
| ) |
| == "y" |
| ): |
| self._chunk_and_save(torch.flatten(self.chunks).tolist()) |
| print("Re-chunked successfully!") |
| else: |
| print("Operation aborted.") |
| else: |
| with open( |
| os.path.join(root_dir, "data/corpus_processed.txt"), |
| "r", |
| errors="ignore", |
| ) as file: |
| text = file.read() |
| encoded = self.manager.encode(text) |
|
|
| self._chunk_and_save(encoded) |
|
|
| |
| self._load_or_compute_scores() |
|
|
| end_t = time.time() |
| print(f"Dataset loading took {end_t - start_t} seconds.") |
|
|
| |
| self.chunks = self.chunks.to(DEVICE) |
| if self.get_rarity_score: |
| self.rarity_scores = self.rarity_scores.to(DEVICE) |
| if self.get_entropy_score: |
| self.entropy_scores = self.entropy_scores.to(DEVICE) |
| self.dummy = torch.tensor([1], device=DEVICE) |
|
|
| def _chunk_and_save(self, encoded): |
| chunked_data = [] |
| if self.sliding_window: |
| print("sliding!") |
| for i in trange( |
| 0, len(encoded) - self.window_size + 1, self.stride, leave=False |
| ): |
| chunked_data.append( |
| torch.tensor(encoded[i : i + self.window_size], dtype=torch.int) |
| ) |
| else: |
| for i in trange(0, len(encoded), self.max_length, leave=False): |
| chunked_data.append( |
| torch.tensor(encoded[i : i + self.max_length], dtype=torch.int) |
| ) |
|
|
| |
| padded_chunk = torch.zeros(self.max_length, dtype=torch.int) |
| padded_chunk[: len(chunked_data[-1])] = chunked_data[-1] |
| chunked_data[-1] = padded_chunk |
|
|
| self.chunks = torch.stack(chunked_data) |
| torch.save(self.chunks, self.cache_file) |
|
|
| def _load_or_compute_scores(self): |
| """Load cached scores or compute them if not available""" |
| if self.get_rarity_score: |
| if os.path.exists(self.rarity_cache_file): |
| print("Loading cached rarity scores...") |
| self.rarity_scores = torch.load(self.rarity_cache_file, weights_only=True) |
| if len(self.rarity_scores) != len(self.chunks): |
| print("Rarity cache size mismatch, recomputing...") |
| self._compute_and_cache_rarity_scores() |
| else: |
| print("Computing rarity scores...") |
| self._compute_and_cache_rarity_scores() |
| |
| if self.get_entropy_score: |
| if os.path.exists(self.entropy_cache_file): |
| print("Loading cached entropy scores...") |
| self.entropy_scores = torch.load(self.entropy_cache_file, weights_only=True) |
| if len(self.entropy_scores) != len(self.chunks): |
| print("Entropy cache size mismatch, recomputing...") |
| self._compute_and_cache_entropy_scores() |
| else: |
| print("Computing entropy scores...") |
| self._compute_and_cache_entropy_scores() |
|
|
| def _compute_and_cache_rarity_scores(self): |
| """Compute rarity scores for all chunks and cache them""" |
| rarity_scores = [] |
| print("Computing rarity scores for all chunks...") |
| for i in trange(len(self.chunks), desc="Computing rarity scores"): |
| score = self.manager.get_rarity_score(self.chunks[i]) |
| rarity_scores.append(score) |
| |
| self.rarity_scores = torch.tensor(rarity_scores, dtype=torch.float32) |
| torch.save(self.rarity_scores, self.rarity_cache_file) |
| print(f"Cached rarity scores to {self.rarity_cache_file}") |
|
|
| def _compute_and_cache_entropy_scores(self): |
| """Compute entropy scores for all chunks and cache them""" |
| entropy_scores = [] |
| print("Computing entropy scores for all chunks...") |
| for i in trange(len(self.chunks), desc="Computing entropy scores"): |
| score = self.manager.get_entropy_score(self.chunks[i]) |
| entropy_scores.append(score) |
| |
| self.entropy_scores = torch.tensor(entropy_scores, dtype=torch.float32) |
| torch.save(self.entropy_scores, self.entropy_cache_file) |
| print(f"Cached entropy scores to {self.entropy_cache_file}") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| def __len__(self): |
| return len(self.chunks) |
|
|
| def __getitem__( |
| self, idx |
| ): |
| seq = self.chunks[idx] |
| if self.get_rarity_score: |
| return seq, self.rarity_scores[idx] |
| if self.get_entropy_score: |
| return seq, self.entropy_scores[idx] |
| return seq, self.dummy |
|
|
|
|
| class Datasplit_chunker(Dataset): |
| def __init__(self, root, name, subset, slide=False, stride=1, length=512): |
| super().__init__() |
|
|
| self.root = root |
| if os.path.exists(os.path.join(root, f"encoded_chunked_{name}.pt")): |
| self.items = torch.load( |
| os.path.join(root, f"encoded_chunked_{name}.pt"), weights_only=True |
| ) |
|
|
| else: |
| self.items = torch.cat([subset.dataset[idx][0] for idx in subset.indices]) |
|
|
| if slide: |
| self.items = self._sliding_window( |
| self.items, window_size=length, stride=stride |
| ) |
|
|
| torch.save(self.items, os.path.join(root, f"encoded_chunked_{name}.pt")) |
| print("saved!") |
| self.chunks = self.items |
| self.dummy = torch.tensor([1], device=DEVICE) |
|
|
| def _sliding_window(self, sequence, window_size, stride): |
| num_windows = (len(sequence) - window_size) // stride + 1 |
| windows = torch.as_strided( |
| sequence, size=(num_windows, window_size), stride=(stride, 1) |
| ) |
| return windows |
|
|
| def __len__(self): |
| return len(self.items) |
|
|
| def __getitem__(self, idx): |
| return self.chunks[idx], self.dummy |
|
|
|
|
| |
| dataset = TextCorpusDataset( |
| root_dir=os.path.expanduser( |
| |
| |
| |
| |
| |
| |
| |
| "~/torch_datasets/github-python/mega_licensed_corpus" |
| ), |
| vocab_size=33819, |
| IS_CODE=True, |
| IS_CUSTOM=True, |
| |
| max_length=256, |
| sliding_window=False, |
| stride=10, |
| get_rarity_score=True, |
| ) |
|
|
| dset_size = int(len(dataset)) |
| train_size = int(0.8 * dset_size) |
| test_size = int(dset_size - train_size) |
| if test_size == 2: |
| print("alert! test size is 2 or whatever. Change this back please.") |
|
|
| torch.manual_seed(3407) |
|
|
| train_dataset, test_dataset, _ = random_split( |
| dataset, [train_size, test_size, len(dataset) - train_size - test_size] |
| ) |
|
|
|
|
| |
| |
|
|
|
|
| |
|
|
| |
|
|
|
|
| def get_train_dataset(): |
| return train_dataset |
|
|
|
|
| def get_test_dataset(): |
|
|
| return test_dataset |
|
|
|
|
| def get_dataloader(dataset, batch_size=64): |
|
|
| return DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
| def fromDataset(dataset): |
| dset_size = int(len(dataset)) |
| train_size = int(0.8 * dset_size) |
| test_size = int(dset_size - train_size) |
| if test_size == 2: |
| print("alert! test size is 2 or whatever. Change this back please.") |
|
|
| torch.manual_seed(3407) |
|
|
| train_dataset, test_dataset, _ = random_split( |
| dataset, [train_size, test_size, len(dataset) - train_size - test_size] |
| ) |
|
|
| return train_dataset, test_dataset |
|
|
|
|
| if __name__ == "__main__": |
| d = get_train_dataset() |
| print("Number of samples: ", len(d)) |
| for a, b in d: |
| |
| manager = dataset.manager |
| print(a) |
| print(manager.decode(a)) |
| |
| print("--- sep batch --- ") |
|
|
| print(f"Number of tokens used: {len(dataset.manager.token_to_id)}") |
| break |
|
|