| """ |
| JSON-optimized tokenizer. |
| |
| Design principles: |
| 1. Structural tokens: JSON grammar symbols ({, }, [, ], :, comma) each get |
| a dedicated single token β no wasted subword splits on syntax. |
| 2. Key vocabulary: Frequently occurring JSON keys get their own tokens |
| (Key(name), Key(id), etc.), massively reducing token count for |
| repetitive schemas. |
| 3. Type-prefixed values: Values are prefixed with a type marker |
| (STR:, NUM:, BOOL:, NULL) so the tokenizer preserves JSON types |
| for lossless roundtrip. |
| 4. BPE for value content: String and number content is tokenized via |
| a BPE codec trained on JSON value distributions. |
| 5. Nesting tokens: [OBJ_START]/[OBJ_END] and Array(N) tokens encode |
| hierarchy without ambiguity. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| from collections import Counter |
| from typing import Any, Optional, Union |
|
|
| from json_tokenizer.bpe import BPETrainer |
|
|
|
|
| |
| class StructuralTokens: |
| """Reserved token IDs for JSON grammar elements.""" |
|
|
| PAD = 0 |
| START = 1 |
| END = 2 |
| OBJ_START = 3 |
| OBJ_END = 4 |
| ARR_START = 5 |
| ARR_END = 6 |
| COLON = 7 |
| COMMA = 8 |
| NULL = 9 |
| TRUE = 10 |
| FALSE = 11 |
| STR_DELIM = 12 |
| NUM_PREFIX = 13 |
| KEY_PREFIX = 14 |
| UNK = 15 |
|
|
| |
| RESERVED_END = 32 |
|
|
| @classmethod |
| def name(cls, token_id: int) -> str: |
| _names = { |
| 0: "[PAD]", |
| 1: "[START]", |
| 2: "[END]", |
| 3: "{", |
| 4: "}", |
| 5: "[", |
| 6: "]", |
| 7: ":", |
| 8: ",", |
| 9: "null", |
| 10: "true", |
| 11: "false", |
| 12: "[STR]", |
| 13: "[NUM]", |
| 14: "[KEY]", |
| 15: "[UNK]", |
| } |
| return _names.get(token_id, f"[RESERVED_{token_id}]") |
|
|
|
|
| class JSONTokenizer: |
| """Tokenizer optimized for JSON structures. |
| |
| Encodes JSON into a compact token sequence with: |
| - Single tokens for structural elements |
| - Dedicated key tokens for common keys |
| - BPE subword tokens for string/number values |
| - Full roundtrip fidelity (encode β decode == original) |
| |
| Usage: |
| tokenizer = JSONTokenizer() |
| tokenizer.train_from_json_files(["data1.json", "data2.json"]) |
| ids = tokenizer.encode('{"name": "Alice", "age": 30}') |
| decoded = tokenizer.decode(ids) |
| """ |
|
|
| def __init__( |
| self, |
| bpe_vocab_size: int = 4096, |
| max_key_vocab: int = 1024, |
| min_key_freq: int = 2, |
| bpe_min_freq: int = 2, |
| ): |
| self.bpe_vocab_size = bpe_vocab_size |
| self.max_key_vocab = max_key_vocab |
| self.min_key_freq = min_key_freq |
| self.bpe_min_freq = bpe_min_freq |
|
|
| |
| self._key_to_id: dict[str, int] = {} |
| self._id_to_key: dict[int, str] = {} |
| self._key_offset = StructuralTokens.RESERVED_END |
|
|
| |
| self._bpe = BPETrainer(vocab_size=bpe_vocab_size, min_frequency=bpe_min_freq) |
| self._bpe_offset = 0 |
|
|
| |
| self._id_to_token: dict[int, str] = {} |
| self._token_to_id: dict[str, int] = {} |
| self._trained = False |
|
|
| @property |
| def vocab_size(self) -> int: |
| """Total vocabulary size.""" |
| if not self._trained: |
| return StructuralTokens.RESERVED_END |
| return self._bpe_offset + len(self._bpe.vocab) |
|
|
| |
|
|
| def train(self, json_objects: list[Any]) -> None: |
| """Train the tokenizer from a list of parsed JSON objects. |
| |
| Extracts keys for the key vocabulary and values for BPE training. |
| |
| Args: |
| json_objects: List of parsed JSON values (dicts, lists, primitives). |
| """ |
| key_counter: Counter[str] = Counter() |
| value_strings: list[str] = [] |
|
|
| for obj in json_objects: |
| self._extract_keys_and_values(obj, key_counter, value_strings) |
|
|
| |
| top_keys = [ |
| k |
| for k, count in key_counter.most_common(self.max_key_vocab) |
| if count >= self.min_key_freq |
| ] |
|
|
| self._key_to_id = {} |
| self._id_to_key = {} |
| for i, key in enumerate(top_keys): |
| tid = self._key_offset + i |
| self._key_to_id[key] = tid |
| self._id_to_key[tid] = key |
|
|
| |
| self._bpe_offset = self._key_offset + len(self._key_to_id) |
|
|
| |
| if value_strings: |
| self._bpe.train(value_strings) |
|
|
| |
| self._build_vocab_lookup() |
| self._trained = True |
|
|
| def train_from_json_strings(self, json_strings: list[str]) -> None: |
| """Train from raw JSON strings.""" |
| objects = [] |
| for s in json_strings: |
| try: |
| objects.append(json.loads(s)) |
| except json.JSONDecodeError: |
| continue |
| self.train(objects) |
|
|
| def train_from_json_files(self, file_paths: list[str]) -> None: |
| """Train from JSON files (one JSON object per file, or JSONL).""" |
| objects = [] |
| for path in file_paths: |
| with open(path) as f: |
| content = f.read().strip() |
| |
| try: |
| obj = json.loads(content) |
| if isinstance(obj, list): |
| objects.extend(obj) |
| else: |
| objects.append(obj) |
| continue |
| except json.JSONDecodeError: |
| pass |
| |
| for line in content.splitlines(): |
| line = line.strip() |
| if line: |
| try: |
| objects.append(json.loads(line)) |
| except json.JSONDecodeError: |
| continue |
| self.train(objects) |
|
|
| def _extract_keys_and_values( |
| self, |
| obj: Any, |
| key_counter: Counter[str], |
| value_strings: list[str], |
| ) -> None: |
| """Recursively extract keys and value strings from a JSON object.""" |
| if isinstance(obj, dict): |
| for key, value in obj.items(): |
| key_counter[key] += 1 |
| |
| value_strings.append(key) |
| self._extract_keys_and_values(value, key_counter, value_strings) |
| elif isinstance(obj, list): |
| for item in obj: |
| self._extract_keys_and_values(item, key_counter, value_strings) |
| elif isinstance(obj, str): |
| value_strings.append(obj) |
| elif isinstance(obj, (int, float)): |
| value_strings.append(str(obj)) |
| |
|
|
| def _build_vocab_lookup(self) -> None: |
| """Build the complete idβtoken mappings.""" |
| self._id_to_token = {} |
| self._token_to_id = {} |
|
|
| |
| for i in range(StructuralTokens.RESERVED_END): |
| name = StructuralTokens.name(i) |
| self._id_to_token[i] = name |
| self._token_to_id[name] = i |
|
|
| |
| for key, tid in self._key_to_id.items(): |
| token_name = f"Key({key})" |
| self._id_to_token[tid] = token_name |
| self._token_to_id[token_name] = tid |
|
|
| |
| for bpe_token, bpe_id in self._bpe.vocab.items(): |
| full_id = self._bpe_offset + bpe_id |
| self._id_to_token[full_id] = f"BPE({bpe_token})" |
| self._token_to_id[f"BPE({bpe_token})"] = full_id |
|
|
| |
|
|
| def encode(self, json_input: Union[str, Any]) -> list[int]: |
| """Encode a JSON string or parsed object into token IDs. |
| |
| Args: |
| json_input: Either a JSON string or an already-parsed Python object. |
| |
| Returns: |
| List of integer token IDs. |
| """ |
| if isinstance(json_input, str): |
| try: |
| obj = json.loads(json_input) |
| except json.JSONDecodeError: |
| raise ValueError(f"Invalid JSON: {json_input[:100]}...") |
| else: |
| obj = json_input |
|
|
| tokens = [StructuralTokens.START] |
| self._encode_value(obj, tokens) |
| tokens.append(StructuralTokens.END) |
| return tokens |
|
|
| def _encode_value(self, value: Any, tokens: list[int]) -> None: |
| """Recursively encode a JSON value into tokens.""" |
| if isinstance(value, dict): |
| self._encode_object(value, tokens) |
| elif isinstance(value, list): |
| self._encode_array(value, tokens) |
| elif isinstance(value, str): |
| self._encode_string(value, tokens) |
| elif isinstance(value, bool): |
| |
| tokens.append(StructuralTokens.TRUE if value else StructuralTokens.FALSE) |
| elif isinstance(value, (int, float)): |
| self._encode_number(value, tokens) |
| elif value is None: |
| tokens.append(StructuralTokens.NULL) |
| else: |
| tokens.append(StructuralTokens.UNK) |
|
|
| def _encode_object(self, obj: dict, tokens: list[int]) -> None: |
| """Encode a JSON object.""" |
| tokens.append(StructuralTokens.OBJ_START) |
| for i, (key, value) in enumerate(obj.items()): |
| if i > 0: |
| tokens.append(StructuralTokens.COMMA) |
| self._encode_key(key, tokens) |
| tokens.append(StructuralTokens.COLON) |
| self._encode_value(value, tokens) |
| tokens.append(StructuralTokens.OBJ_END) |
|
|
| def _encode_array(self, arr: list, tokens: list[int]) -> None: |
| """Encode a JSON array.""" |
| tokens.append(StructuralTokens.ARR_START) |
| for i, item in enumerate(arr): |
| if i > 0: |
| tokens.append(StructuralTokens.COMMA) |
| self._encode_value(item, tokens) |
| tokens.append(StructuralTokens.ARR_END) |
|
|
| def _encode_key(self, key: str, tokens: list[int]) -> None: |
| """Encode a JSON key β uses key vocab if available, else BPE.""" |
| if key in self._key_to_id: |
| tokens.append(self._key_to_id[key]) |
| else: |
| tokens.append(StructuralTokens.KEY_PREFIX) |
| bpe_ids = self._bpe.encode_to_ids(key) |
| tokens.extend(self._bpe_offset + bid for bid in bpe_ids) |
|
|
| def _encode_string(self, value: str, tokens: list[int]) -> None: |
| """Encode a JSON string value.""" |
| tokens.append(StructuralTokens.STR_DELIM) |
| if value: |
| bpe_ids = self._bpe.encode_to_ids(value) |
| tokens.extend(self._bpe_offset + bid for bid in bpe_ids) |
| tokens.append(StructuralTokens.STR_DELIM) |
|
|
| def _encode_number(self, value: Union[int, float], tokens: list[int]) -> None: |
| """Encode a JSON number value.""" |
| tokens.append(StructuralTokens.NUM_PREFIX) |
| |
| if isinstance(value, float) and value == int(value) and "." in str(value): |
| text = str(value) |
| elif isinstance(value, int): |
| text = str(value) |
| else: |
| text = repr(value) |
| bpe_ids = self._bpe.encode_to_ids(text) |
| tokens.extend(self._bpe_offset + bid for bid in bpe_ids) |
|
|
| |
|
|
| def decode(self, token_ids: list[int]) -> str: |
| """Decode token IDs back to a JSON string. |
| |
| Args: |
| token_ids: List of integer token IDs from encode(). |
| |
| Returns: |
| JSON string faithful to the original. |
| """ |
| obj = self._decode_to_object(token_ids) |
| return json.dumps(obj, ensure_ascii=False) |
|
|
| def decode_to_object(self, token_ids: list[int]) -> Any: |
| """Decode token IDs back to a Python object.""" |
| return self._decode_to_object(token_ids) |
|
|
| def _decode_to_object(self, token_ids: list[int]) -> Any: |
| """Parse token IDs back into a Python object.""" |
| |
| ids = list(token_ids) |
| if ids and ids[0] == StructuralTokens.START: |
| ids = ids[1:] |
| if ids and ids[-1] == StructuralTokens.END: |
| ids = ids[:-1] |
|
|
| result, _ = self._parse_value(ids, 0) |
| return result |
|
|
| def _parse_value(self, ids: list[int], pos: int) -> tuple[Any, int]: |
| """Parse a single value starting at position pos.""" |
| if pos >= len(ids): |
| return None, pos |
|
|
| tid = ids[pos] |
|
|
| if tid == StructuralTokens.OBJ_START: |
| return self._parse_object(ids, pos) |
| elif tid == StructuralTokens.ARR_START: |
| return self._parse_array(ids, pos) |
| elif tid == StructuralTokens.STR_DELIM: |
| return self._parse_string(ids, pos) |
| elif tid == StructuralTokens.NUM_PREFIX: |
| return self._parse_number(ids, pos) |
| elif tid == StructuralTokens.NULL: |
| return None, pos + 1 |
| elif tid == StructuralTokens.TRUE: |
| return True, pos + 1 |
| elif tid == StructuralTokens.FALSE: |
| return False, pos + 1 |
| else: |
| return None, pos + 1 |
|
|
| def _parse_object(self, ids: list[int], pos: int) -> tuple[dict, int]: |
| """Parse a JSON object from token IDs.""" |
| assert ids[pos] == StructuralTokens.OBJ_START |
| pos += 1 |
| result: dict[str, Any] = {} |
|
|
| while pos < len(ids) and ids[pos] != StructuralTokens.OBJ_END: |
| if ids[pos] == StructuralTokens.COMMA: |
| pos += 1 |
| continue |
|
|
| |
| key, pos = self._parse_key(ids, pos) |
|
|
| |
| if pos < len(ids) and ids[pos] == StructuralTokens.COLON: |
| pos += 1 |
|
|
| |
| value, pos = self._parse_value(ids, pos) |
| result[key] = value |
|
|
| if pos < len(ids) and ids[pos] == StructuralTokens.OBJ_END: |
| pos += 1 |
|
|
| return result, pos |
|
|
| def _parse_array(self, ids: list[int], pos: int) -> tuple[list, int]: |
| """Parse a JSON array from token IDs.""" |
| assert ids[pos] == StructuralTokens.ARR_START |
| pos += 1 |
| result: list[Any] = [] |
|
|
| while pos < len(ids) and ids[pos] != StructuralTokens.ARR_END: |
| if ids[pos] == StructuralTokens.COMMA: |
| pos += 1 |
| continue |
|
|
| value, pos = self._parse_value(ids, pos) |
| result.append(value) |
|
|
| if pos < len(ids) and ids[pos] == StructuralTokens.ARR_END: |
| pos += 1 |
|
|
| return result, pos |
|
|
| def _parse_key(self, ids: list[int], pos: int) -> tuple[str, int]: |
| """Parse a key from token IDs.""" |
| tid = ids[pos] |
|
|
| |
| if tid in self._id_to_key: |
| return self._id_to_key[tid], pos + 1 |
|
|
| |
| if tid == StructuralTokens.KEY_PREFIX: |
| pos += 1 |
| bpe_tokens: list[str] = [] |
| while pos < len(ids) and ids[pos] >= self._bpe_offset: |
| bpe_id = ids[pos] - self._bpe_offset |
| bpe_tokens.append(self._bpe.id_to_token(bpe_id)) |
| pos += 1 |
| |
| if pos < len(ids) and ids[pos] == StructuralTokens.COLON: |
| break |
| return self._bpe.decode_tokens(bpe_tokens), pos |
|
|
| return f"<unknown_key_{tid}>", pos + 1 |
|
|
| def _parse_string(self, ids: list[int], pos: int) -> tuple[str, int]: |
| """Parse a string value from token IDs.""" |
| assert ids[pos] == StructuralTokens.STR_DELIM |
| pos += 1 |
|
|
| bpe_tokens: list[str] = [] |
| while pos < len(ids) and ids[pos] != StructuralTokens.STR_DELIM: |
| bpe_id = ids[pos] - self._bpe_offset |
| bpe_tokens.append(self._bpe.id_to_token(bpe_id)) |
| pos += 1 |
|
|
| |
| if pos < len(ids) and ids[pos] == StructuralTokens.STR_DELIM: |
| pos += 1 |
|
|
| return self._bpe.decode_tokens(bpe_tokens), pos |
|
|
| def _parse_number(self, ids: list[int], pos: int) -> tuple[Union[int, float], int]: |
| """Parse a number value from token IDs.""" |
| assert ids[pos] == StructuralTokens.NUM_PREFIX |
| pos += 1 |
|
|
| bpe_tokens: list[str] = [] |
| while pos < len(ids): |
| tid = ids[pos] |
| if tid < self._bpe_offset: |
| break |
| bpe_id = tid - self._bpe_offset |
| bpe_tokens.append(self._bpe.id_to_token(bpe_id)) |
| pos += 1 |
|
|
| text = self._bpe.decode_tokens(bpe_tokens).strip() |
| try: |
| if "." in text or "e" in text.lower(): |
| return float(text), pos |
| return int(text), pos |
| except ValueError: |
| return 0, pos |
|
|
| |
|
|
| def decode_tokens_readable(self, token_ids: list[int]) -> list[str]: |
| """Convert token IDs to human-readable token names.""" |
| result: list[str] = [] |
| for tid in token_ids: |
| if tid in self._id_to_token: |
| result.append(self._id_to_token[tid]) |
| elif tid in self._id_to_key: |
| result.append(f"Key({self._id_to_key[tid]})") |
| else: |
| bpe_id = tid - self._bpe_offset |
| token_str = self._bpe.id_to_token(bpe_id) |
| result.append(f"BPE({repr(token_str)})") |
| return result |
|
|
| def token_count(self, json_input: Union[str, Any]) -> int: |
| """Count tokens for a JSON input without materializing full list.""" |
| return len(self.encode(json_input)) |
|
|
| |
|
|
| def save(self, directory: str) -> None: |
| """Save the full tokenizer state to a directory.""" |
| import os |
|
|
| os.makedirs(directory, exist_ok=True) |
|
|
| |
| self._bpe.save(os.path.join(directory, "bpe_model.json")) |
|
|
| |
| config = { |
| "version": "json-tokenizer-v1", |
| "bpe_vocab_size": self.bpe_vocab_size, |
| "max_key_vocab": self.max_key_vocab, |
| "min_key_freq": self.min_key_freq, |
| "bpe_min_freq": self.bpe_min_freq, |
| "key_vocab": self._key_to_id, |
| "key_offset": self._key_offset, |
| "bpe_offset": self._bpe_offset, |
| } |
| with open(os.path.join(directory, "tokenizer_config.json"), "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| @classmethod |
| def load(cls, directory: str) -> "JSONTokenizer": |
| """Load a trained tokenizer from a directory.""" |
| import os |
|
|
| with open(os.path.join(directory, "tokenizer_config.json")) as f: |
| config = json.load(f) |
|
|
| tokenizer = cls( |
| bpe_vocab_size=config["bpe_vocab_size"], |
| max_key_vocab=config["max_key_vocab"], |
| min_key_freq=config["min_key_freq"], |
| bpe_min_freq=config["bpe_min_freq"], |
| ) |
|
|
| |
| tokenizer._key_to_id = config["key_vocab"] |
| tokenizer._id_to_key = {int(v): k for k, v in config["key_vocab"].items()} |
| tokenizer._key_offset = config["key_offset"] |
| tokenizer._bpe_offset = config["bpe_offset"] |
|
|
| |
| tokenizer._bpe = BPETrainer.load(os.path.join(directory, "bpe_model.json")) |
|
|
| tokenizer._build_vocab_lookup() |
| tokenizer._trained = True |
| return tokenizer |
|
|